1
0

Compare commits

..

4 Commits

Author SHA1 Message Date
David Robertson
4d343db081 Get rid of my home dir, whoops 2021-11-16 16:37:43 +00:00
David Robertson
a1367dcf8c Require networkx 2021-11-16 16:34:33 +00:00
David Robertson
9e361c8550 Changelog 2021-11-16 13:52:59 +00:00
David Robertson
51fec1a534 Commit hacky script to visualise store inheritance
Use e.g. with `scripts-dev/storage_inheritance.py DataStore --show`.
2021-11-16 13:51:50 +00:00
70 changed files with 612 additions and 750 deletions

View File

@@ -1,19 +1,3 @@
Synapse 1.47.0 (2021-11-17)
===========================
No significant changes since 1.47.0rc3.
Synapse 1.47.0rc3 (2021-11-16)
==============================
Bugfixes
--------
- Fix a bug introduced in 1.47.0rc1 which caused worker processes to not halt startup in the presence of outstanding database migrations. ([\#11346](https://github.com/matrix-org/synapse/issues/11346))
- Fix a bug introduced in 1.47.0rc1 which prevented the 'remove deleted devices from `device_inbox` column' background process from running when updating from a recent Synapse version. ([\#11303](https://github.com/matrix-org/synapse/issues/11303), [\#11353](https://github.com/matrix-org/synapse/issues/11353))
Synapse 1.47.0rc2 (2021-11-10) Synapse 1.47.0rc2 (2021-11-10)
============================== ==============================

View File

@@ -1 +0,0 @@
Add type annotations to `synapse.metrics`.

View File

@@ -1 +0,0 @@
Add support for the `/_matrix/client/v3` APIs from Matrix v1.1.

View File

@@ -1 +0,0 @@
Changed the word 'Home server' as one word 'homeserver' in documentation.

View File

@@ -1 +0,0 @@
Add type hints to `synapse.util`.

View File

@@ -1 +0,0 @@
Improve type annotations in Synapse's test suite.

View File

@@ -1 +0,0 @@
Fix a bug, introduced in Synapse 1.46.0, which caused the `check_3pid_auth` and `on_logged_out` callbacks in legacy password authentication provider modules to not be registered. Modules using the generic module API were not affected.

View File

@@ -1 +0,0 @@
Add admin API to un-shadow-ban a user.

View File

@@ -1 +0,0 @@
Fix a bug introduced in 1.41.0 where space hierarchy responses would be incorrectly reused if multiple users were to make the same request at the same time.

1
changelog.d/11357.misc Normal file
View File

@@ -0,0 +1 @@
Add a development script for visualising the storage class inheritance hierarchy.

View File

@@ -1 +0,0 @@
Require all files in synapse/ and tests/ to pass mypy unless specifically excluded.

View File

@@ -1 +0,0 @@
Fix running `scripts-dev/complement.sh`, which was broken in v1.47.0rc1.

View File

@@ -1 +0,0 @@
Rename `get_access_token_for_user_id` to `create_access_token_for_user_id` to better reflect what it does.

View File

@@ -1 +0,0 @@
Add support for the `/_matrix/media/v3` APIs from Matrix v1.1.

View File

@@ -1 +0,0 @@
Trim redundant DataStore inheritance.

12
debian/changelog vendored
View File

@@ -1,15 +1,3 @@
matrix-synapse-py3 (1.47.0) stable; urgency=medium
* New synapse release 1.47.0.
-- Synapse Packaging team <packages@matrix.org> Wed, 17 Nov 2021 13:09:43 +0000
matrix-synapse-py3 (1.47.0~rc3) stable; urgency=medium
* New synapse release 1.47.0~rc3.
-- Synapse Packaging team <packages@matrix.org> Tue, 16 Nov 2021 14:32:47 +0000
matrix-synapse-py3 (1.47.0~rc2) stable; urgency=medium matrix-synapse-py3 (1.47.0~rc2) stable; urgency=medium
[ Dan Callahan ] [ Dan Callahan ]

View File

@@ -48,7 +48,7 @@ WORKERS_CONFIG = {
"app": "synapse.app.user_dir", "app": "synapse.app.user_dir",
"listener_resources": ["client"], "listener_resources": ["client"],
"endpoint_patterns": [ "endpoint_patterns": [
"^/_matrix/client/(api/v1|r0|v3|unstable)/user_directory/search$" "^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$"
], ],
"shared_extra_conf": {"update_user_directory": False}, "shared_extra_conf": {"update_user_directory": False},
"worker_extra_conf": "", "worker_extra_conf": "",
@@ -85,10 +85,10 @@ WORKERS_CONFIG = {
"app": "synapse.app.generic_worker", "app": "synapse.app.generic_worker",
"listener_resources": ["client"], "listener_resources": ["client"],
"endpoint_patterns": [ "endpoint_patterns": [
"^/_matrix/client/(v2_alpha|r0|v3)/sync$", "^/_matrix/client/(v2_alpha|r0)/sync$",
"^/_matrix/client/(api/v1|v2_alpha|r0|v3)/events$", "^/_matrix/client/(api/v1|v2_alpha|r0)/events$",
"^/_matrix/client/(api/v1|r0|v3)/initialSync$", "^/_matrix/client/(api/v1|r0)/initialSync$",
"^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$", "^/_matrix/client/(api/v1|r0)/rooms/[^/]+/initialSync$",
], ],
"shared_extra_conf": {}, "shared_extra_conf": {},
"worker_extra_conf": "", "worker_extra_conf": "",
@@ -146,11 +146,11 @@ WORKERS_CONFIG = {
"app": "synapse.app.generic_worker", "app": "synapse.app.generic_worker",
"listener_resources": ["client"], "listener_resources": ["client"],
"endpoint_patterns": [ "endpoint_patterns": [
"^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/redact", "^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact",
"^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/send", "^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send",
"^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$", "^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$",
"^/_matrix/client/(api/v1|r0|v3|unstable)/join/", "^/_matrix/client/(api/v1|r0|unstable)/join/",
"^/_matrix/client/(api/v1|r0|v3|unstable)/profile/", "^/_matrix/client/(api/v1|r0|unstable)/profile/",
], ],
"shared_extra_conf": {}, "shared_extra_conf": {},
"worker_extra_conf": "", "worker_extra_conf": "",
@@ -158,7 +158,7 @@ WORKERS_CONFIG = {
"frontend_proxy": { "frontend_proxy": {
"app": "synapse.app.frontend_proxy", "app": "synapse.app.frontend_proxy",
"listener_resources": ["client", "replication"], "listener_resources": ["client", "replication"],
"endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload"], "endpoint_patterns": ["^/_matrix/client/(api/v1|r0|unstable)/keys/upload"],
"shared_extra_conf": {}, "shared_extra_conf": {},
"worker_extra_conf": ( "worker_extra_conf": (
"worker_main_http_uri: http://127.0.0.1:%d" "worker_main_http_uri: http://127.0.0.1:%d"

View File

@@ -948,7 +948,7 @@ The following fields are returned in the JSON response body:
See also the See also the
[Client-Server API Spec on pushers](https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers). [Client-Server API Spec on pushers](https://matrix.org/docs/spec/client_server/latest#get-matrix-client-r0-pushers).
## Controlling whether a user is shadow-banned ## Shadow-banning users
Shadow-banning is a useful tool for moderating malicious or egregiously abusive users. Shadow-banning is a useful tool for moderating malicious or egregiously abusive users.
A shadow-banned users receives successful responses to their client-server API requests, A shadow-banned users receives successful responses to their client-server API requests,
@@ -961,22 +961,16 @@ or broken behaviour for the client. A shadow-banned user will not receive any
notification and it is generally more appropriate to ban or kick abusive users. notification and it is generally more appropriate to ban or kick abusive users.
A shadow-banned user will be unable to contact anyone on the server. A shadow-banned user will be unable to contact anyone on the server.
To shadow-ban a user the API is: The API is:
``` ```
POST /_synapse/admin/v1/users/<user_id>/shadow_ban POST /_synapse/admin/v1/users/<user_id>/shadow_ban
``` ```
To un-shadow-ban a user the API is:
```
DELETE /_synapse/admin/v1/users/<user_id>/shadow_ban
```
To use it, you will need to authenticate by providing an `access_token` for a To use it, you will need to authenticate by providing an `access_token` for a
server admin: [Admin API](../usage/administration/admin_api) server admin: [Admin API](../usage/administration/admin_api)
An empty JSON dict is returned in both cases. An empty JSON dict is returned.
**Parameters** **Parameters**

View File

@@ -23,7 +23,7 @@ Server with a domain specific API.
1. **Messaging Layer** 1. **Messaging Layer**
This is what the rest of the homeserver hits to send messages, join rooms, This is what the rest of the Home Server hits to send messages, join rooms,
etc. It also allows you to register callbacks for when it get's notified by etc. It also allows you to register callbacks for when it get's notified by
lower levels that e.g. a new message has been received. lower levels that e.g. a new message has been received.

View File

@@ -1,7 +1,7 @@
<h2 style="color:red"> <h2 style="color:red">
This page of the Synapse documentation is now deprecated. For up to date This page of the Synapse documentation is now deprecated. For up to date
documentation on setting up or writing a password auth provider module, please see documentation on setting up or writing a password auth provider module, please see
<a href="modules/index.md">this page</a>. <a href="modules.md">this page</a>.
</h2> </h2>
# Password auth provider modules # Password auth provider modules

View File

@@ -1,12 +1,12 @@
# Overview # Overview
This document explains how to enable VoIP relaying on your homeserver with This document explains how to enable VoIP relaying on your Home Server with
TURN. TURN.
The synapse Matrix homeserver supports integration with TURN server via the The synapse Matrix Home Server supports integration with TURN server via the
[TURN server REST API](<https://tools.ietf.org/html/draft-uberti-behave-turn-rest-00>). This [TURN server REST API](<https://tools.ietf.org/html/draft-uberti-behave-turn-rest-00>). This
allows the homeserver to generate credentials that are valid for use on the allows the Home Server to generate credentials that are valid for use on the
TURN server through the use of a secret shared between the homeserver and the TURN server through the use of a secret shared between the Home Server and the
TURN server. TURN server.
The following sections describe how to install [coturn](<https://github.com/coturn/coturn>) (which implements the TURN REST API) and integrate it with synapse. The following sections describe how to install [coturn](<https://github.com/coturn/coturn>) (which implements the TURN REST API) and integrate it with synapse.
@@ -171,10 +171,10 @@ Your homeserver configuration file needs the following extra keys:
for your TURN server to be given out to your clients. Add separate for your TURN server to be given out to your clients. Add separate
entries for each transport your TURN server supports. entries for each transport your TURN server supports.
2. "`turn_shared_secret`": This is the secret shared between your 2. "`turn_shared_secret`": This is the secret shared between your
homeserver and your TURN server, so you should set it to the same Home server and your TURN server, so you should set it to the same
string you used in turnserver.conf. string you used in turnserver.conf.
3. "`turn_user_lifetime`": This is the amount of time credentials 3. "`turn_user_lifetime`": This is the amount of time credentials
generated by your homeserver are valid for (in milliseconds). generated by your Home Server are valid for (in milliseconds).
Shorter times offer less potential for abuse at the expense of Shorter times offer less potential for abuse at the expense of
increased traffic between web clients and your home server to increased traffic between web clients and your home server to
refresh credentials. The TURN REST API specification recommends refresh credentials. The TURN REST API specification recommends
@@ -220,7 +220,7 @@ Here are a few things to try:
anyone who has successfully set this up. anyone who has successfully set this up.
* Check that you have opened your firewall to allow TCP and UDP traffic to the * Check that you have opened your firewall to allow TCP and UDP traffic to the
TURN ports (normally 3478 and 5349). TURN ports (normally 3478 and 5479).
* Check that you have opened your firewall to allow UDP traffic to the UDP * Check that you have opened your firewall to allow UDP traffic to the UDP
relay ports (49152-65535 by default). relay ports (49152-65535 by default).

View File

@@ -182,10 +182,10 @@ This worker can handle API requests matching the following regular
expressions: expressions:
# Sync requests # Sync requests
^/_matrix/client/(v2_alpha|r0|v3)/sync$ ^/_matrix/client/(v2_alpha|r0)/sync$
^/_matrix/client/(api/v1|v2_alpha|r0|v3)/events$ ^/_matrix/client/(api/v1|v2_alpha|r0)/events$
^/_matrix/client/(api/v1|r0|v3)/initialSync$ ^/_matrix/client/(api/v1|r0)/initialSync$
^/_matrix/client/(api/v1|r0|v3)/rooms/[^/]+/initialSync$ ^/_matrix/client/(api/v1|r0)/rooms/[^/]+/initialSync$
# Federation requests # Federation requests
^/_matrix/federation/v1/event/ ^/_matrix/federation/v1/event/
@@ -216,40 +216,40 @@ expressions:
^/_matrix/federation/v1/send/ ^/_matrix/federation/v1/send/
# Client API requests # Client API requests
^/_matrix/client/(api/v1|r0|v3|unstable)/createRoom$ ^/_matrix/client/(api/v1|r0|unstable)/createRoom$
^/_matrix/client/(api/v1|r0|v3|unstable)/publicRooms$ ^/_matrix/client/(api/v1|r0|unstable)/publicRooms$
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/joined_members$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/joined_members$
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/context/.*$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/context/.*$
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/members$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$ ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/spaces$
^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/hierarchy$ ^/_matrix/client/unstable/org.matrix.msc2946/rooms/.*/hierarchy$
^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$ ^/_matrix/client/unstable/im.nheko.summary/rooms/.*/summary$
^/_matrix/client/(api/v1|r0|v3|unstable)/account/3pid$ ^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
^/_matrix/client/(api/v1|r0|v3|unstable)/devices$ ^/_matrix/client/(api/v1|r0|unstable)/devices$
^/_matrix/client/(api/v1|r0|v3|unstable)/keys/query$ ^/_matrix/client/(api/v1|r0|unstable)/keys/query$
^/_matrix/client/(api/v1|r0|v3|unstable)/keys/changes$ ^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
^/_matrix/client/versions$ ^/_matrix/client/versions$
^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ ^/_matrix/client/(api/v1|r0|unstable)/voip/turnServer$
^/_matrix/client/(api/v1|r0|v3|unstable)/joined_groups$ ^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups$ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
^/_matrix/client/(api/v1|r0|v3|unstable)/publicised_groups/ ^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/event/
^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ ^/_matrix/client/(api/v1|r0|unstable)/joined_rooms$
^/_matrix/client/(api/v1|r0|v3|unstable)/search$ ^/_matrix/client/(api/v1|r0|unstable)/search$
# Registration/login requests # Registration/login requests
^/_matrix/client/(api/v1|r0|v3|unstable)/login$ ^/_matrix/client/(api/v1|r0|unstable)/login$
^/_matrix/client/(r0|v3|unstable)/register$ ^/_matrix/client/(r0|unstable)/register$
^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$ ^/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity$
# Event sending requests # Event sending requests
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/redact ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/redact
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/send ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/state/ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state/
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/(join|invite|leave|ban|unban|kick)$
^/_matrix/client/(api/v1|r0|v3|unstable)/join/ ^/_matrix/client/(api/v1|r0|unstable)/join/
^/_matrix/client/(api/v1|r0|v3|unstable)/profile/ ^/_matrix/client/(api/v1|r0|unstable)/profile/
Additionally, the following REST endpoints can be handled for GET requests: Additionally, the following REST endpoints can be handled for GET requests:
@@ -261,14 +261,14 @@ room must be routed to the same instance. Additionally, care must be taken to
ensure that the purge history admin API is not used while pagination requests ensure that the purge history admin API is not used while pagination requests
for the room are in flight: for the room are in flight:
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/messages$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$
Additionally, the following endpoints should be included if Synapse is configured Additionally, the following endpoints should be included if Synapse is configured
to use SSO (you only need to include the ones for whichever SSO provider you're to use SSO (you only need to include the ones for whichever SSO provider you're
using): using):
# for all SSO providers # for all SSO providers
^/_matrix/client/(api/v1|r0|v3|unstable)/login/sso/redirect ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect
^/_synapse/client/pick_idp$ ^/_synapse/client/pick_idp$
^/_synapse/client/pick_username ^/_synapse/client/pick_username
^/_synapse/client/new_user_consent$ ^/_synapse/client/new_user_consent$
@@ -281,7 +281,7 @@ using):
^/_synapse/client/saml2/authn_response$ ^/_synapse/client/saml2/authn_response$
# CAS requests. # CAS requests.
^/_matrix/client/(api/v1|r0|v3|unstable)/login/cas/ticket$ ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$
Ensure that all SSO logins go to a single process. Ensure that all SSO logins go to a single process.
For multiple workers not handling the SSO endpoints properly, see For multiple workers not handling the SSO endpoints properly, see
@@ -465,7 +465,7 @@ Note that if a reverse proxy is used , then `/_matrix/media/` must be routed for
Handles searches in the user directory. It can handle REST endpoints matching Handles searches in the user directory. It can handle REST endpoints matching
the following regular expressions: the following regular expressions:
^/_matrix/client/(api/v1|r0|v3|unstable)/user_directory/search$ ^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$
When using this worker you must also set `update_user_directory: False` in the When using this worker you must also set `update_user_directory: False` in the
shared configuration file to stop the main synapse running background shared configuration file to stop the main synapse running background
@@ -477,12 +477,12 @@ Proxies some frequently-requested client endpoints to add caching and remove
load from the main synapse. It can handle REST endpoints matching the following load from the main synapse. It can handle REST endpoints matching the following
regular expressions: regular expressions:
^/_matrix/client/(api/v1|r0|v3|unstable)/keys/upload ^/_matrix/client/(api/v1|r0|unstable)/keys/upload
If `use_presence` is False in the homeserver config, it can also handle REST If `use_presence` is False in the homeserver config, it can also handle REST
endpoints matching the following regular expressions: endpoints matching the following regular expressions:
^/_matrix/client/(api/v1|r0|v3|unstable)/presence/[^/]+/status ^/_matrix/client/(api/v1|r0|unstable)/presence/[^/]+/status
This "stub" presence handler will pass through `GET` request but make the This "stub" presence handler will pass through `GET` request but make the
`PUT` effectively a no-op. `PUT` effectively a no-op.

View File

@@ -160,9 +160,6 @@ disallow_untyped_defs = True
[mypy-synapse.handlers.*] [mypy-synapse.handlers.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.metrics.*]
disallow_untyped_defs = True
[mypy-synapse.push.*] [mypy-synapse.push.*]
disallow_untyped_defs = True disallow_untyped_defs = True
@@ -199,11 +196,92 @@ disallow_untyped_defs = True
[mypy-synapse.streams.*] [mypy-synapse.streams.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.util.*] [mypy-synapse.util.batching_queue]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.util.caches.treecache] [mypy-synapse.util.caches.cached_call]
disallow_untyped_defs = False disallow_untyped_defs = True
[mypy-synapse.util.caches.dictionary_cache]
disallow_untyped_defs = True
[mypy-synapse.util.caches.lrucache]
disallow_untyped_defs = True
[mypy-synapse.util.caches.response_cache]
disallow_untyped_defs = True
[mypy-synapse.util.caches.stream_change_cache]
disallow_untyped_defs = True
[mypy-synapse.util.caches.ttl_cache]
disallow_untyped_defs = True
[mypy-synapse.util.daemonize]
disallow_untyped_defs = True
[mypy-synapse.util.file_consumer]
disallow_untyped_defs = True
[mypy-synapse.util.frozenutils]
disallow_untyped_defs = True
[mypy-synapse.util.hash]
disallow_untyped_defs = True
[mypy-synapse.util.httpresourcetree]
disallow_untyped_defs = True
[mypy-synapse.util.iterutils]
disallow_untyped_defs = True
[mypy-synapse.util.linked_list]
disallow_untyped_defs = True
[mypy-synapse.util.logcontext]
disallow_untyped_defs = True
[mypy-synapse.util.logformatter]
disallow_untyped_defs = True
[mypy-synapse.util.macaroons]
disallow_untyped_defs = True
[mypy-synapse.util.manhole]
disallow_untyped_defs = True
[mypy-synapse.util.module_loader]
disallow_untyped_defs = True
[mypy-synapse.util.msisdn]
disallow_untyped_defs = True
[mypy-synapse.util.patch_inline_callbacks]
disallow_untyped_defs = True
[mypy-synapse.util.ratelimitutils]
disallow_untyped_defs = True
[mypy-synapse.util.retryutils]
disallow_untyped_defs = True
[mypy-synapse.util.rlimit]
disallow_untyped_defs = True
[mypy-synapse.util.stringutils]
disallow_untyped_defs = True
[mypy-synapse.util.templates]
disallow_untyped_defs = True
[mypy-synapse.util.threepids]
disallow_untyped_defs = True
[mypy-synapse.util.wheel_timer]
disallow_untyped_defs = True
[mypy-synapse.util.versionstring]
disallow_untyped_defs = True
[mypy-tests.handlers.test_user_directory] [mypy-tests.handlers.test_user_directory]
disallow_untyped_defs = True disallow_untyped_defs = True

View File

@@ -24,7 +24,7 @@
set -e set -e
# Change to the repository root # Change to the repository root
cd "$(dirname $0)/.." cd "$(dirname "$0")/.."
# Check for a user-specified Complement checkout # Check for a user-specified Complement checkout
if [[ -z "$COMPLEMENT_DIR" ]]; then if [[ -z "$COMPLEMENT_DIR" ]]; then
@@ -61,8 +61,8 @@ cd "$COMPLEMENT_DIR"
EXTRA_COMPLEMENT_ARGS="" EXTRA_COMPLEMENT_ARGS=""
if [[ -n "$1" ]]; then if [[ -n "$1" ]]; then
# A test name regex has been set, supply it to Complement # A test name regex has been set, supply it to Complement
EXTRA_COMPLEMENT_ARGS+="-run $1 " EXTRA_COMPLEMENT_ARGS=(-run "$1")
fi fi
# Run the tests! # Run the tests!
go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403,msc2716 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/... go test -v -tags synapse_blacklist,msc2946,msc3083,msc2403,msc2716 -count=1 "${EXTRA_COMPLEMENT_ARGS[@]}" ./tests/...

View File

@@ -0,0 +1,179 @@
#! /usr/bin/env python3
import argparse
import os
import re
import subprocess
import sys
import tempfile
from typing import Iterable, Optional, Set
import networkx
def scrape_storage_classes() -> str:
"""Grep the for classes ending with "Store" and extract their list of parents.
Returns the stdout from `rg` as a single string."""
# TODO: this is a big hack which assumes that each Store class has a unique name.
# That assumption is wrong: there are two DirectoryStores, one in
# synapse/replication/slave/storage/directory.py and the other in
# synapse/storage/databases/main/directory.py
# Would be nice to have a way to account for this.
return subprocess.check_output(
[
"rg",
"-o",
"--no-line-number",
"--no-filename",
"--multiline",
r"class .*Store\((.|\n)*?\):$",
"synapse",
"tests",
],
).decode()
oneline_class_pattern = re.compile(r"^class (.*)\((.*)\):$")
opening_class_pattern = re.compile(r"^class (.*)\($")
def load_graph(lines: Iterable[str]) -> networkx.DiGraph:
"""Process the output of scrape_storage_classes to build an inheritance graph.
Every time a class C is created that explicitly inherits from a parent P, we add an
edge C -> P.
"""
G = networkx.DiGraph()
child: Optional[str] = None
for line in lines:
line = line.strip()
if not line or line.startswith("#"):
continue
if (match := oneline_class_pattern.match(line)) is not None:
child, parents = match.groups()
for parent in parents.split(", "):
if "metaclass" not in parent:
G.add_edge(child, parent)
child = None
elif (match := opening_class_pattern.match(line)) is not None:
(child,) = match.groups()
elif line == "):":
child = None
else:
assert child is not None, repr(line)
parent = line.strip(",")
if "metaclass" not in parent:
G.add_edge(child, parent)
return G
def select_vertices_of_interest(G: networkx.DiGraph, target: Optional[str]) -> Set[str]:
"""Find all nodes we want to visualise.
If no TARGET is given, we visualise all of G. Otherwise we visualise a given
TARGET, its parents, and all of their parents recursively.
Requires that G is a DAG.
If not None, the TARGET must belong to G.
"""
assert networkx.is_directed_acyclic_graph(G)
if target is not None:
component: Set[str] = networkx.descendants(G, target)
component.add(target)
else:
component = set(G.nodes)
return component
def generate_dot_source(G: networkx.DiGraph, nodes: Set[str]) -> str:
output = """\
strict digraph {
rankdir="LR";
node [shape=box];
"""
for (child, parent) in G.edges:
if child in nodes and parent in nodes:
output += f" {child} -> {parent};\n"
output += "}\n"
return output
def render_png(dot_source: str, destination: Optional[str]) -> str:
if destination is None:
handle, destination = tempfile.mkstemp()
os.close(handle)
print("Warning: writing to", destination, "which will persist", file=sys.stderr)
subprocess.run(
[
"dot",
"-o",
destination,
"-Tpng",
],
input=dot_source,
encoding="utf-8",
check=True,
)
return destination
def show_graph(location: str) -> None:
subprocess.run(
["xdg-open", location],
check=True,
)
def main(parser: argparse.ArgumentParser, args: argparse.Namespace) -> int:
if not (args.output or args.show):
parser.print_help(file=sys.stderr)
print("Must either --output or --show, or both.", file=sys.stderr)
return os.EX_USAGE
lines = scrape_storage_classes().split("\n")
G = load_graph(lines)
nodes = select_vertices_of_interest(G, args.target)
dot_source = generate_dot_source(G, nodes)
output_location = render_png(dot_source, args.output)
if args.show:
show_graph(output_location)
return os.EX_OK
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Visualise the inheritance of Synapse's storage classes. Requires "
"ripgrep (https://github.com/BurntSushi/ripgrep) as 'rg'; graphviz "
"(https://graphviz.org/) for the 'dot' program; and networkx "
"(https://networkx.org/). Requires Python 3.8+ for the walrus"
"operator."
)
parser.add_argument(
"target",
nargs="?",
help="Show only TARGET and its ancestors. Otherwise, show the entire hierarchy.",
)
parser.add_argument(
"--output",
nargs=1,
help="Render inheritance graph to a png file.",
)
parser.add_argument(
"--show",
action="store_true",
help="Open the inheritance graph in an image viewer.",
)
return parser
if __name__ == "__main__":
parser = build_parser()
args = parser.parse_args()
sys.exit(main(parser, args))

View File

@@ -135,6 +135,8 @@ CONDITIONAL_REQUIREMENTS["dev"] = (
# The following are executed as commands by the release script. # The following are executed as commands by the release script.
"twine", "twine",
"towncrier", "towncrier",
# For storage_inheritance script
"networkx==2.6.3",
] ]
) )

View File

@@ -47,7 +47,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.47.0" __version__ = "1.47.0rc2"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View File

@@ -30,8 +30,7 @@ FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable"
STATIC_PREFIX = "/_matrix/static" STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client" WEB_CLIENT_PREFIX = "/_matrix/client"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_R0_PREFIX = "/_matrix/media/r0" MEDIA_PREFIX = "/_matrix/media/r0"
MEDIA_V3_PREFIX = "/_matrix/media/v3"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"

View File

@@ -402,7 +402,7 @@ async def start(hs: "HomeServer") -> None:
if hasattr(signal, "SIGHUP"): if hasattr(signal, "SIGHUP"):
@wrap_as_background_process("sighup") @wrap_as_background_process("sighup")
async def handle_sighup(*args: Any, **kwargs: Any) -> None: def handle_sighup(*args: Any, **kwargs: Any) -> None:
# Tell systemd our state, if we're using it. This will silently fail if # Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd. # we're not using systemd.
sdnotify(b"RELOADING=1") sdnotify(b"RELOADING=1")

View File

@@ -28,11 +28,13 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging from synapse.config.logger import setup_logging
from synapse.events import EventBase from synapse.events import EventBase
from synapse.handlers.admin import ExfiltrationWriter from synapse.handlers.admin import ExfiltrationWriter
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.groups import SlavedGroupServerStore from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
@@ -57,7 +59,9 @@ class AdminCmdSlavedStore(
SlavedDeviceInboxStore, SlavedDeviceInboxStore,
SlavedDeviceStore, SlavedDeviceStore,
SlavedPushRuleStore, SlavedPushRuleStore,
SlavedEventStore,
SlavedClientIpStore, SlavedClientIpStore,
BaseSlavedStore,
RoomWorkerStore, RoomWorkerStore,
): ):
pass pass

View File

@@ -26,8 +26,7 @@ from synapse.api.urls import (
CLIENT_API_PREFIX, CLIENT_API_PREFIX,
FEDERATION_PREFIX, FEDERATION_PREFIX,
LEGACY_MEDIA_PREFIX, LEGACY_MEDIA_PREFIX,
MEDIA_R0_PREFIX, MEDIA_PREFIX,
MEDIA_V3_PREFIX,
SERVER_KEY_V2_PREFIX, SERVER_KEY_V2_PREFIX,
) )
from synapse.app import _base from synapse.app import _base
@@ -48,12 +47,14 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext
from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.filtering import SlavedFilteringStore from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.groups import SlavedGroupServerStore from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.slave.storage.keys import SlavedKeyStore from synapse.replication.slave.storage.keys import SlavedKeyStore
@@ -235,6 +236,7 @@ class GenericWorkerSlavedStore(
SlavedPusherStore, SlavedPusherStore,
CensorEventsStore, CensorEventsStore,
ClientIpWorkerStore, ClientIpWorkerStore,
SlavedEventStore,
SlavedKeyStore, SlavedKeyStore,
RoomWorkerStore, RoomWorkerStore,
DirectoryStore, DirectoryStore,
@@ -250,6 +252,7 @@ class GenericWorkerSlavedStore(
TransactionWorkerStore, TransactionWorkerStore,
LockStore, LockStore,
SessionStore, SessionStore,
BaseSlavedStore,
): ):
# Properties that multiple storage classes define. Tell mypy what the # Properties that multiple storage classes define. Tell mypy what the
# expected type is. # expected type is.
@@ -335,8 +338,7 @@ class GenericWorkerServer(HomeServer):
resources.update( resources.update(
{ {
MEDIA_R0_PREFIX: media_repo, MEDIA_PREFIX: media_repo,
MEDIA_V3_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo,
"/_synapse/admin": admin_resource, "/_synapse/admin": admin_resource,
} }

View File

@@ -29,8 +29,7 @@ from synapse import events
from synapse.api.urls import ( from synapse.api.urls import (
FEDERATION_PREFIX, FEDERATION_PREFIX,
LEGACY_MEDIA_PREFIX, LEGACY_MEDIA_PREFIX,
MEDIA_R0_PREFIX, MEDIA_PREFIX,
MEDIA_V3_PREFIX,
SERVER_KEY_V2_PREFIX, SERVER_KEY_V2_PREFIX,
STATIC_PREFIX, STATIC_PREFIX,
WEB_CLIENT_PREFIX, WEB_CLIENT_PREFIX,
@@ -194,7 +193,6 @@ class SynapseHomeServer(HomeServer):
{ {
"/_matrix/client/api/v1": client_resource, "/_matrix/client/api/v1": client_resource,
"/_matrix/client/r0": client_resource, "/_matrix/client/r0": client_resource,
"/_matrix/client/v3": client_resource,
"/_matrix/client/unstable": client_resource, "/_matrix/client/unstable": client_resource,
"/_matrix/client/v2_alpha": client_resource, "/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource, "/_matrix/client/versions": client_resource,
@@ -246,11 +244,7 @@ class SynapseHomeServer(HomeServer):
if self.config.server.enable_media_repo: if self.config.server.enable_media_repo:
media_repo = self.get_media_repository_resource() media_repo = self.get_media_repository_resource()
resources.update( resources.update(
{ {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo}
MEDIA_R0_PREFIX: media_repo,
MEDIA_V3_PREFIX: media_repo,
LEGACY_MEDIA_PREFIX: media_repo,
}
) )
elif name == "media": elif name == "media":
raise ConfigError( raise ConfigError(

View File

@@ -40,8 +40,6 @@ from typing import TYPE_CHECKING, Optional, Tuple
from signedjson.sign import sign_json from signedjson.sign import sign_json
from twisted.internet.defer import Deferred
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonDict, get_domain_from_id from synapse.types import JsonDict, get_domain_from_id
@@ -168,7 +166,7 @@ class GroupAttestionRenewer:
return {} return {}
def _start_renew_attestations(self) -> "Deferred[None]": def _start_renew_attestations(self) -> None:
return run_as_background_process("renew_attestations", self._renew_attestations) return run_as_background_process("renew_attestations", self._renew_attestations)
async def _renew_attestations(self) -> None: async def _renew_attestations(self) -> None:

View File

@@ -793,7 +793,7 @@ class AuthHandler:
) = await self.get_refresh_token_for_user_id( ) = await self.get_refresh_token_for_user_id(
user_id=existing_token.user_id, device_id=existing_token.device_id user_id=existing_token.user_id, device_id=existing_token.device_id
) )
access_token = await self.create_access_token_for_user_id( access_token = await self.get_access_token_for_user_id(
user_id=existing_token.user_id, user_id=existing_token.user_id,
device_id=existing_token.device_id, device_id=existing_token.device_id,
valid_until_ms=valid_until_ms, valid_until_ms=valid_until_ms,
@@ -855,7 +855,7 @@ class AuthHandler:
) )
return refresh_token, refresh_token_id return refresh_token, refresh_token_id
async def create_access_token_for_user_id( async def get_access_token_for_user_id(
self, self,
user_id: str, user_id: str,
device_id: Optional[str], device_id: Optional[str],
@@ -1828,6 +1828,13 @@ def load_single_legacy_password_auth_provider(
logger.error("Error while initializing %r: %s", module, e) logger.error("Error while initializing %r: %s", module, e)
raise raise
# The known hooks. If a module implements a method who's name appears in this set
# we'll want to register it
password_auth_provider_methods = {
"check_3pid_auth",
"on_logged_out",
}
# All methods that the module provides should be async, but this wasn't enforced # All methods that the module provides should be async, but this wasn't enforced
# in the old module system, so we wrap them if needed # in the old module system, so we wrap them if needed
def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]: def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
@@ -1912,14 +1919,11 @@ def load_single_legacy_password_auth_provider(
return run return run
# If the module has these methods implemented, then we pull them out # populate hooks with the implemented methods, wrapped with async_wrapper
# and register them as hooks. hooks = {
check_3pid_auth_hook: Optional[CHECK_3PID_AUTH_CALLBACK] = async_wrapper( hook: async_wrapper(getattr(provider, hook, None))
getattr(provider, "check_3pid_auth", None) for hook in password_auth_provider_methods
) }
on_logged_out_hook: Optional[ON_LOGGED_OUT_CALLBACK] = async_wrapper(
getattr(provider, "on_logged_out", None)
)
supported_login_types = {} supported_login_types = {}
# call get_supported_login_types and add that to the dict # call get_supported_login_types and add that to the dict
@@ -1946,11 +1950,7 @@ def load_single_legacy_password_auth_provider(
# need to use a tuple here for ("password",) not a list since lists aren't hashable # need to use a tuple here for ("password",) not a list since lists aren't hashable
auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password auth_checkers[(LoginType.PASSWORD, ("password",))] = check_password
api.register_password_auth_provider_callbacks( api.register_password_auth_provider_callbacks(hooks, auth_checkers=auth_checkers)
check_3pid_auth=check_3pid_auth_hook,
on_logged_out=on_logged_out_hook,
auth_checkers=auth_checkers,
)
CHECK_3PID_AUTH_CALLBACK = Callable[ CHECK_3PID_AUTH_CALLBACK = Callable[

View File

@@ -819,7 +819,7 @@ class RegistrationHandler:
) )
valid_until_ms = self.clock.time_msec() + self.access_token_lifetime valid_until_ms = self.clock.time_msec() + self.access_token_lifetime
access_token = await self._auth_handler.create_access_token_for_user_id( access_token = await self._auth_handler.get_access_token_for_user_id(
user_id, user_id,
device_id=registered_device_id, device_id=registered_device_id,
valid_until_ms=valid_until_ms, valid_until_ms=valid_until_ms,

View File

@@ -97,7 +97,7 @@ class RoomSummaryHandler:
# If a user tries to fetch the same page multiple times in quick succession, # If a user tries to fetch the same page multiple times in quick succession,
# only process the first attempt and return its result to subsequent requests. # only process the first attempt and return its result to subsequent requests.
self._pagination_response_cache: ResponseCache[ self._pagination_response_cache: ResponseCache[
Tuple[str, str, bool, Optional[int], Optional[int], Optional[str]] Tuple[str, bool, Optional[int], Optional[int], Optional[str]]
] = ResponseCache( ] = ResponseCache(
hs.get_clock(), hs.get_clock(),
"get_room_hierarchy", "get_room_hierarchy",
@@ -282,14 +282,7 @@ class RoomSummaryHandler:
# This is due to the pagination process mutating internal state, attempting # This is due to the pagination process mutating internal state, attempting
# to process multiple requests for the same page will result in errors. # to process multiple requests for the same page will result in errors.
return await self._pagination_response_cache.wrap( return await self._pagination_response_cache.wrap(
( (requested_room_id, suggested_only, max_depth, limit, from_token),
requester,
requested_room_id,
suggested_only,
max_depth,
limit,
from_token,
),
self._get_room_hierarchy, self._get_room_hierarchy,
requester, requester,
requested_room_id, requested_room_id,

View File

@@ -90,7 +90,7 @@ class FollowerTypingHandler:
self.wheel_timer = WheelTimer(bucket_size=5000) self.wheel_timer = WheelTimer(bucket_size=5000)
@wrap_as_background_process("typing._handle_timeouts") @wrap_as_background_process("typing._handle_timeouts")
async def _handle_timeouts(self) -> None: def _handle_timeouts(self) -> None:
logger.debug("Checking for typing timeouts") logger.debug("Checking for typing timeouts")
now = self.clock.time_msec() now = self.clock.time_msec()

View File

@@ -20,25 +20,10 @@ import os
import platform import platform
import threading import threading
import time import time
from typing import ( from typing import Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
Any,
Callable,
Dict,
Generic,
Iterable,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
import attr import attr
from prometheus_client import CollectorRegistry, Counter, Gauge, Histogram, Metric from prometheus_client import Counter, Gauge, Histogram
from prometheus_client.core import ( from prometheus_client.core import (
REGISTRY, REGISTRY,
CounterMetricFamily, CounterMetricFamily,
@@ -47,7 +32,6 @@ from prometheus_client.core import (
) )
from twisted.internet import reactor from twisted.internet import reactor
from twisted.internet.base import ReactorBase
from twisted.python.threadpool import ThreadPool from twisted.python.threadpool import ThreadPool
import synapse import synapse
@@ -70,7 +54,7 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
class RegistryProxy: class RegistryProxy:
@staticmethod @staticmethod
def collect() -> Iterable[Metric]: def collect():
for metric in REGISTRY.collect(): for metric in REGISTRY.collect():
if not metric.name.startswith("__"): if not metric.name.startswith("__"):
yield metric yield metric
@@ -90,7 +74,7 @@ class LaterGauge:
] ]
) )
def collect(self) -> Iterable[Metric]: def collect(self):
g = GaugeMetricFamily(self.name, self.desc, labels=self.labels) g = GaugeMetricFamily(self.name, self.desc, labels=self.labels)
@@ -109,10 +93,10 @@ class LaterGauge:
yield g yield g
def __attrs_post_init__(self) -> None: def __attrs_post_init__(self):
self._register() self._register()
def _register(self) -> None: def _register(self):
if self.name in all_gauges.keys(): if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering" % (self.name,)) logger.warning("%s already registered, reregistering" % (self.name,))
REGISTRY.unregister(all_gauges.pop(self.name)) REGISTRY.unregister(all_gauges.pop(self.name))
@@ -121,12 +105,7 @@ class LaterGauge:
all_gauges[self.name] = self all_gauges[self.name] = self
# `MetricsEntry` only makes sense when it is a `Protocol`, class InFlightGauge:
# but `Protocol` can't be used as a `TypeVar` bound.
MetricsEntry = TypeVar("MetricsEntry")
class InFlightGauge(Generic[MetricsEntry]):
"""Tracks number of things (e.g. requests, Measure blocks, etc) in flight """Tracks number of things (e.g. requests, Measure blocks, etc) in flight
at any given time. at any given time.
@@ -136,19 +115,14 @@ class InFlightGauge(Generic[MetricsEntry]):
callbacks. callbacks.
Args: Args:
name name (str)
desc desc (str)
labels labels (list[str])
sub_metrics: A list of sub metrics that the callbacks will update. sub_metrics (list[str]): A list of sub metrics that the callbacks
will update.
""" """
def __init__( def __init__(self, name, desc, labels, sub_metrics):
self,
name: str,
desc: str,
labels: Sequence[str],
sub_metrics: Sequence[str],
):
self.name = name self.name = name
self.desc = desc self.desc = desc
self.labels = labels self.labels = labels
@@ -156,25 +130,19 @@ class InFlightGauge(Generic[MetricsEntry]):
# Create a class which have the sub_metrics values as attributes, which # Create a class which have the sub_metrics values as attributes, which
# default to 0 on initialization. Used to pass to registered callbacks. # default to 0 on initialization. Used to pass to registered callbacks.
self._metrics_class: Type[MetricsEntry] = attr.make_class( self._metrics_class = attr.make_class(
"_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True
) )
# Counts number of in flight blocks for a given set of label values # Counts number of in flight blocks for a given set of label values
self._registrations: Dict[ self._registrations: Dict = {}
Tuple[str, ...], Set[Callable[[MetricsEntry], None]]
] = {}
# Protects access to _registrations # Protects access to _registrations
self._lock = threading.Lock() self._lock = threading.Lock()
self._register_with_collector() self._register_with_collector()
def register( def register(self, key, callback):
self,
key: Tuple[str, ...],
callback: Callable[[MetricsEntry], None],
) -> None:
"""Registers that we've entered a new block with labels `key`. """Registers that we've entered a new block with labels `key`.
`callback` gets called each time the metrics are collected. The same `callback` gets called each time the metrics are collected. The same
@@ -190,17 +158,13 @@ class InFlightGauge(Generic[MetricsEntry]):
with self._lock: with self._lock:
self._registrations.setdefault(key, set()).add(callback) self._registrations.setdefault(key, set()).add(callback)
def unregister( def unregister(self, key, callback):
self,
key: Tuple[str, ...],
callback: Callable[[MetricsEntry], None],
) -> None:
"""Registers that we've exited a block with labels `key`.""" """Registers that we've exited a block with labels `key`."""
with self._lock: with self._lock:
self._registrations.setdefault(key, set()).discard(callback) self._registrations.setdefault(key, set()).discard(callback)
def collect(self) -> Iterable[Metric]: def collect(self):
"""Called by prometheus client when it reads metrics. """Called by prometheus client when it reads metrics.
Note: may be called by a separate thread. Note: may be called by a separate thread.
@@ -236,7 +200,7 @@ class InFlightGauge(Generic[MetricsEntry]):
gauge.add_metric(key, getattr(metrics, name)) gauge.add_metric(key, getattr(metrics, name))
yield gauge yield gauge
def _register_with_collector(self) -> None: def _register_with_collector(self):
if self.name in all_gauges.keys(): if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering" % (self.name,)) logger.warning("%s already registered, reregistering" % (self.name,))
REGISTRY.unregister(all_gauges.pop(self.name)) REGISTRY.unregister(all_gauges.pop(self.name))
@@ -266,7 +230,7 @@ class GaugeBucketCollector:
name: str, name: str,
documentation: str, documentation: str,
buckets: Iterable[float], buckets: Iterable[float],
registry: CollectorRegistry = REGISTRY, registry=REGISTRY,
): ):
""" """
Args: Args:
@@ -293,12 +257,12 @@ class GaugeBucketCollector:
registry.register(self) registry.register(self)
def collect(self) -> Iterable[Metric]: def collect(self):
# Don't report metrics unless we've already collected some data # Don't report metrics unless we've already collected some data
if self._metric is not None: if self._metric is not None:
yield self._metric yield self._metric
def update_data(self, values: Iterable[float]) -> None: def update_data(self, values: Iterable[float]):
"""Update the data to be reported by the metric """Update the data to be reported by the metric
The existing data is cleared, and each measurement in the input is assigned The existing data is cleared, and each measurement in the input is assigned
@@ -340,7 +304,7 @@ class GaugeBucketCollector:
class CPUMetrics: class CPUMetrics:
def __init__(self) -> None: def __init__(self):
ticks_per_sec = 100 ticks_per_sec = 100
try: try:
# Try and get the system config # Try and get the system config
@@ -350,7 +314,7 @@ class CPUMetrics:
self.ticks_per_sec = ticks_per_sec self.ticks_per_sec = ticks_per_sec
def collect(self) -> Iterable[Metric]: def collect(self):
if not HAVE_PROC_SELF_STAT: if not HAVE_PROC_SELF_STAT:
return return
@@ -400,7 +364,7 @@ gc_time = Histogram(
class GCCounts: class GCCounts:
def collect(self) -> Iterable[Metric]: def collect(self):
cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
for n, m in enumerate(gc.get_count()): for n, m in enumerate(gc.get_count()):
cm.add_metric([str(n)], m) cm.add_metric([str(n)], m)
@@ -418,7 +382,7 @@ if not running_on_pypy:
class PyPyGCStats: class PyPyGCStats:
def collect(self) -> Iterable[Metric]: def collect(self):
# @stats is a pretty-printer object with __str__() returning a nice table, # @stats is a pretty-printer object with __str__() returning a nice table,
# plus some fields that contain data from that table. # plus some fields that contain data from that table.
@@ -601,7 +565,7 @@ def register_threadpool(name: str, threadpool: ThreadPool) -> None:
class ReactorLastSeenMetric: class ReactorLastSeenMetric:
def collect(self) -> Iterable[Metric]: def collect(self):
cm = GaugeMetricFamily( cm = GaugeMetricFamily(
"python_twisted_reactor_last_seen", "python_twisted_reactor_last_seen",
"Seconds since the Twisted reactor was last seen", "Seconds since the Twisted reactor was last seen",
@@ -620,12 +584,9 @@ MIN_TIME_BETWEEN_GCS = (1.0, 10.0, 30.0)
_last_gc = [0.0, 0.0, 0.0] _last_gc = [0.0, 0.0, 0.0]
F = TypeVar("F", bound=Callable[..., Any]) def runUntilCurrentTimer(reactor, func):
def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F:
@functools.wraps(func) @functools.wraps(func)
def f(*args: Any, **kwargs: Any) -> Any: def f(*args, **kwargs):
now = reactor.seconds() now = reactor.seconds()
num_pending = 0 num_pending = 0
@@ -688,7 +649,7 @@ def runUntilCurrentTimer(reactor: ReactorBase, func: F) -> F:
return ret return ret
return cast(F, f) return f
try: try:
@@ -716,5 +677,5 @@ __all__ = [
"start_http_server", "start_http_server",
"LaterGauge", "LaterGauge",
"InFlightGauge", "InFlightGauge",
"GaugeBucketCollector", "BucketCollector",
] ]

View File

@@ -25,25 +25,27 @@ import math
import threading import threading
from http.server import BaseHTTPRequestHandler, HTTPServer from http.server import BaseHTTPRequestHandler, HTTPServer
from socketserver import ThreadingMixIn from socketserver import ThreadingMixIn
from typing import Any, Dict, List, Type, Union from typing import Dict, List
from urllib.parse import parse_qs, urlparse from urllib.parse import parse_qs, urlparse
from prometheus_client import REGISTRY, CollectorRegistry from prometheus_client import REGISTRY
from prometheus_client.core import Sample
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.util import caches from synapse.util import caches
CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
def floatToGoString(d: Union[int, float]) -> str: INF = float("inf")
MINUS_INF = float("-inf")
def floatToGoString(d):
d = float(d) d = float(d)
if d == math.inf: if d == INF:
return "+Inf" return "+Inf"
elif d == -math.inf: elif d == MINUS_INF:
return "-Inf" return "-Inf"
elif math.isnan(d): elif math.isnan(d):
return "NaN" return "NaN"
@@ -58,7 +60,7 @@ def floatToGoString(d: Union[int, float]) -> str:
return s return s
def sample_line(line: Sample, name: str) -> str: def sample_line(line, name):
if line.labels: if line.labels:
labelstr = "{{{0}}}".format( labelstr = "{{{0}}}".format(
",".join( ",".join(
@@ -80,7 +82,7 @@ def sample_line(line: Sample, name: str) -> str:
return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp) return "{}{} {}{}\n".format(name, labelstr, floatToGoString(line.value), timestamp)
def generate_latest(registry: CollectorRegistry, emit_help: bool = False) -> bytes: def generate_latest(registry, emit_help=False):
# Trigger the cache metrics to be rescraped, which updates the common # Trigger the cache metrics to be rescraped, which updates the common
# metrics but do not produce metrics themselves # metrics but do not produce metrics themselves
@@ -185,7 +187,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
registry = REGISTRY registry = REGISTRY
def do_GET(self) -> None: def do_GET(self):
registry = self.registry registry = self.registry
params = parse_qs(urlparse(self.path).query) params = parse_qs(urlparse(self.path).query)
@@ -205,11 +207,11 @@ class MetricsHandler(BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
self.wfile.write(output) self.wfile.write(output)
def log_message(self, format: str, *args: Any) -> None: def log_message(self, format, *args):
"""Log nothing.""" """Log nothing."""
@classmethod @classmethod
def factory(cls, registry: CollectorRegistry) -> Type: def factory(cls, registry):
"""Returns a dynamic MetricsHandler class tied """Returns a dynamic MetricsHandler class tied
to the passed registry. to the passed registry.
""" """
@@ -234,9 +236,7 @@ class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer):
daemon_threads = True daemon_threads = True
def start_http_server( def start_http_server(port, addr="", registry=REGISTRY):
port: int, addr: str = "", registry: CollectorRegistry = REGISTRY
) -> None:
"""Starts an HTTP server for prometheus metrics as a daemon thread""" """Starts an HTTP server for prometheus metrics as a daemon thread"""
CustomMetricsHandler = MetricsHandler.factory(registry) CustomMetricsHandler = MetricsHandler.factory(registry)
httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler) httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler)
@@ -252,10 +252,10 @@ class MetricsResource(Resource):
isLeaf = True isLeaf = True
def __init__(self, registry: CollectorRegistry = REGISTRY): def __init__(self, registry=REGISTRY):
self.registry = registry self.registry = registry
def render_GET(self, request: Request) -> bytes: def render_GET(self, request):
request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii")) request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
response = generate_latest(self.registry) response = generate_latest(self.registry)
request.setHeader(b"Content-Length", str(len(response))) request.setHeader(b"Content-Length", str(len(response)))

View File

@@ -15,37 +15,19 @@
import logging import logging
import threading import threading
from functools import wraps from functools import wraps
from types import TracebackType from typing import TYPE_CHECKING, Dict, Optional, Set, Union
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
Optional,
Set,
Type,
TypeVar,
Union,
cast,
)
from prometheus_client import Metric
from prometheus_client.core import REGISTRY, Counter, Gauge from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer from twisted.internet import defer
from synapse.logging.context import ( from synapse.logging.context import LoggingContext, PreserveLoggingContext
ContextResourceUsage,
LoggingContext,
PreserveLoggingContext,
)
from synapse.logging.opentracing import ( from synapse.logging.opentracing import (
SynapseTags, SynapseTags,
noop_context_manager, noop_context_manager,
start_active_span, start_active_span,
) )
from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING: if TYPE_CHECKING:
import resource import resource
@@ -134,7 +116,7 @@ class _Collector:
before they are returned. before they are returned.
""" """
def collect(self) -> Iterable[Metric]: def collect(self):
global _background_processes_active_since_last_scrape global _background_processes_active_since_last_scrape
# We swap out the _background_processes set with an empty one so that # We swap out the _background_processes set with an empty one so that
@@ -162,12 +144,12 @@ REGISTRY.register(_Collector())
class _BackgroundProcess: class _BackgroundProcess:
def __init__(self, desc: str, ctx: LoggingContext): def __init__(self, desc, ctx):
self.desc = desc self.desc = desc
self._context = ctx self._context = ctx
self._reported_stats: Optional[ContextResourceUsage] = None self._reported_stats = None
def update_metrics(self) -> None: def update_metrics(self):
"""Updates the metrics with values from this process.""" """Updates the metrics with values from this process."""
new_stats = self._context.get_resource_usage() new_stats = self._context.get_resource_usage()
if self._reported_stats is None: if self._reported_stats is None:
@@ -187,16 +169,7 @@ class _BackgroundProcess:
) )
R = TypeVar("R") def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwargs):
def run_as_background_process(
desc: str,
func: Callable[..., Awaitable[Optional[R]]],
*args: Any,
bg_start_span: bool = True,
**kwargs: Any,
) -> "defer.Deferred[Optional[R]]":
"""Run the given function in its own logcontext, with resource metrics """Run the given function in its own logcontext, with resource metrics
This should be used to wrap processes which are fired off to run in the This should be used to wrap processes which are fired off to run in the
@@ -216,13 +189,11 @@ def run_as_background_process(
args: positional args for func args: positional args for func
kwargs: keyword args for func kwargs: keyword args for func
Returns: Returns: Deferred which returns the result of func, but note that it does not
Deferred which returns the result of func, or `None` if func raises. follow the synapse logcontext rules.
Note that the returned Deferred does not follow the synapse logcontext
rules.
""" """
async def run() -> Optional[R]: async def run():
with _bg_metrics_lock: with _bg_metrics_lock:
count = _background_process_counts.get(desc, 0) count = _background_process_counts.get(desc, 0)
_background_process_counts[desc] = count + 1 _background_process_counts[desc] = count + 1
@@ -239,13 +210,12 @@ def run_as_background_process(
else: else:
ctx = noop_context_manager() ctx = noop_context_manager()
with ctx: with ctx:
return await func(*args, **kwargs) return await maybe_awaitable(func(*args, **kwargs))
except Exception: except Exception:
logger.exception( logger.exception(
"Background process '%s' threw an exception", "Background process '%s' threw an exception",
desc, desc,
) )
return None
finally: finally:
_background_process_in_flight_count.labels(desc).dec() _background_process_in_flight_count.labels(desc).dec()
@@ -255,24 +225,19 @@ def run_as_background_process(
return defer.ensureDeferred(run()) return defer.ensureDeferred(run())
F = TypeVar("F", bound=Callable[..., Awaitable[Optional[Any]]]) def wrap_as_background_process(desc):
def wrap_as_background_process(desc: str) -> Callable[[F], F]:
"""Decorator that wraps a function that gets called as a background """Decorator that wraps a function that gets called as a background
process. process.
Equivalent to calling the function with `run_as_background_process` Equivalent of calling the function with `run_as_background_process`
""" """
def wrap_as_background_process_inner(func: F) -> F: def wrap_as_background_process_inner(func):
@wraps(func) @wraps(func)
def wrap_as_background_process_inner_2( def wrap_as_background_process_inner_2(*args, **kwargs):
*args: Any, **kwargs: Any
) -> "defer.Deferred[Optional[R]]":
return run_as_background_process(desc, func, *args, **kwargs) return run_as_background_process(desc, func, *args, **kwargs)
return cast(F, wrap_as_background_process_inner_2) return wrap_as_background_process_inner_2
return wrap_as_background_process_inner return wrap_as_background_process_inner
@@ -300,7 +265,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
super().__init__("%s-%s" % (name, instance_id)) super().__init__("%s-%s" % (name, instance_id))
self._proc = _BackgroundProcess(name, self) self._proc = _BackgroundProcess(name, self)
def start(self, rusage: "Optional[resource.struct_rusage]") -> None: def start(self, rusage: "Optional[resource.struct_rusage]"):
"""Log context has started running (again).""" """Log context has started running (again)."""
super().start(rusage) super().start(rusage)
@@ -311,12 +276,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
with _bg_metrics_lock: with _bg_metrics_lock:
_background_processes_active_since_last_scrape.add(self._proc) _background_processes_active_since_last_scrape.add(self._proc)
def __exit__( def __exit__(self, type, value, traceback) -> None:
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
"""Log context has finished.""" """Log context has finished."""
super().__exit__(type, value, traceback) super().__exit__(type, value, traceback)

View File

@@ -16,16 +16,14 @@ import ctypes
import logging import logging
import os import os
import re import re
from typing import Iterable, Optional from typing import Optional
from prometheus_client import Metric
from synapse.metrics import REGISTRY, GaugeMetricFamily from synapse.metrics import REGISTRY, GaugeMetricFamily
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _setup_jemalloc_stats() -> None: def _setup_jemalloc_stats():
"""Checks to see if jemalloc is loaded, and hooks up a collector to record """Checks to see if jemalloc is loaded, and hooks up a collector to record
statistics exposed by jemalloc. statistics exposed by jemalloc.
""" """
@@ -137,7 +135,7 @@ def _setup_jemalloc_stats() -> None:
class JemallocCollector: class JemallocCollector:
"""Metrics for internal jemalloc stats.""" """Metrics for internal jemalloc stats."""
def collect(self) -> Iterable[Metric]: def collect(self):
_jemalloc_refresh_stats() _jemalloc_refresh_stats()
g = GaugeMetricFamily( g = GaugeMetricFamily(
@@ -187,7 +185,7 @@ def _setup_jemalloc_stats() -> None:
logger.debug("Added jemalloc stats") logger.debug("Added jemalloc stats")
def setup_jemalloc_stats() -> None: def setup_jemalloc_stats():
"""Try to setup jemalloc stats, if jemalloc is loaded.""" """Try to setup jemalloc stats, if jemalloc is loaded."""
try: try:

View File

@@ -898,7 +898,7 @@ class UserTokenRestServlet(RestServlet):
if auth_user.to_string() == user_id: if auth_user.to_string() == user_id:
raise SynapseError(400, "Cannot use admin API to login as self") raise SynapseError(400, "Cannot use admin API to login as self")
token = await self.auth_handler.create_access_token_for_user_id( token = await self.auth_handler.get_access_token_for_user_id(
user_id=auth_user.to_string(), user_id=auth_user.to_string(),
device_id=None, device_id=None,
valid_until_ms=valid_until_ms, valid_until_ms=valid_until_ms,
@@ -909,7 +909,7 @@ class UserTokenRestServlet(RestServlet):
class ShadowBanRestServlet(RestServlet): class ShadowBanRestServlet(RestServlet):
"""An admin API for controlling whether a user is shadow-banned. """An admin API for shadow-banning a user.
A shadow-banned users receives successful responses to their client-server A shadow-banned users receives successful responses to their client-server
API requests, but the events are not propagated into rooms. API requests, but the events are not propagated into rooms.
@@ -917,19 +917,11 @@ class ShadowBanRestServlet(RestServlet):
Shadow-banning a user should be used as a tool of last resort and may lead Shadow-banning a user should be used as a tool of last resort and may lead
to confusing or broken behaviour for the client. to confusing or broken behaviour for the client.
Example of shadow-banning a user: Example:
POST /_synapse/admin/v1/users/@test:example.com/shadow_ban POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
{} {}
200 OK
{}
Example of removing a user from being shadow-banned:
DELETE /_synapse/admin/v1/users/@test:example.com/shadow_ban
{}
200 OK 200 OK
{} {}
""" """
@@ -953,18 +945,6 @@ class ShadowBanRestServlet(RestServlet):
return 200, {} return 200, {}
async def on_DELETE(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Only local users can be shadow-banned")
await self.store.set_shadow_banned(UserID.from_string(user_id), False)
return 200, {}
class RateLimitRestServlet(RestServlet): class RateLimitRestServlet(RestServlet):
"""An admin API to override ratelimiting for an user. """An admin API to override ratelimiting for an user.

View File

@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
def client_patterns( def client_patterns(
path_regex: str, path_regex: str,
releases: Iterable[str] = ("r0", "v3"), releases: Iterable[int] = (0,),
unstable: bool = True, unstable: bool = True,
v1: bool = False, v1: bool = False,
) -> Iterable[Pattern]: ) -> Iterable[Pattern]:
@@ -52,7 +52,7 @@ def client_patterns(
v1_prefix = CLIENT_API_PREFIX + "/api/v1" v1_prefix = CLIENT_API_PREFIX + "/api/v1"
patterns.append(re.compile("^" + v1_prefix + path_regex)) patterns.append(re.compile("^" + v1_prefix + path_regex))
for release in releases: for release in releases:
new_prefix = CLIENT_API_PREFIX + f"/{release}" new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex)) patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns return patterns

View File

@@ -262,7 +262,7 @@ class SigningKeyUploadServlet(RestServlet):
} }
""" """
PATTERNS = client_patterns("/keys/device_signing/upload$", releases=("v3",)) PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()

View File

@@ -188,7 +188,7 @@ class LoggingDatabaseConnection:
# The type of entry which goes on our after_callbacks and exception_callbacks lists. # The type of entry which goes on our after_callbacks and exception_callbacks lists.
_CallbackListEntry = Tuple[Callable[..., object], Iterable[Any], Dict[str, Any]] _CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]]
R = TypeVar("R") R = TypeVar("R")
@@ -235,7 +235,7 @@ class LoggingTransaction:
self.after_callbacks = after_callbacks self.after_callbacks = after_callbacks
self.exception_callbacks = exception_callbacks self.exception_callbacks = exception_callbacks
def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any): def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any):
"""Call the given callback on the main twisted thread after the """Call the given callback on the main twisted thread after the
transaction has finished. Used to invalidate the caches on the transaction has finished. Used to invalidate the caches on the
correct thread. correct thread.
@@ -247,7 +247,7 @@ class LoggingTransaction:
self.after_callbacks.append((callback, args, kwargs)) self.after_callbacks.append((callback, args, kwargs))
def call_on_exception( def call_on_exception(
self, callback: Callable[..., object], *args: Any, **kwargs: Any self, callback: Callable[..., None], *args: Any, **kwargs: Any
): ):
# if self.exception_callbacks is None, that means that whatever constructed the # if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that # LoggingTransaction isn't expecting there to be any callbacks; assert that

View File

@@ -31,6 +31,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
from .cache import CacheInvalidationWorkerStore
from .censor_events import CensorEventsStore from .censor_events import CensorEventsStore
from .client_ips import ClientIpStore from .client_ips import ClientIpStore
from .deviceinbox import DeviceInboxStore from .deviceinbox import DeviceInboxStore
@@ -48,6 +49,7 @@ from .keys import KeyStore
from .lock import LockStore from .lock import LockStore
from .media_repository import MediaRepositoryStore from .media_repository import MediaRepositoryStore
from .metrics import ServerMetricsStore from .metrics import ServerMetricsStore
from .monthly_active_users import MonthlyActiveUsersStore
from .openid import OpenIdStore from .openid import OpenIdStore
from .presence import PresenceStore from .presence import PresenceStore
from .profile import ProfileStore from .profile import ProfileStore
@@ -61,9 +63,11 @@ from .relations import RelationsStore
from .room import RoomStore from .room import RoomStore
from .room_batch import RoomBatchStore from .room_batch import RoomBatchStore
from .roommember import RoomMemberStore from .roommember import RoomMemberStore
from .search import SearchStore
from .session import SessionStore from .session import SessionStore
from .signatures import SignatureStore from .signatures import SignatureStore
from .state import StateStore from .state import StateStore
from .stats import StatsStore
from .stream import StreamStore from .stream import StreamStore
from .tags import TagsStore from .tags import TagsStore
from .transactions import TransactionWorkerStore from .transactions import TransactionWorkerStore
@@ -103,6 +107,7 @@ class DataStore(
ReceiptsStore, ReceiptsStore,
EndToEndKeyStore, EndToEndKeyStore,
EndToEndRoomKeyStore, EndToEndRoomKeyStore,
SearchStore,
TagsStore, TagsStore,
AccountDataStore, AccountDataStore,
EventPushActionsStore, EventPushActionsStore,
@@ -113,10 +118,13 @@ class DataStore(
UserDirectoryStore, UserDirectoryStore,
GroupServerStore, GroupServerStore,
UserErasureStore, UserErasureStore,
MonthlyActiveUsersStore,
StatsStore,
RelationsStore, RelationsStore,
CensorEventsStore, CensorEventsStore,
UIAuthStore, UIAuthStore,
EventForwardExtremitiesStore, EventForwardExtremitiesStore,
CacheInvalidationWorkerStore,
ServerMetricsStore, ServerMetricsStore,
LockStore, LockStore,
SessionStore, SessionStore,

View File

@@ -17,7 +17,7 @@ from typing import Iterable, List, Optional, Tuple
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.types import RoomAlias from synapse.types import RoomAlias
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached

View File

@@ -17,7 +17,7 @@ from typing import Any, Dict, List
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -16,7 +16,7 @@ import logging
from typing import Any, List, Set, Tuple from typing import Any, List, Set, Tuple
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken

View File

@@ -476,7 +476,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
shadow_banned: true iff the user is to be shadow-banned, false otherwise. shadow_banned: true iff the user is to be shadow-banned, false otherwise.
""" """
def set_shadow_banned_txn(txn: LoggingTransaction) -> None: def set_shadow_banned_txn(txn):
user_id = user.to_string() user_id = user.to_string()
self.db_pool.simple_update_one_txn( self.db_pool.simple_update_one_txn(
txn, txn,

View File

@@ -15,7 +15,7 @@
from typing import Dict, Iterable from typing import Dict, Iterable
from synapse.storage.database import LoggingTransaction from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList

View File

@@ -131,9 +131,9 @@ def prepare_database(
"config==None in prepare_database, but database is not empty" "config==None in prepare_database, but database is not empty"
) )
# This should be run on all processes, master or worker. The master will # if it's a worker app, refuse to upgrade the database, to avoid multiple
# apply the deltas, while workers will check if any outstanding deltas # workers doing it at once.
# exist and raise an PrepareDatabaseException if they do. if config.worker.worker_app is None:
_upgrade_existing_database( _upgrade_existing_database(
cur, cur,
version_info, version_info,
@@ -141,6 +141,14 @@ def prepare_database(
config, config,
databases=databases, databases=databases,
) )
elif version_info.current_version < SCHEMA_VERSION:
# If the DB is on an older version than we expect then we refuse
# to start the worker (as the main process needs to run first to
# update the schema).
raise UpgradeDatabaseException(
OUTDATED_SCHEMA_ON_WORKER_ERROR
% (SCHEMA_VERSION, version_info.current_version)
)
else: else:
logger.info("%r: Initialising new database", databases) logger.info("%r: Initialising new database", databases)
@@ -350,18 +358,6 @@ def _upgrade_existing_database(
is_worker = config and config.worker.worker_app is not None is_worker = config and config.worker.worker_app is not None
# If the schema version needs to be updated, and we are on a worker, we immediately
# know to bail out as workers cannot update the database schema. Only one process
# must update the database at the time, therefore we delegate this task to the master.
if is_worker and current_schema_state.current_version < SCHEMA_VERSION:
# If the DB is on an older version than we expect then we refuse
# to start the worker (as the main process needs to run first to
# update the schema).
raise UpgradeDatabaseException(
OUTDATED_SCHEMA_ON_WORKER_ERROR
% (SCHEMA_VERSION, current_schema_state.current_version)
)
if ( if (
current_schema_state.compat_version is not None current_schema_state.compat_version is not None
and current_schema_state.compat_version > SCHEMA_VERSION and current_schema_state.compat_version > SCHEMA_VERSION

View File

@@ -18,17 +18,5 @@
-- when a device was deleted using Synapse earlier than 1.47.0. -- when a device was deleted using Synapse earlier than 1.47.0.
-- This runs as background task, but may take a bit to finish. -- This runs as background task, but may take a bit to finish.
-- Remove any existing instances of this job running. It's OK to stop and restart this job,
-- as it's just deleting entries from a table - no progress will be lost.
--
-- This is necessary due a similar migration running the job accidentally
-- being included in schema version 64 during v1.47.0rc1,rc2. If a
-- homeserver had updated from Synapse <=v1.45.0 (schema version <=64),
-- then they would have started running this background update already.
-- If that update was still running, then simply inserting it again would
-- cause an SQL failure. So we effectively do an "upsert" here instead.
DELETE FROM background_updates WHERE update_name = 'remove_deleted_devices_from_device_inbox';
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(6506, 'remove_deleted_devices_from_device_inbox', '{}'); (6505, 'remove_deleted_devices_from_device_inbox', '{}');

View File

@@ -27,7 +27,6 @@ from typing import (
Generic, Generic,
Hashable, Hashable,
Iterable, Iterable,
Iterator,
Optional, Optional,
Set, Set,
TypeVar, TypeVar,
@@ -41,6 +40,7 @@ from typing_extensions import ContextManager
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import CancelledError from twisted.internet.defer import CancelledError
from twisted.internet.interfaces import IReactorTime from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
from twisted.python.failure import Failure from twisted.python.failure import Failure
from synapse.logging.context import ( from synapse.logging.context import (
@@ -78,7 +78,7 @@ class ObservableDeferred(Generic[_T]):
object.__setattr__(self, "_result", None) object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", []) object.__setattr__(self, "_observers", [])
def callback(r: _T) -> _T: def callback(r):
object.__setattr__(self, "_result", (True, r)) object.__setattr__(self, "_result", (True, r))
# once we have set _result, no more entries will be added to _observers, # once we have set _result, no more entries will be added to _observers,
@@ -98,7 +98,7 @@ class ObservableDeferred(Generic[_T]):
) )
return r return r
def errback(f: Failure) -> Optional[Failure]: def errback(f):
object.__setattr__(self, "_result", (False, f)) object.__setattr__(self, "_result", (False, f))
# once we have set _result, no more entries will be added to _observers, # once we have set _result, no more entries will be added to _observers,
@@ -109,7 +109,7 @@ class ObservableDeferred(Generic[_T]):
for observer in observers: for observer in observers:
# This is a little bit of magic to correctly propagate stack # This is a little bit of magic to correctly propagate stack
# traces when we `await` on one of the observer deferreds. # traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f # type: ignore[union-attr] f.value.__failure__ = f
try: try:
observer.errback(f) observer.errback(f)
except Exception as e: except Exception as e:
@@ -314,7 +314,7 @@ class Linearizer:
# will release the lock. # will release the lock.
@contextmanager @contextmanager
def _ctx_manager(_: None) -> Iterator[None]: def _ctx_manager(_):
try: try:
yield yield
finally: finally:
@@ -355,7 +355,7 @@ class Linearizer:
new_defer = make_deferred_yieldable(defer.Deferred()) new_defer = make_deferred_yieldable(defer.Deferred())
entry.deferreds[new_defer] = 1 entry.deferreds[new_defer] = 1
def cb(_r: None) -> "defer.Deferred[None]": def cb(_r):
logger.debug("Acquired linearizer lock %r for key %r", self.name, key) logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
entry.count += 1 entry.count += 1
@@ -371,7 +371,7 @@ class Linearizer:
# code must be synchronous, so this is the only sensible place.) # code must be synchronous, so this is the only sensible place.)
return self._clock.sleep(0) return self._clock.sleep(0)
def eb(e: Failure) -> Failure: def eb(e):
logger.info("defer %r got err %r", new_defer, e) logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError): if isinstance(e, CancelledError):
logger.debug( logger.debug(
@@ -435,7 +435,7 @@ class ReadWriteLock:
await make_deferred_yieldable(curr_writer) await make_deferred_yieldable(curr_writer)
@contextmanager @contextmanager
def _ctx_manager() -> Iterator[None]: def _ctx_manager():
try: try:
yield yield
finally: finally:
@@ -464,7 +464,7 @@ class ReadWriteLock:
await make_deferred_yieldable(defer.gatherResults(to_wait_on)) await make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager @contextmanager
def _ctx_manager() -> Iterator[None]: def _ctx_manager():
try: try:
yield yield
finally: finally:
@@ -524,7 +524,7 @@ def timeout_deferred(
delayed_call = reactor.callLater(timeout, time_it_out) delayed_call = reactor.callLater(timeout, time_it_out)
def convert_cancelled(value: Failure) -> Failure: def convert_cancelled(value: failure.Failure):
# if the original deferred was cancelled, and our timeout has fired, then # if the original deferred was cancelled, and our timeout has fired, then
# the reason it was cancelled was due to our timeout. Turn the CancelledError # the reason it was cancelled was due to our timeout. Turn the CancelledError
# into a TimeoutError. # into a TimeoutError.
@@ -534,7 +534,7 @@ def timeout_deferred(
deferred.addErrback(convert_cancelled) deferred.addErrback(convert_cancelled)
def cancel_timeout(result: _T) -> _T: def cancel_timeout(result):
# stop the pending call to cancel the deferred if it's been fired # stop the pending call to cancel the deferred if it's been fired
if delayed_call.active(): if delayed_call.active():
delayed_call.cancel() delayed_call.cancel()
@@ -542,11 +542,11 @@ def timeout_deferred(
deferred.addBoth(cancel_timeout) deferred.addBoth(cancel_timeout)
def success_cb(val: _T) -> None: def success_cb(val):
if not new_d.called: if not new_d.called:
new_d.callback(val) new_d.callback(val)
def failure_cb(val: Failure) -> None: def failure_cb(val):
if not new_d.called: if not new_d.called:
new_d.errback(val) new_d.errback(val)
@@ -557,13 +557,13 @@ def timeout_deferred(
# This class can't be generic because it uses slots with attrs. # This class can't be generic because it uses slots with attrs.
# See: https://github.com/python-attrs/attrs/issues/313 # See: https://github.com/python-attrs/attrs/issues/313
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True)
class DoneAwaitable: # should be: Generic[R] class DoneAwaitable: # should be: Generic[R]
"""Simple awaitable that returns the provided value.""" """Simple awaitable that returns the provided value."""
value: Any # should be: R value = attr.ib(type=Any) # should be: R
def __await__(self) -> Any: def __await__(self):
return self return self
def __iter__(self) -> "DoneAwaitable": def __iter__(self) -> "DoneAwaitable":

View File

@@ -17,7 +17,7 @@ import logging
import typing import typing
from enum import Enum, auto from enum import Enum, auto
from sys import intern from sys import intern
from typing import Any, Callable, Dict, List, Optional, Sized from typing import Callable, Dict, Optional, Sized
import attr import attr
from prometheus_client.core import Gauge from prometheus_client.core import Gauge
@@ -58,20 +58,20 @@ class EvictionReason(Enum):
time = auto() time = auto()
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True)
class CacheMetric: class CacheMetric:
_cache: Sized _cache = attr.ib()
_cache_type: str _cache_type = attr.ib(type=str)
_cache_name: str _cache_name = attr.ib(type=str)
_collect_callback: Optional[Callable] _collect_callback = attr.ib(type=Optional[Callable])
hits: int = 0 hits = attr.ib(default=0)
misses: int = 0 misses = attr.ib(default=0)
eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib( eviction_size_by_reason: typing.Counter[EvictionReason] = attr.ib(
factory=collections.Counter factory=collections.Counter
) )
memory_usage: Optional[int] = None memory_usage = attr.ib(default=None)
def inc_hits(self) -> None: def inc_hits(self) -> None:
self.hits += 1 self.hits += 1
@@ -89,14 +89,13 @@ class CacheMetric:
self.memory_usage += memory self.memory_usage += memory
def dec_memory_usage(self, memory: int) -> None: def dec_memory_usage(self, memory: int) -> None:
assert self.memory_usage is not None
self.memory_usage -= memory self.memory_usage -= memory
def clear_memory_usage(self) -> None: def clear_memory_usage(self) -> None:
if self.memory_usage is not None: if self.memory_usage is not None:
self.memory_usage = 0 self.memory_usage = 0
def describe(self) -> List[str]: def describe(self):
return [] return []
def collect(self) -> None: def collect(self) -> None:
@@ -119,9 +118,8 @@ class CacheMetric:
self.eviction_size_by_reason[reason] self.eviction_size_by_reason[reason]
) )
cache_total.labels(self._cache_name).set(self.hits + self.misses) cache_total.labels(self._cache_name).set(self.hits + self.misses)
max_size = getattr(self._cache, "max_size", None) if getattr(self._cache, "max_size", None):
if max_size: cache_max_size.labels(self._cache_name).set(self._cache.max_size)
cache_max_size.labels(self._cache_name).set(max_size)
if TRACK_MEMORY_USAGE: if TRACK_MEMORY_USAGE:
# self.memory_usage can be None if nothing has been inserted # self.memory_usage can be None if nothing has been inserted
@@ -195,7 +193,7 @@ KNOWN_KEYS = {
} }
def intern_string(string: Optional[str]) -> Optional[str]: def intern_string(string):
"""Takes a (potentially) unicode string and interns it if it's ascii""" """Takes a (potentially) unicode string and interns it if it's ascii"""
if string is None: if string is None:
return None return None
@@ -206,7 +204,7 @@ def intern_string(string: Optional[str]) -> Optional[str]:
return string return string
def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]: def intern_dict(dictionary):
"""Takes a dictionary and interns well known keys and their values""" """Takes a dictionary and interns well known keys and their values"""
return { return {
KNOWN_KEYS.get(key, key): _intern_known_values(key, value) KNOWN_KEYS.get(key, key): _intern_known_values(key, value)
@@ -214,7 +212,7 @@ def intern_dict(dictionary: Dict[str, Any]) -> Dict[str, Any]:
} }
def _intern_known_values(key: str, value: Any) -> Any: def _intern_known_values(key, value):
intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key") intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
if key in intern_keys: if key in intern_keys:

View File

@@ -289,7 +289,7 @@ class DeferredCache(Generic[KT, VT]):
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks) self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key: KT) -> None: def invalidate(self, key) -> None:
"""Delete a key, or tree of entries """Delete a key, or tree of entries
If the cache is backed by a regular dict, then "key" must be of If the cache is backed by a regular dict, then "key" must be of

View File

@@ -19,15 +19,12 @@ import logging
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict,
Generic, Generic,
Hashable,
Iterable, Iterable,
Mapping, Mapping,
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
Type,
TypeVar, TypeVar,
Union, Union,
cast, cast,
@@ -35,7 +32,6 @@ from typing import (
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from twisted.internet import defer from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
@@ -64,12 +60,7 @@ class _CachedFunction(Generic[F]):
class _CacheDescriptorBase: class _CacheDescriptorBase:
def __init__( def __init__(self, orig: Callable[..., Any], num_args, cache_context=False):
self,
orig: Callable[..., Any],
num_args: Optional[int],
cache_context: bool = False,
):
self.orig = orig self.orig = orig
arg_spec = inspect.getfullargspec(orig) arg_spec = inspect.getfullargspec(orig)
@@ -181,14 +172,14 @@ class LruCacheDescriptor(_CacheDescriptorBase):
def __init__( def __init__(
self, self,
orig: Callable[..., Any], orig,
max_entries: int = 1000, max_entries: int = 1000,
cache_context: bool = False, cache_context: bool = False,
): ):
super().__init__(orig, num_args=None, cache_context=cache_context) super().__init__(orig, num_args=None, cache_context=cache_context)
self.max_entries = max_entries self.max_entries = max_entries
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: def __get__(self, obj, owner):
cache: LruCache[CacheKey, Any] = LruCache( cache: LruCache[CacheKey, Any] = LruCache(
cache_name=self.orig.__name__, cache_name=self.orig.__name__,
max_size=self.max_entries, max_size=self.max_entries,
@@ -198,7 +189,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
sentinel = LruCacheDescriptor._Sentinel.sentinel sentinel = LruCacheDescriptor._Sentinel.sentinel
@functools.wraps(self.orig) @functools.wraps(self.orig)
def _wrapped(*args: Any, **kwargs: Any) -> Any: def _wrapped(*args, **kwargs):
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
callbacks = (invalidate_callback,) if invalidate_callback else () callbacks = (invalidate_callback,) if invalidate_callback else ()
@@ -254,19 +245,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return r1 + r2 return r1 + r2
Args: Args:
num_args: number of positional arguments (excluding ``self`` and num_args (int): number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named ``cache_context``) to use as cache keys. Defaults to all named
args of the function. args of the function.
""" """
def __init__( def __init__(
self, self,
orig: Callable[..., Any], orig,
max_entries: int = 1000, max_entries=1000,
num_args: Optional[int] = None, num_args=None,
tree: bool = False, tree=False,
cache_context: bool = False, cache_context=False,
iterable: bool = False, iterable=False,
prune_unread_entries: bool = True, prune_unread_entries: bool = True,
): ):
super().__init__(orig, num_args=num_args, cache_context=cache_context) super().__init__(orig, num_args=num_args, cache_context=cache_context)
@@ -281,7 +272,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
self.iterable = iterable self.iterable = iterable
self.prune_unread_entries = prune_unread_entries self.prune_unread_entries = prune_unread_entries
def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: def __get__(self, obj, owner):
cache: DeferredCache[CacheKey, Any] = DeferredCache( cache: DeferredCache[CacheKey, Any] = DeferredCache(
name=self.orig.__name__, name=self.orig.__name__,
max_entries=self.max_entries, max_entries=self.max_entries,
@@ -293,7 +284,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
get_cache_key = self.cache_key_builder get_cache_key = self.cache_key_builder
@functools.wraps(self.orig) @functools.wraps(self.orig)
def _wrapped(*args: Any, **kwargs: Any) -> Any: def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated # whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -344,19 +335,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
of results. of results.
""" """
def __init__( def __init__(self, orig, cached_method_name, list_name, num_args=None):
self,
orig: Callable[..., Any],
cached_method_name: str,
list_name: str,
num_args: Optional[int] = None,
):
""" """
Args: Args:
orig orig (function)
cached_method_name: The name of the cached method. cached_method_name (str): The name of the cached method.
list_name: Name of the argument which is the bulk lookup list list_name (str): Name of the argument which is the bulk lookup list
num_args: number of positional arguments (excluding ``self``, num_args (int): number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all but including list_name) to use as cache keys. Defaults to all
named args of the function. named args of the function.
""" """
@@ -375,15 +360,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
% (self.list_name, cached_method_name) % (self.list_name, cached_method_name)
) )
def __get__( def __get__(self, obj, objtype=None):
self, obj: Optional[Any], objtype: Optional[Type] = None
) -> Callable[..., Any]:
cached_method = getattr(obj, self.cached_method_name) cached_method = getattr(obj, self.cached_method_name)
cache: DeferredCache[CacheKey, Any] = cached_method.cache cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args num_args = cached_method.num_args
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args: Any, **kwargs: Any) -> Any: def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its # If we're passed a cache_context then we'll want to call its
# invalidate() whenever we are invalidated # invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -394,7 +377,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
results = {} results = {}
def update_results_dict(res: Any, arg: Hashable) -> None: def update_results_dict(res, arg):
results[arg] = res results[arg] = res
# list of deferreds to wait for # list of deferreds to wait for
@@ -406,13 +389,13 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
# otherwise a tuple is used. # otherwise a tuple is used.
if num_args == 1: if num_args == 1:
def arg_to_cache_key(arg: Hashable) -> Hashable: def arg_to_cache_key(arg):
return arg return arg
else: else:
keylist = list(keyargs) keylist = list(keyargs)
def arg_to_cache_key(arg: Hashable) -> Hashable: def arg_to_cache_key(arg):
keylist[self.list_pos] = arg keylist[self.list_pos] = arg
return tuple(keylist) return tuple(keylist)
@@ -438,7 +421,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
key = arg_to_cache_key(arg) key = arg_to_cache_key(arg)
cache.set(key, deferred, callback=invalidate_callback) cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res: Dict[Hashable, Any]) -> None: def complete_all(res):
# the wrapped function has completed. It returns a # the wrapped function has completed. It returns a
# a dict. We can now resolve the observable deferreds in # a dict. We can now resolve the observable deferreds in
# the cache and update our own result map. # the cache and update our own result map.
@@ -447,7 +430,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
deferreds_map[e].callback(val) deferreds_map[e].callback(val)
results[e] = val results[e] = val
def errback(f: Failure) -> Failure: def errback(f):
# the wrapped function has failed. Invalidate any cache # the wrapped function has failed. Invalidate any cache
# entries we're supposed to be populating, and fail # entries we're supposed to be populating, and fail
# their deferreds. # their deferreds.

View File

@@ -19,8 +19,6 @@ from typing import Any, Generic, Optional, TypeVar, Union, overload
import attr import attr
from typing_extensions import Literal from typing_extensions import Literal
from twisted.internet import defer
from synapse.config import cache as cache_config from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock from synapse.util import Clock
@@ -83,7 +81,7 @@ class ExpiringCache(Generic[KT, VT]):
# Don't bother starting the loop if things never expire # Don't bother starting the loop if things never expire
return return
def f() -> "defer.Deferred[None]": def f():
return run_as_background_process( return run_as_background_process(
"prune_cache_%s" % self._cache_name, self._prune_cache "prune_cache_%s" % self._cache_name, self._prune_cache
) )
@@ -159,7 +157,7 @@ class ExpiringCache(Generic[KT, VT]):
self[key] = value self[key] = value
return value return value
async def _prune_cache(self) -> None: def _prune_cache(self) -> None:
if not self._expiry_ms: if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called # zero expiry time means don't expire. This should never get called
# since we have this check in start too. # since we have this check in start too.
@@ -212,7 +210,7 @@ class ExpiringCache(Generic[KT, VT]):
return False return False
@attr.s(slots=True, auto_attribs=True) @attr.s(slots=True)
class _CacheEntry: class _CacheEntry:
time: int time = attr.ib(type=int)
value: Any value = attr.ib()

View File

@@ -18,13 +18,12 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util.async_helpers import maybe_awaitable from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def user_left_room(distributor: "Distributor", user: UserID, room_id: str) -> None: def user_left_room(distributor, user, room_id):
distributor.fire("user_left_room", user=user, room_id=room_id) distributor.fire("user_left_room", user=user, room_id=room_id)
@@ -64,7 +63,7 @@ class Distributor:
self.pre_registration[name] = [] self.pre_registration[name] = []
self.pre_registration[name].append(observer) self.pre_registration[name].append(observer)
def fire(self, name: str, *args: Any, **kwargs: Any) -> None: def fire(self, name: str, *args, **kwargs) -> None:
"""Dispatches the given signal to the registered observers. """Dispatches the given signal to the registered observers.
Runs the observers as a background process. Does not return a deferred. Runs the observers as a background process. Does not return a deferred.
@@ -96,7 +95,7 @@ class Signal:
Each observer callable may return a Deferred.""" Each observer callable may return a Deferred."""
self.observers.append(observer) self.observers.append(observer)
def fire(self, *args: Any, **kwargs: Any) -> "defer.Deferred[List[Any]]": def fire(self, *args, **kwargs) -> "defer.Deferred[List[Any]]":
"""Invokes every callable in the observer list, passing in the args and """Invokes every callable in the observer list, passing in the args and
kwargs. Exceptions thrown by observers are logged but ignored. It is kwargs. Exceptions thrown by observers are logged but ignored. It is
not an error to fire a signal with no observers. not an error to fire a signal with no observers.
@@ -104,7 +103,7 @@ class Signal:
Returns a Deferred that will complete when all the observers have Returns a Deferred that will complete when all the observers have
completed.""" completed."""
async def do(observer: Callable[..., Any]) -> Any: async def do(observer):
try: try:
return await maybe_awaitable(observer(*args, **kwargs)) return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e: except Exception as e:
@@ -121,5 +120,5 @@ class Signal:
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )
def __repr__(self) -> str: def __repr__(self):
return "<Signal name=%r>" % (self.name,) return "<Signal name=%r>" % (self.name,)

View File

@@ -3,52 +3,23 @@
# We copy it here as we need to instantiate `GAIResolver` manually, but it is a # We copy it here as we need to instantiate `GAIResolver` manually, but it is a
# private class. # private class.
from socket import ( from socket import (
AF_INET, AF_INET,
AF_INET6, AF_INET6,
AF_UNSPEC, AF_UNSPEC,
SOCK_DGRAM, SOCK_DGRAM,
SOCK_STREAM, SOCK_STREAM,
AddressFamily,
SocketKind,
gaierror, gaierror,
getaddrinfo, getaddrinfo,
) )
from typing import (
TYPE_CHECKING,
Callable,
List,
NoReturn,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from zope.interface import implementer from zope.interface import implementer
from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import ( from twisted.internet.interfaces import IHostnameResolver, IHostResolution
IAddress,
IHostnameResolver,
IHostResolution,
IReactorThreads,
IResolutionReceiver,
)
from twisted.internet.threads import deferToThreadPool from twisted.internet.threads import deferToThreadPool
if TYPE_CHECKING:
# The types below are copied from
# https://github.com/twisted/twisted/blob/release-21.2.0-10091/src/twisted/internet/interfaces.py
# so that the type hints can match the interfaces.
from twisted.python.runtime import platform
if platform.supportsThreads():
from twisted.python.threadpool import ThreadPool
else:
ThreadPool = object # type: ignore[misc, assignment]
@implementer(IHostResolution) @implementer(IHostResolution)
class HostResolution: class HostResolution:
@@ -56,13 +27,13 @@ class HostResolution:
The in-progress resolution of a given hostname. The in-progress resolution of a given hostname.
""" """
def __init__(self, name: str): def __init__(self, name):
""" """
Create a L{HostResolution} with the given name. Create a L{HostResolution} with the given name.
""" """
self.name = name self.name = name
def cancel(self) -> NoReturn: def cancel(self):
# IHostResolution.cancel # IHostResolution.cancel
raise NotImplementedError() raise NotImplementedError()
@@ -91,17 +62,6 @@ _socktypeToType = {
} }
_GETADDRINFO_RESULT = List[
Tuple[
AddressFamily,
SocketKind,
int,
str,
Union[Tuple[str, int], Tuple[str, int, int, int]],
]
]
@implementer(IHostnameResolver) @implementer(IHostnameResolver)
class GAIResolver: class GAIResolver:
""" """
@@ -109,12 +69,7 @@ class GAIResolver:
L{getaddrinfo} in a thread. L{getaddrinfo} in a thread.
""" """
def __init__( def __init__(self, reactor, getThreadPool=None, getaddrinfo=getaddrinfo):
self,
reactor: IReactorThreads,
getThreadPool: Optional[Callable[[], "ThreadPool"]] = None,
getaddrinfo: Callable[[str, int, int, int], _GETADDRINFO_RESULT] = getaddrinfo,
):
""" """
Create a L{GAIResolver}. Create a L{GAIResolver}.
@param reactor: the reactor to schedule result-delivery on @param reactor: the reactor to schedule result-delivery on
@@ -134,16 +89,14 @@ class GAIResolver:
) )
self._getaddrinfo = getaddrinfo self._getaddrinfo = getaddrinfo
# The types on IHostnameResolver is incorrect in Twisted, see def resolveHostName(
# https://twistedmatrix.com/trac/ticket/10276
def resolveHostName( # type: ignore[override]
self, self,
resolutionReceiver: IResolutionReceiver, resolutionReceiver,
hostName: str, hostName,
portNumber: int = 0, portNumber=0,
addressTypes: Optional[Sequence[Type[IAddress]]] = None, addressTypes=None,
transportSemantics: str = "TCP", transportSemantics="TCP",
) -> IHostResolution: ):
""" """
See L{IHostnameResolver.resolveHostName} See L{IHostnameResolver.resolveHostName}
@param resolutionReceiver: see interface @param resolutionReceiver: see interface
@@ -159,7 +112,7 @@ class GAIResolver:
] ]
socketType = _transportToSocket[transportSemantics] socketType = _transportToSocket[transportSemantics]
def get() -> _GETADDRINFO_RESULT: def get():
try: try:
return self._getaddrinfo( return self._getaddrinfo(
hostName, portNumber, addressFamily, socketType hostName, portNumber, addressFamily, socketType
@@ -172,7 +125,7 @@ class GAIResolver:
resolutionReceiver.resolutionBegan(resolution) resolutionReceiver.resolutionBegan(resolution)
@d.addCallback @d.addCallback
def deliverResults(result: _GETADDRINFO_RESULT) -> None: def deliverResults(result):
for family, socktype, _proto, _cannoname, sockaddr in result: for family, socktype, _proto, _cannoname, sockaddr in result:
addrType = _afToType[family] addrType = _afToType[family]
resolutionReceiver.addressResolved( resolutionReceiver.addressResolved(

View File

@@ -56,22 +56,14 @@ block_db_sched_duration = Counter(
"synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"] "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]
) )
# This is dynamically created in InFlightGauge.__init__.
class _InFlightMetric(Protocol):
real_time_max: float
real_time_sum: float
# Tracks the number of blocks currently active # Tracks the number of blocks currently active
in_flight: InFlightGauge[_InFlightMetric] = InFlightGauge( in_flight = InFlightGauge(
"synapse_util_metrics_block_in_flight", "synapse_util_metrics_block_in_flight",
"", "",
labels=["block_name"], labels=["block_name"],
sub_metrics=["real_time_max", "real_time_sum"], sub_metrics=["real_time_max", "real_time_sum"],
) )
T = TypeVar("T", bound=Callable[..., Any]) T = TypeVar("T", bound=Callable[..., Any])
@@ -188,7 +180,7 @@ class Measure:
""" """
return self._logging_context.get_resource_usage() return self._logging_context.get_resource_usage()
def _update_in_flight(self, metrics: _InFlightMetric) -> None: def _update_in_flight(self, metrics) -> None:
"""Gets called when processing in flight metrics""" """Gets called when processing in flight metrics"""
assert self.start is not None assert self.start is not None
duration = self.clock.time() - self.start duration = self.clock.time() - self.start

View File

@@ -116,7 +116,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._limit_usage_by_mau = False self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception # Ensure does not throw exception
self.get_success( self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None self.user1, device_id=None, valid_until_ms=None
) )
) )
@@ -134,7 +134,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
self.get_failure( self.get_failure(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None self.user1, device_id=None, valid_until_ms=None
), ),
ResourceLimitError, ResourceLimitError,
@@ -162,7 +162,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
# If not in monthly active cohort # If not in monthly active cohort
self.get_failure( self.get_failure(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None self.user1, device_id=None, valid_until_ms=None
), ),
ResourceLimitError, ResourceLimitError,
@@ -179,7 +179,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.clock.time_msec()) return_value=make_awaitable(self.clock.time_msec())
) )
self.get_success( self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None self.user1, device_id=None, valid_until_ms=None
) )
) )
@@ -197,7 +197,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
) )
# Ensure does not raise exception # Ensure does not raise exception
self.get_success( self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
self.user1, device_id=None, valid_until_ms=None self.user1, device_id=None, valid_until_ms=None
) )
) )

View File

@@ -193,8 +193,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True}) @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self): def test_get_or_create_user_mau_not_blocked(self):
# Type ignore: mypy doesn't like us assigning to methods. self.store.count_monthly_users = Mock(
self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
) )
# Ensure does not throw exception # Ensure does not throw exception
@@ -202,8 +201,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True}) @override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self): def test_get_or_create_user_mau_blocked(self):
# Type ignore: mypy doesn't like us assigning to methods. self.store.get_monthly_active_count = Mock(
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.lots_of_users) return_value=make_awaitable(self.lots_of_users)
) )
self.get_failure( self.get_failure(
@@ -211,8 +209,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
ResourceLimitError, ResourceLimitError,
) )
# Type ignore: mypy doesn't like us assigning to methods. self.store.get_monthly_active_count = Mock(
self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value) return_value=make_awaitable(self.hs.config.server.max_mau_value)
) )
self.get_failure( self.get_failure(

View File

@@ -14,8 +14,6 @@
from typing import Any, Iterable, List, Optional, Tuple from typing import Any, Iterable, List, Optional, Tuple
from unittest import mock from unittest import mock
from twisted.internet.defer import ensureDeferred
from synapse.api.constants import ( from synapse.api.constants import (
EventContentFields, EventContentFields,
EventTypes, EventTypes,
@@ -318,59 +316,6 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
AuthError, AuthError,
) )
def test_room_hierarchy_cache(self) -> None:
"""In-flight room hierarchy requests are deduplicated."""
# Run two `get_room_hierarchy` calls up until they block.
deferred1 = ensureDeferred(
self.handler.get_room_hierarchy(self.user, self.space)
)
deferred2 = ensureDeferred(
self.handler.get_room_hierarchy(self.user, self.space)
)
# Complete the two calls.
result1 = self.get_success(deferred1)
result2 = self.get_success(deferred2)
# Both `get_room_hierarchy` calls should return the same result.
expected = [(self.space, [self.room]), (self.room, ())]
self._assert_hierarchy(result1, expected)
self._assert_hierarchy(result2, expected)
self.assertIs(result1, result2)
# A subsequent `get_room_hierarchy` call should not reuse the result.
result3 = self.get_success(
self.handler.get_room_hierarchy(self.user, self.space)
)
self._assert_hierarchy(result3, expected)
self.assertIsNot(result1, result3)
def test_room_hierarchy_cache_sharing(self) -> None:
"""Room hierarchy responses for different users are not shared."""
user2 = self.register_user("user2", "pass")
# Make the room within the space invite-only.
self.helper.send_state(
self.room,
event_type=EventTypes.JoinRules,
body={"join_rule": JoinRules.INVITE},
tok=self.token,
)
# Run two `get_room_hierarchy` calls for different users up until they block.
deferred1 = ensureDeferred(
self.handler.get_room_hierarchy(self.user, self.space)
)
deferred2 = ensureDeferred(self.handler.get_room_hierarchy(user2, self.space))
# Complete the two calls.
result1 = self.get_success(deferred1)
result2 = self.get_success(deferred2)
# The `get_room_hierarchy` calls should return different results.
self._assert_hierarchy(result1, [(self.space, [self.room]), (self.room, ())])
self._assert_hierarchy(result2, [(self.space, [self.room])])
def _create_room_with_join_rule( def _create_room_with_join_rule(
self, join_rule: str, room_version: Optional[str] = None, **extra_content self, join_rule: str, room_version: Optional[str] = None, **extra_content
) -> str: ) -> str:

View File

@@ -1169,14 +1169,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# regardless of whether password login or SSO is allowed # regardless of whether password login or SSO is allowed
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.get_success( self.admin_user_tok = self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
self.admin_user, device_id=None, valid_until_ms=None self.admin_user, device_id=None, valid_until_ms=None
) )
) )
self.other_user = self.register_user("user", "pass", displayname="User") self.other_user = self.register_user("user", "pass", displayname="User")
self.other_user_token = self.get_success( self.other_user_token = self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
self.other_user, device_id=None, valid_until_ms=None self.other_user, device_id=None, valid_until_ms=None
) )
) )
@@ -3592,34 +3592,31 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
self.other_user self.other_user
) )
@parameterized.expand(["POST", "DELETE"]) def test_no_auth(self):
def test_no_auth(self, method: str):
""" """
Try to get information of an user without authentication. Try to get information of an user without authentication.
""" """
channel = self.make_request(method, self.url) channel = self.make_request("POST", self.url)
self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"]) def test_requester_is_not_admin(self):
def test_requester_is_not_admin(self, method: str):
""" """
If the user is not a server admin, an error is returned. If the user is not a server admin, an error is returned.
""" """
other_user_token = self.login("user", "pass") other_user_token = self.login("user", "pass")
channel = self.make_request(method, self.url, access_token=other_user_token) channel = self.make_request("POST", self.url, access_token=other_user_token)
self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
@parameterized.expand(["POST", "DELETE"]) def test_user_is_not_local(self):
def test_user_is_not_local(self, method: str):
""" """
Tests that shadow-banning for a user that is not a local returns a 400 Tests that shadow-banning for a user that is not a local returns a 400
""" """
url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" url = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
channel = self.make_request(method, url, access_token=self.admin_user_tok) channel = self.make_request("POST", url, access_token=self.admin_user_tok)
self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(400, channel.code, msg=channel.json_body)
def test_success(self): def test_success(self):
@@ -3639,17 +3636,6 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
result = self.get_success(self.store.get_user_by_access_token(other_user_token)) result = self.get_success(self.store.get_user_by_access_token(other_user_token))
self.assertTrue(result.shadow_banned) self.assertTrue(result.shadow_banned)
# Un-shadow-ban the user.
channel = self.make_request(
"DELETE", self.url, access_token=self.admin_user_tok
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual({}, channel.json_body)
# Ensure the user is no longer shadow-banned (and the cache was cleared).
result = self.get_success(self.store.get_user_by_access_token(other_user_token))
self.assertFalse(result.shadow_banned)
class RateLimitTestCase(unittest.HomeserverTestCase): class RateLimitTestCase(unittest.HomeserverTestCase):

View File

@@ -71,7 +71,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"localdb_enabled": False}}) @override_config({"password_config": {"localdb_enabled": False}})
def test_get_change_password_capabilities_localdb_disabled(self): def test_get_change_password_capabilities_localdb_disabled(self):
access_token = self.get_success( access_token = self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None self.user, device_id=None, valid_until_ms=None
) )
) )
@@ -85,7 +85,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"enabled": False}}) @override_config({"password_config": {"enabled": False}})
def test_get_change_password_capabilities_password_disabled(self): def test_get_change_password_capabilities_password_disabled(self):
access_token = self.get_success( access_token = self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None self.user, device_id=None, valid_until_ms=None
) )
) )
@@ -174,7 +174,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"experimental_features": {"msc3244_enabled": False}}) @override_config({"experimental_features": {"msc3244_enabled": False}})
def test_get_does_not_include_msc3244_fields_when_disabled(self): def test_get_does_not_include_msc3244_fields_when_disabled(self):
access_token = self.get_success( access_token = self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None self.user, device_id=None, valid_until_ms=None
) )
) )
@@ -189,7 +189,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
def test_get_does_include_msc3244_fields_when_enabled(self): def test_get_does_include_msc3244_fields_when_enabled(self):
access_token = self.get_success( access_token = self.get_success(
self.auth_handler.create_access_token_for_user_id( self.auth_handler.get_access_token_for_user_id(
self.user, device_id=None, valid_until_ms=None self.user, device_id=None, valid_until_ms=None
) )
) )

View File

@@ -28,12 +28,11 @@ from typing import (
MutableMapping, MutableMapping,
Optional, Optional,
Tuple, Tuple,
overload, Union,
) )
from unittest.mock import patch from unittest.mock import patch
import attr import attr
from typing_extensions import Literal
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Site from twisted.web.server import Site
@@ -56,32 +55,6 @@ class RestHelper:
site = attr.ib(type=Site) site = attr.ib(type=Site)
auth_user_id = attr.ib() auth_user_id = attr.ib()
@overload
def create_room_as(
self,
room_creator: Optional[str] = ...,
is_public: Optional[bool] = ...,
room_version: Optional[str] = ...,
tok: Optional[str] = ...,
expect_code: Literal[200] = ...,
extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
) -> str:
...
@overload
def create_room_as(
self,
room_creator: Optional[str] = ...,
is_public: Optional[bool] = ...,
room_version: Optional[str] = ...,
tok: Optional[str] = ...,
expect_code: int = ...,
extra_content: Optional[Dict] = ...,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = ...,
) -> Optional[str]:
...
def create_room_as( def create_room_as(
self, self,
room_creator: Optional[str] = None, room_creator: Optional[str] = None,
@@ -91,7 +64,7 @@ class RestHelper:
expect_code: int = 200, expect_code: int = 200,
extra_content: Optional[Dict] = None, extra_content: Optional[Dict] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
) -> Optional[str]: ) -> str:
""" """
Create a room. Create a room.
@@ -134,8 +107,6 @@ class RestHelper:
if expect_code == 200: if expect_code == 200:
return channel.json_body["room_id"] return channel.json_body["room_id"]
else:
return None
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None): def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership( self.change_membership(
@@ -205,7 +176,7 @@ class RestHelper:
extra_data: Optional[dict] = None, extra_data: Optional[dict] = None,
tok: Optional[str] = None, tok: Optional[str] = None,
expect_code: int = 200, expect_code: int = 200,
expect_errcode: Optional[str] = None, expect_errcode: str = None,
) -> None: ) -> None:
""" """
Send a membership state event into a room. Send a membership state event into a room.
@@ -289,7 +260,9 @@ class RestHelper:
txn_id=None, txn_id=None,
tok=None, tok=None,
expect_code=200, expect_code=200,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None, custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
): ):
if txn_id is None: if txn_id is None:
txn_id = "m%s" % (str(time.time())) txn_id = "m%s" % (str(time.time()))
@@ -536,7 +509,7 @@ class RestHelper:
went. went.
""" """
cookies: Dict[str, str] = {} cookies = {}
# if we're doing a ui auth, hit the ui auth redirect endpoint # if we're doing a ui auth, hit the ui auth redirect endpoint
if ui_auth_session_id: if ui_auth_session_id:
@@ -658,13 +631,7 @@ class RestHelper:
# hit the redirect url again with the right Host header, which should now issue # hit the redirect url again with the right Host header, which should now issue
# a cookie and redirect to the SSO provider. # a cookie and redirect to the SSO provider.
def get_location(channel: FakeChannel) -> str: location = channel.headers.getRawHeaders("Location")[0]
location_values = channel.headers.getRawHeaders("Location")
# Keep mypy happy by asserting that location_values is nonempty
assert location_values
return location_values[0]
location = get_location(channel)
parts = urllib.parse.urlsplit(location) parts = urllib.parse.urlsplit(location)
channel = make_request( channel = make_request(
self.hs.get_reactor(), self.hs.get_reactor(),
@@ -678,7 +645,7 @@ class RestHelper:
assert channel.code == 302 assert channel.code == 302
channel.extract_cookies(cookies) channel.extract_cookies(cookies)
return get_location(channel) return channel.headers.getRawHeaders("Location")[0]
def initiate_sso_ui_auth( def initiate_sso_ui_auth(
self, ui_auth_session_id: str, cookies: MutableMapping[str, str] self, ui_auth_session_id: str, cookies: MutableMapping[str, str]

View File

@@ -24,7 +24,6 @@ from typing import (
MutableMapping, MutableMapping,
Optional, Optional,
Tuple, Tuple,
Type,
Union, Union,
) )
@@ -227,7 +226,7 @@ def make_request(
path: Union[bytes, str], path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"", content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None, access_token: Optional[str] = None,
request: Type[Request] = SynapseRequest, request: Request = SynapseRequest,
shorthand: bool = True, shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None, federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False, content_is_form: bool = False,

View File

@@ -11,9 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import List
from unittest import mock
from synapse.app.generic_worker import GenericWorkerServer from synapse.app.generic_worker import GenericWorkerServer
from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database from synapse.storage.prepare_database import PrepareDatabaseException, prepare_database
@@ -22,22 +19,6 @@ from synapse.storage.schema import SCHEMA_VERSION
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
def fake_listdir(filepath: str) -> List[str]:
"""
A fake implementation of os.listdir which we can use to mock out the filesystem.
Args:
filepath: The directory to list files for.
Returns:
A list of files and folders in the directory.
"""
if filepath.endswith("full_schemas"):
return [str(SCHEMA_VERSION)]
return ["99_add_unicorn_to_database.sql"]
class WorkerSchemaTests(HomeserverTestCase): class WorkerSchemaTests(HomeserverTestCase):
def make_homeserver(self, reactor, clock): def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver( hs = self.setup_test_homeserver(
@@ -70,7 +51,7 @@ class WorkerSchemaTests(HomeserverTestCase):
prepare_database(db_conn, db_pool.engine, self.hs.config) prepare_database(db_conn, db_pool.engine, self.hs.config)
def test_not_upgraded_old_schema_version(self): def test_not_upgraded(self):
"""Test that workers don't start if the DB has an older schema version""" """Test that workers don't start if the DB has an older schema version"""
db_pool = self.hs.get_datastore().db_pool db_pool = self.hs.get_datastore().db_pool
db_conn = LoggingDatabaseConnection( db_conn = LoggingDatabaseConnection(
@@ -86,34 +67,3 @@ class WorkerSchemaTests(HomeserverTestCase):
with self.assertRaises(PrepareDatabaseException): with self.assertRaises(PrepareDatabaseException):
prepare_database(db_conn, db_pool.engine, self.hs.config) prepare_database(db_conn, db_pool.engine, self.hs.config)
def test_not_upgraded_current_schema_version_with_outstanding_deltas(self):
"""
Test that workers don't start if the DB is on the current schema version,
but there are still outstanding delta migrations to run.
"""
db_pool = self.hs.get_datastore().db_pool
db_conn = LoggingDatabaseConnection(
db_pool._db_pool.connect(),
db_pool.engine,
"tests",
)
# Set the schema version of the database to the current version
cur = db_conn.cursor()
cur.execute("UPDATE schema_version SET version = ?", (SCHEMA_VERSION,))
db_conn.commit()
# Path `os.listdir` here to make synapse think that there is a migration
# file ready to be run.
# Note that we can't patch this function for the whole method, else Synapse
# will try to find the file when building the database initially.
with mock.patch("os.listdir", mock.Mock(side_effect=fake_listdir)):
with self.assertRaises(PrepareDatabaseException):
# Synapse should think that there is an outstanding migration file due to
# patching 'os.listdir' in the function decorator.
#
# We expect Synapse to raise an exception to indicate the master process
# needs to apply this migration file.
prepare_database(db_conn, db_pool.engine, self.hs.config)

View File

@@ -44,7 +44,6 @@ from twisted.python.threadpool import ThreadPool
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.trial import unittest from twisted.trial import unittest
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse import events from synapse import events
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
@@ -96,13 +95,16 @@ def around(target):
return _around return _around
T = TypeVar("T")
class TestCase(unittest.TestCase): class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel' """A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the attributes on both itself and its individual test methods, to override the
root logger's logging level while that test (case|method) runs.""" root logger's logging level while that test (case|method) runs."""
def __init__(self, methodName: str): def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName) super().__init__(methodName, *args, **kwargs)
method = getattr(self, methodName) method = getattr(self, methodName)
@@ -218,16 +220,16 @@ class HomeserverTestCase(TestCase):
Attributes: Attributes:
servlets: List of servlet registration function. servlets: List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked. user_id (str): The user ID to assume if auth is hijacked.
hijack_auth: Whether to hijack auth to return the user specified hijack_auth (bool): Whether to hijack auth to return the user specified
in user_id. in user_id.
""" """
hijack_auth: ClassVar[bool] = True hijack_auth = True
needs_threadpool: ClassVar[bool] = False needs_threadpool = False
servlets: ClassVar[List[RegisterServletsFunc]] = [] servlets: ClassVar[List[RegisterServletsFunc]] = []
def __init__(self, methodName: str): def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName) super().__init__(methodName, *args, **kwargs)
# see if we have any additional config for this test # see if we have any additional config for this test
method = getattr(self, methodName) method = getattr(self, methodName)
@@ -299,10 +301,9 @@ class HomeserverTestCase(TestCase):
None, None,
) )
# Type ignore: mypy doesn't like us assigning to methods. self.hs.get_auth().get_user_by_req = get_user_by_req
self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment] self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] self.hs.get_auth().get_access_token_from_request = Mock(
self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
return_value="1234" return_value="1234"
) )
@@ -416,7 +417,7 @@ class HomeserverTestCase(TestCase):
path: Union[bytes, str], path: Union[bytes, str],
content: Union[bytes, str, JsonDict] = b"", content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None, access_token: Optional[str] = None,
request: Type[Request] = SynapseRequest, request: Type[T] = SynapseRequest,
shorthand: bool = True, shorthand: bool = True,
federation_auth_origin: Optional[bytes] = None, federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False, content_is_form: bool = False,
@@ -595,7 +596,7 @@ class HomeserverTestCase(TestCase):
nonce_str += b"\x00notadmin" nonce_str += b"\x00notadmin"
want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
want_mac_digest = want_mac.hexdigest() want_mac = want_mac.hexdigest()
body = json.dumps( body = json.dumps(
{ {
@@ -604,7 +605,7 @@ class HomeserverTestCase(TestCase):
"displayname": displayname, "displayname": displayname,
"password": password, "password": password,
"admin": admin, "admin": admin,
"mac": want_mac_digest, "mac": want_mac,
"inhibit_login": True, "inhibit_login": True,
} }
) )