1
0

Compare commits

..

13 Commits

Author SHA1 Message Date
Erik Johnston 5de571987e WIP: Type TreeCache 2022-07-18 13:05:07 +01:00
Erik Johnston 129691f190 Comment TreeCache 2022-07-18 10:49:12 +01:00
Erik Johnston 462db2a171 Comment LruCache 2022-07-18 10:48:24 +01:00
Erik Johnston 7f7b36d56d Comment LruCache 2022-07-18 10:40:36 +01:00
Erik Johnston 057ae8b61c Comments 2022-07-17 11:34:34 +01:00
Erik Johnston 7aceec3ed9 Fix up 2022-07-15 16:52:16 +01:00
Erik Johnston cad555f07c Better stuff 2022-07-15 16:31:53 +01:00
Erik Johnston 23c2f394a5 Fix mypy 2022-07-15 16:30:56 +01:00
Erik Johnston 602a81f5a2 don't update access 2022-07-15 16:24:10 +01:00
Erik Johnston f046366d2a Fix test 2022-07-15 15:54:01 +01:00
Erik Johnston a22716c5c5 Fix literal 2022-07-15 15:31:57 +01:00
Erik Johnston 40a8fba5f6 Newsfile 2022-07-15 15:27:03 +01:00
Erik Johnston 326a175987 Make DictionaryCache have better expiry properties 2022-07-15 15:26:02 +01:00
104 changed files with 983 additions and 1696 deletions
+1 -1
View File
@@ -69,7 +69,7 @@ with open('pyproject.toml', 'w') as f:
"
python3 -c "$REMOVE_DEV_DEPENDENCIES"
pipx install poetry==1.1.14
pipx install poetry==1.1.12
~/.local/bin/poetry lock
echo "::group::Patched pyproject.toml"
-13
View File
@@ -1,16 +1,3 @@
# Commits in this file will be removed from GitHub blame results.
#
# To use this file locally, use:
# git blame --ignore-revs-file="path/to/.git-blame-ignore-revs" <files>
#
# or configure the `blame.ignoreRevsFile` option in your git config.
#
# If ignoring a pull request that was not squash merged, only the merge
# commit needs to be put here. Child commits will be resolved from it.
# Run black (#3679).
8b3d9b6b199abb87246f982d5db356f1966db925
# Black reformatting (#5482).
32e7c9e7f20b57dd081023ac42d6931a8da9b3a3
+2 -2
View File
@@ -127,12 +127,12 @@ jobs:
run: |
set -x
DEBIAN_FRONTEND=noninteractive sudo apt-get install -yqq python3 pipx
pipx install poetry==1.1.14
pipx install poetry==1.1.12
poetry remove -n twisted
poetry add -n --extras tls git+https://github.com/twisted/twisted.git#trunk
poetry lock --no-update
# NOT IN 1.1.14 poetry lock --check
# NOT IN 1.1.12 poetry lock --check
working-directory: synapse
- run: |
+2 -11
View File
@@ -3,15 +3,6 @@ Synapse vNext
As of this release, Synapse no longer allows the tasks of verifying email address ownership, and password reset confirmation, to be delegated to an identity server. For more information, see the [upgrade notes](https://matrix-org.github.io/synapse/v1.64/upgrade.html#upgrading-to-v1640).
Synapse 1.63.0 (2022-07-19)
===========================
Improved Documentation
----------------------
- Clarify that homeserver server names are included in the reported data when the `report_stats` config option is enabled. ([\#13321](https://github.com/matrix-org/synapse/issues/13321))
Synapse 1.63.0rc1 (2022-07-12)
==============================
@@ -20,7 +11,7 @@ Features
- Add a rate limit for local users sending invites. ([\#13125](https://github.com/matrix-org/synapse/issues/13125))
- Implement [MSC3827](https://github.com/matrix-org/matrix-spec-proposals/pull/3827): Filtering of `/publicRooms` by room type. ([\#13031](https://github.com/matrix-org/synapse/issues/13031))
- Improve validation logic in the account data REST endpoints. ([\#13148](https://github.com/matrix-org/synapse/issues/13148))
- Improve validation logic in Synapse's REST endpoints. ([\#13148](https://github.com/matrix-org/synapse/issues/13148))
Bugfixes
@@ -48,7 +39,7 @@ Improved Documentation
- Add an explanation of the `--report-stats` argument to the docs. ([\#13029](https://github.com/matrix-org/synapse/issues/13029))
- Add a helpful example bash script to the contrib directory for creating multiple worker configuration files of the same type. Contributed by @villepeh. ([\#13032](https://github.com/matrix-org/synapse/issues/13032))
- Add missing links to config options. ([\#13166](https://github.com/matrix-org/synapse/issues/13166))
- Add documentation for homeserver usage statistics collection. ([\#13086](https://github.com/matrix-org/synapse/issues/13086))
- Add documentation for anonymised homeserver statistics collection. ([\#13086](https://github.com/matrix-org/synapse/issues/13086))
- Add documentation for the existing `databases` option in the homeserver configuration manual. ([\#13212](https://github.com/matrix-org/synapse/issues/13212))
- Clean up references to sample configuration and redirect users to the configuration manual instead. ([\#13077](https://github.com/matrix-org/synapse/issues/13077), [\#13139](https://github.com/matrix-org/synapse/issues/13139))
- Document how the Synapse team does reviews. ([\#13132](https://github.com/matrix-org/synapse/issues/13132))
-1
View File
@@ -1 +0,0 @@
Use lower isolation level when purging rooms to avoid serialization errors. Contributed by Nick @ Beeper.
-1
View File
@@ -1 +0,0 @@
Provide more info why we don't have any thumbnails to serve.
-1
View File
@@ -1 +0,0 @@
Always use a version of canonicaljson that supports the C implementation of frozendict.
-1
View File
@@ -1 +0,0 @@
Fix spurious warning when fetching state after a missing prev event.
-1
View File
@@ -1 +0,0 @@
Add another `contrib` script to help set up worker processes. Contributed by @villepeh.
-1
View File
@@ -1 +0,0 @@
Add per-room rate limiting for room joins. For each room, Synapse now monitors the rate of join events in that room, and throttle additional joins if that rate grows too large.
-1
View File
@@ -1 +0,0 @@
Don't pull out the full state when creating an event.
-1
View File
@@ -1 +0,0 @@
Upgrade from Poetry 1.1.14 to 1.1.12, to fix bugs when locking packages.
+1
View File
@@ -0,0 +1 @@
Make `DictionaryCache` expire full entries if they haven't been queried in a while, even if specific keys have been queried recently.
-1
View File
@@ -1 +0,0 @@
Fix a bug introduced in v1.18.0 where the `synapse_pushers` metric would overcount pushers when they are replaced.
-1
View File
@@ -1 +0,0 @@
Use `HTTPStatus` constants in place of literals in tests.
-1
View File
@@ -1 +0,0 @@
Improve performance of query `_get_subset_users_in_room_with_profiles`.
-1
View File
@@ -1 +0,0 @@
Up batch size of `bulk_get_push_rules` and `_get_joined_profiles_from_event_ids`.
-1
View File
@@ -1 +0,0 @@
Remove unnecessary `json.dumps` from tests.
-1
View File
@@ -1 +0,0 @@
Don't pull out the full state when creating an event.
-1
View File
@@ -1 +0,0 @@
Use an asynchronous cache wrapper for the get event cache. Contributed by Nick @ Beeper (@fizzadar).
-1
View File
@@ -1 +0,0 @@
Reduce memory usage of sending dummy events.
-1
View File
@@ -1 +0,0 @@
Prevent formatting changes of [#3679](https://github.com/matrix-org/synapse/pull/3679) from appearing in `git blame`.
-1
View File
@@ -1 +0,0 @@
Add notes when config options where changed. Contributed by @behrmann.
-1
View File
@@ -1 +0,0 @@
Reduce memory usage of state caches.
-1
View File
@@ -1 +0,0 @@
Stop builindg `.deb` packages for Ubuntu 21.10 (Impish Indri), which has reached end of life.
-1
View File
@@ -1 +0,0 @@
Add type hints to `trace` decorator.
-1
View File
@@ -1 +0,0 @@
Document the new `rc_invites.per_issuer` throttling option added in Synapse 1.63.
@@ -1,145 +0,0 @@
# Creating multiple stream writers with a bash script
This script creates multiple [stream writer](https://github.com/matrix-org/synapse/blob/develop/docs/workers.md#stream-writers) workers.
Stream writers require both replication and HTTP listeners.
It also prints out the example lines for Synapse main configuration file.
Remember to route necessary endpoints directly to a worker associated with it.
If you run the script as-is, it will create workers with the replication listener starting from port 8034 and another, regular http listener starting from 8044. If you don't need all of the stream writers listed in the script, just remove them from the ```STREAM_WRITERS``` array.
```sh
#!/bin/bash
# Start with these replication and http ports.
# The script loop starts with the exact port and then increments it by one.
REP_START_PORT=8034
HTTP_START_PORT=8044
# Stream writer workers to generate. Feel free to add or remove them as you wish.
# Event persister ("events") isn't included here as it does not require its
# own HTTP listener.
STREAM_WRITERS+=( "presence" "typing" "receipts" "to_device" "account_data" )
NUM_WRITERS=$(expr ${#STREAM_WRITERS[@]})
i=0
while [ $i -lt "$NUM_WRITERS" ]
do
cat << EOF > ${STREAM_WRITERS[$i]}_stream_writer.yaml
worker_app: synapse.app.generic_worker
worker_name: ${STREAM_WRITERS[$i]}_stream_writer
# The replication listener on the main synapse process.
worker_replication_host: 127.0.0.1
worker_replication_http_port: 9093
worker_listeners:
- type: http
port: $(expr $REP_START_PORT + $i)
resources:
- names: [replication]
- type: http
port: $(expr $HTTP_START_PORT + $i)
resources:
- names: [client]
worker_log_config: /etc/matrix-synapse/stream-writer-log.yaml
EOF
HOMESERVER_YAML_INSTANCE_MAP+=$" ${STREAM_WRITERS[$i]}_stream_writer:
host: 127.0.0.1
port: $(expr $REP_START_PORT + $i)
"
HOMESERVER_YAML_STREAM_WRITERS+=$" ${STREAM_WRITERS[$i]}: ${STREAM_WRITERS[$i]}_stream_writer
"
((i++))
done
cat << EXAMPLECONFIG
# Add these lines to your homeserver.yaml.
# Don't forget to configure your reverse proxy and
# necessary endpoints to their respective worker.
# See https://github.com/matrix-org/synapse/blob/develop/docs/workers.md
# for more information.
# Remember: Under NO circumstances should the replication
# listener be exposed to the public internet;
# it has no authentication and is unencrypted.
instance_map:
$HOMESERVER_YAML_INSTANCE_MAP
stream_writers:
$HOMESERVER_YAML_STREAM_WRITERS
EXAMPLECONFIG
```
Copy the code above save it to a file ```create_stream_writers.sh``` (for example).
Make the script executable by running ```chmod +x create_stream_writers.sh```.
## Run the script to create workers and print out a sample configuration
Simply run the script to create YAML files in the current folder and print out the required configuration for ```homeserver.yaml```.
```console
$ ./create_stream_writers.sh
# Add these lines to your homeserver.yaml.
# Don't forget to configure your reverse proxy and
# necessary endpoints to their respective worker.
# See https://github.com/matrix-org/synapse/blob/develop/docs/workers.md
# for more information
# Remember: Under NO circumstances should the replication
# listener be exposed to the public internet;
# it has no authentication and is unencrypted.
instance_map:
presence_stream_writer:
host: 127.0.0.1
port: 8034
typing_stream_writer:
host: 127.0.0.1
port: 8035
receipts_stream_writer:
host: 127.0.0.1
port: 8036
to_device_stream_writer:
host: 127.0.0.1
port: 8037
account_data_stream_writer:
host: 127.0.0.1
port: 8038
stream_writers:
presence: presence_stream_writer
typing: typing_stream_writer
receipts: receipts_stream_writer
to_device: to_device_stream_writer
account_data: account_data_stream_writer
```
Simply copy-and-paste the output to an appropriate place in your Synapse main configuration file.
## Write directly to Synapse configuration file
You could also write the output directly to homeserver main configuration file. **This, however, is not recommended** as even a small typo (such as replacing >> with >) can erase the entire ```homeserver.yaml```.
If you do this, back up your original configuration file first:
```console
# Back up homeserver.yaml first
cp /etc/matrix-synapse/homeserver.yaml /etc/matrix-synapse/homeserver.yaml.bak
# Create workers and write output to your homeserver.yaml
./create_stream_writers.sh >> /etc/matrix-synapse/homeserver.yaml
```
@@ -1,4 +1,4 @@
# Creating multiple generic workers with a bash script
# Creating multiple workers with a bash script
Setting up multiple worker configuration files manually can be time-consuming.
You can alternatively create multiple worker configuration files with a simple `bash` script. For example:
-8
View File
@@ -1,11 +1,3 @@
matrix-synapse-py3 (1.63.0) stable; urgency=medium
* Clarify that homeserver server names are included in the data reported
by opt-in server stats reporting (`report_stats` homeserver config option).
* New Synapse release 1.63.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 19 Jul 2022 14:42:24 +0200
matrix-synapse-py3 (1.63.0~rc1) stable; urgency=medium
* New Synapse release 1.63.0rc1.
+1 -1
View File
@@ -31,7 +31,7 @@ EOF
# This file is autogenerated, and will be recreated on upgrade if it is deleted.
# Any changes you make will be preserved.
# Whether to report homeserver usage statistics.
# Whether to report anonymized homeserver usage statistics.
report_stats: false
EOF
fi
+6 -6
View File
@@ -37,7 +37,7 @@ msgstr ""
#. Type: boolean
#. Description
#: ../templates:2001
msgid "Report homeserver usage statistics?"
msgid "Report anonymous statistics?"
msgstr ""
#. Type: boolean
@@ -45,11 +45,11 @@ msgstr ""
#: ../templates:2001
msgid ""
"Developers of Matrix and Synapse really appreciate helping the project out "
"by reporting homeserver usage statistics from this homeserver. Your "
"homeserver's server name, along with very basic aggregate data (e.g. "
"number of users) will be reported. But it helps track the growth of the "
"Matrix community, and helps in making Matrix a success, as well as to "
"convince other networks that they should peer with Matrix."
"by reporting anonymized usage statistics from this homeserver. Only very "
"basic aggregate data (e.g. number of users) will be reported, but it helps "
"track the growth of the Matrix community, and helps in making Matrix a "
"success, as well as to convince other networks that they should peer with "
"Matrix."
msgstr ""
#. Type: boolean
+6 -7
View File
@@ -10,13 +10,12 @@ _Description: Name of the server:
Template: matrix-synapse/report-stats
Type: boolean
Default: false
_Description: Report homeserver usage statistics?
_Description: Report anonymous statistics?
Developers of Matrix and Synapse really appreciate helping the
project out by reporting homeserver usage statistics from this
homeserver. Your homeserver's server name, along with very basic
aggregate data (e.g. number of users) will be reported. But it
helps track the growth of the Matrix community, and helps in
making Matrix a success, as well as to convince other networks
that they should peer with Matrix.
project out by reporting anonymized usage statistics from this
homeserver. Only very basic aggregate data (e.g. number of users)
will be reported, but it helps track the growth of the Matrix
community, and helps in making Matrix a success, as well as to
convince other networks that they should peer with Matrix.
.
Thank you.
+1 -1
View File
@@ -45,7 +45,7 @@ RUN \
# We install poetry in its own build stage to avoid its dependencies conflicting with
# synapse's dependencies.
# We use a specific commit from poetry's master branch instead of our usual 1.1.14,
# We use a specific commit from poetry's master branch instead of our usual 1.1.12,
# to incorporate fixes to some bugs in `poetry export`. This commit corresponds to
# https://github.com/python-poetry/poetry/pull/5156 and
# https://github.com/python-poetry/poetry/issues/5141 ;
@@ -67,10 +67,6 @@ rc_joins:
per_second: 9999
burst_count: 9999
rc_joins_per_room:
per_second: 9999
burst_count: 9999
rc_3pid_validation:
per_second: 1000
burst_count: 1000
+1 -1
View File
@@ -68,7 +68,7 @@
- [Federation](usage/administration/admin_api/federation.md)
- [Manhole](manhole.md)
- [Monitoring](metrics-howto.md)
- [Reporting Homeserver Usage Statistics](usage/administration/monitoring/reporting_homeserver_usage_statistics.md)
- [Reporting Anonymised Statistics](usage/administration/monitoring/reporting_anonymised_statistics.md)
- [Understanding Synapse Through Grafana Graphs](usage/administration/understanding_synapse_through_grafana_graphs.md)
- [Useful SQL for Admins](usage/administration/useful_sql_for_admins.md)
- [Database Maintenance Tools](usage/administration/database_maintenance_tools.md)
-25
View File
@@ -237,28 +237,3 @@ poetry run pip install build && poetry run python -m build
because [`build`](https://github.com/pypa/build) is a standardish tool which
doesn't require poetry. (It's what we use in CI too). However, you could try
`poetry build` too.
# Troubleshooting
## Check the version of poetry with `poetry --version`.
At the time of writing, the 1.2 series is beta only. We have seen some examples
where the lockfiles generated by 1.2 prereleasese aren't interpreted correctly
by poetry 1.1.x. For now, use poetry 1.1.14, which includes a critical
[change](https://github.com/python-poetry/poetry/pull/5973) needed to remain
[compatible with PyPI](https://github.com/pypi/warehouse/pull/11775).
It can also be useful to check the version of `poetry-core` in use. If you've
installed `poetry` with `pipx`, try `pipx runpip poetry list | grep poetry-core`.
## Clear caches: `poetry cache clear --all pypi`.
Poetry caches a bunch of information about packages that isn't readily available
from PyPI. (This is what makes poetry seem slow when doing the first
`poetry install`.) Try `poetry cache list` and `poetry cache clear --all
<name of cache>` to see if that fixes things.
## Try `--verbose` or `--dry-run` arguments.
Sometimes useful to see what poetry's internal logic is.
-10
View File
@@ -104,16 +104,6 @@ minimum, a `notif_from` setting.)
Specifying an `email` setting under `account_threepid_delegates` will now cause
an error at startup.
## Changes to the event replication streams
Synapse now includes a flag indicating if an event is an outlier when
replicating it to other workers. This is a forwards- and backwards-incompatible
change: v1.63 and workers cannot process events replicated by v1.64 workers, and
vice versa.
Once all workers are upgraded to v1.64 (or downgraded to v1.63), event
replication will resume as normal.
# Upgrading to v1.62.0
## New signatures for spam checker callbacks
@@ -1,11 +1,11 @@
# Reporting Homeserver Usage Statistics
# Reporting Anonymised Statistics
When generating your Synapse configuration file, you are asked whether you
would like to report usage statistics to Matrix.org. These statistics
would like to report anonymised statistics to Matrix.org. These statistics
provide the foundation a glimpse into the number of Synapse homeservers
participating in the network, as well as statistics such as the number of
rooms being created and messages being sent. This feature is sometimes
affectionately called "phone home" stats. Reporting
affectionately called "phone-home" stats. Reporting
[is optional](../../configuration/config_documentation.md#report_stats)
and the reporting endpoint
[can be configured](../../configuration/config_documentation.md#report_stats_endpoint),
@@ -21,9 +21,9 @@ The following statistics are sent to the configured reporting endpoint:
| Statistic Name | Type | Description |
|----------------------------|--------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `homeserver` | string | The homeserver's server name. |
| `memory_rss` | int | The memory usage of the process (in kilobytes on Unix-based systems, bytes on MacOS). |
| `cpu_average` | int | CPU time in % of a single core (not % of all cores). |
| `homeserver` | string | The homeserver's server name. |
| `server_context` | string | An arbitrary string used to group statistics from a set of homeservers. |
| `timestamp` | int | The current time, represented as the number of seconds since the epoch. |
| `uptime_seconds` | int | The number of seconds since the homeserver was last started. |
@@ -239,8 +239,6 @@ If this option is provided, it parses the given yaml to json and
serves it on `/.well-known/matrix/client` endpoint
alongside the standard properties.
*Added in Synapse 1.62.0.*
Example configuration:
```yaml
extra_well_known_client_content :
@@ -1157,9 +1155,6 @@ Caching can be configured through the following sub-options:
with intermittent connections, at the cost of higher memory usage.
A value of zero means that sync responses are not cached.
Defaults to 2m.
*Changed in Synapse 1.62.0*: The default was changed from 0 to 2m.
* `cache_autotuning` and its sub-options `max_cache_memory_usage`, `target_cache_memory_usage`, and
`min_cache_ttl` work in conjunction with each other to maintain a balance between cache memory
usage and cache entry availability. You must be using [jemalloc](https://github.com/matrix-org/synapse#help-synapse-is-slow-and-eats-all-my-ramcpu)
@@ -1476,25 +1471,6 @@ rc_joins:
per_second: 0.03
burst_count: 12
```
---
### `rc_joins_per_room`
This option allows admins to ratelimit joins to a room based on the number of recent
joins (local or remote) to that room. It is intended to mitigate mass-join spam
waves which target multiple homeservers.
By default, one join is permitted to a room every second, with an accumulating
buffer of up to ten instantaneous joins.
Example configuration (default values):
```yaml
rc_joins_per_room:
per_second: 1
burst_count: 10
```
_Added in Synapse 1.64.0._
---
### `rc_3pid_validation`
@@ -1528,8 +1504,6 @@ The `rc_invites.per_user` limit applies to the *receiver* of the invite, rather
sender, meaning that a `rc_invite.per_user.burst_count` of 5 mandates that a single user
cannot *receive* more than a burst of 5 invites at a time.
In contrast, the `rc_invites.per_issuer` limit applies to the *issuer* of the invite, meaning that a `rc_invite.per_issuer.burst_count` of 5 mandates that single user cannot *send* more than a burst of 5 invites at a time.
Example configuration:
```yaml
rc_invites:
@@ -1539,13 +1513,7 @@ rc_invites:
per_user:
per_second: 0.004
burst_count: 3
per_issuer:
per_second: 0.5
burst_count: 5
```
_Changed in version 1.63:_ added the `per_issuer` limit.
---
### `rc_third_party_invite`
@@ -2437,14 +2405,9 @@ metrics_flags:
---
### `report_stats`
Whether or not to report homeserver usage statistics. This is originally
Whether or not to report anonymized homeserver usage statistics. This is originally
set when generating the config. Set this option to true or false to change the current
behavior. See
[Reporting Homeserver Usage Statistics](../administration/monitoring/reporting_homeserver_usage_statistics.md)
for information on what data is reported.
Statistics will be reported 5 minutes after Synapse starts, and then every 3 hours
after that.
behavior.
Example configuration:
```yaml
@@ -2453,7 +2416,7 @@ report_stats: true
---
### `report_stats_endpoint`
The endpoint to report homeserver usage statistics to.
The endpoint to report the anonymized homeserver usage statistics to.
Defaults to https://matrix.org/report-usage-stats/push
Example configuration:
Generated
+1 -1
View File
@@ -1563,7 +1563,7 @@ url_preview = ["lxml"]
[metadata]
lock-version = "1.1"
python-versions = "^3.7.1"
content-hash = "c24bbcee7e86dbbe7cdbf49f91a25b310bf21095452641e7440129f59b077f78"
content-hash = "e96625923122e29b6ea5964379828e321b6cede2b020fc32c6f86c09d86d1ae8"
[metadata.files]
attrs = [
+2 -4
View File
@@ -54,7 +54,7 @@ skip_gitignore = true
[tool.poetry]
name = "matrix-synapse"
version = "1.63.0"
version = "1.63.0rc1"
description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "Apache-2.0"
@@ -110,9 +110,7 @@ jsonschema = ">=3.0.0"
frozendict = ">=1,!=2.1.2"
# We require 2.1.0 or higher for type hints. Previous guard was >= 1.1.0
unpaddedbase64 = ">=2.1.0"
# We require 1.5.0 to work around an issue when running against the C implementation of
# frozendict: https://github.com/matrix-org/python-canonicaljson/issues/36
canonicaljson = "^1.5.0"
canonicaljson = "^1.4.0"
# we use the type definitions added in signedjson 1.1.
signedjson = "^1.1.0"
# validating SSL certs for IP addresses requires service_identity 18.1.
+1
View File
@@ -26,6 +26,7 @@ DISTS = (
"debian:bookworm",
"debian:sid",
"ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14)
"ubuntu:impish", # 21.10 (EOL 2022-07)
"ubuntu:jammy", # 22.04 LTS (EOL 2027-04)
)
+1 -1
View File
@@ -33,7 +33,7 @@ def main() -> None:
parser.add_argument(
"--report-stats",
action="store",
help="Whether the generated config reports homeserver usage statistics",
help="Whether the generated config reports anonymized usage statistics",
choices=["yes", "no"],
)
+7 -7
View File
@@ -97,16 +97,16 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
# We split these messages out to allow packages to override with package
# specific instructions.
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\
Please opt in or out of reporting homeserver usage statistics, by setting
the `report_stats` key in your config file to either True or False.
Please opt in or out of reporting anonymized homeserver usage statistics, by
setting the `report_stats` key in your config file to either True or False.
"""
MISSING_REPORT_STATS_SPIEL = """\
We would really appreciate it if you could help our project out by reporting
homeserver usage statistics from your homeserver. Your homeserver's server name,
along with very basic aggregate data (e.g. number of users) will be reported. But
it helps us to track the growth of the Matrix community, and helps us to make Matrix
a success, as well as to convince other networks that they should peer with us.
anonymized usage statistics from your homeserver. Only very basic aggregate
data (e.g. number of users) will be reported, but it helps us to track the
growth of the Matrix community, and helps us to make Matrix a success, as well
as to convince other networks that they should peer with us.
Thank you.
"""
@@ -621,7 +621,7 @@ class RootConfig:
generate_group.add_argument(
"--report-stats",
action="store",
help="Whether the generated config reports homeserver usage statistics.",
help="Whether the generated config reports anonymized usage statistics.",
choices=["yes", "no"],
)
generate_group.add_argument(
-7
View File
@@ -112,13 +112,6 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.01, "burst_count": 10},
)
# Track the rate of joins to a given room. If there are too many, temporarily
# prevent local joins and remote joins via this server.
self.rc_joins_per_room = RateLimitConfig(
config.get("rc_joins_per_room", {}),
defaults={"per_second": 1, "burst_count": 10},
)
# Ratelimit cross-user key requests:
# * For local requests this is keyed by the sending device.
# * For requests received over federation this is keyed by the origin.
+7 -28
View File
@@ -42,18 +42,6 @@ THUMBNAIL_SIZE_YAML = """\
# method: %(method)s
"""
# A map from the given media type to the type of thumbnail we should generate
# for it.
THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP = {
"image/jpeg": "jpeg",
"image/jpg": "jpeg",
"image/webp": "jpeg",
# Thumbnails can only be jpeg or png. We choose png thumbnails for gif
# because it can have transparency.
"image/gif": "png",
"image/png": "png",
}
HTTP_PROXY_SET_WARNING = """\
The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured."""
@@ -91,22 +79,13 @@ def parse_thumbnail_requirements(
width = size["width"]
height = size["height"]
method = size["method"]
for format, thumbnail_format in THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.items():
requirement = requirements.setdefault(format, [])
if thumbnail_format == "jpeg":
requirement.append(
ThumbnailRequirement(width, height, method, "image/jpeg")
)
elif thumbnail_format == "png":
requirement.append(
ThumbnailRequirement(width, height, method, "image/png")
)
else:
raise Exception(
"Unknown thumbnail mapping from %s to %s. This is a Synapse problem, please report!"
% (format, thumbnail_format)
)
jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg")
png_thumbnail = ThumbnailRequirement(width, height, method, "image/png")
requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail)
requirements.setdefault("image/jpg", []).append(jpeg_thumbnail)
requirements.setdefault("image/webp", []).append(jpeg_thumbnail)
requirements.setdefault("image/gif", []).append(png_thumbnail)
requirements.setdefault("image/png", []).append(png_thumbnail)
return {
media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items()
}
+1 -7
View File
@@ -24,11 +24,9 @@ from synapse.api.room_versions import (
RoomVersion,
)
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.storage.state import StateFilter
from synapse.types import EventID, JsonDict
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -123,11 +121,7 @@ class EventBuilder:
"""
if auth_event_ids is None:
state_ids = await self._state.compute_state_after_events(
self.room_id,
prev_event_ids,
state_filter=StateFilter.from_types(
auth_types_for_event(self.room_version, self)
),
self.room_id, prev_event_ids
)
auth_event_ids = self._event_auth_handler.compute_auth_events(
self, state_ids
+1 -1
View File
@@ -217,7 +217,7 @@ class FederationClient(FederationBase):
)
async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: Optional[int]
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
-16
View File
@@ -118,7 +118,6 @@ class FederationServer(FederationBase):
self._federation_event_handler = hs.get_federation_event_handler()
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
self._room_member_handler = hs.get_room_member_handler()
self._state_storage_controller = hs.get_storage_controllers().state
@@ -622,15 +621,6 @@ class FederationServer(FederationBase):
)
raise IncompatibleRoomVersionError(room_version=room_version)
# Refuse the request if that room has seen too many joins recently.
# This is in addition to the HS-level rate limiting applied by
# BaseFederationServlet.
# type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?)
await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
requester=None,
key=room_id,
update=False,
)
pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
@@ -665,12 +655,6 @@ class FederationServer(FederationBase):
room_id: str,
caller_supports_partial_state: bool = False,
) -> Dict[str, Any]:
await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
requester=None,
key=room_id,
update=False,
)
event, context = await self._on_send_membership_event(
origin, content, Membership.JOIN, room_id
)
+1 -1
View File
@@ -619,7 +619,7 @@ class TransportLayerClient:
)
async def claim_client_keys(
self, destination: str, query_content: JsonDict, timeout: Optional[int]
self, destination: str, query_content: JsonDict, timeout: int
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
+7 -9
View File
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
@@ -92,11 +92,7 @@ class E2eKeysHandler:
@trace
async def query_devices(
self,
query_body: JsonDict,
timeout: int,
from_user_id: str,
from_device_id: Optional[str],
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
) -> JsonDict:
"""Handle a device key query from a client
@@ -124,7 +120,9 @@ class E2eKeysHandler:
the number of in-flight queries at a time.
"""
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
)
# separate users by domain.
# make a map from domain to user_id to device_ids
@@ -394,7 +392,7 @@ class E2eKeysHandler:
@trace
async def query_local_devices(
self, query: Mapping[str, Optional[List[str]]]
self, query: Dict[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users
@@ -463,7 +461,7 @@ class E2eKeysHandler:
@trace
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
-7
View File
@@ -1037,9 +1037,6 @@ class FederationEventHandler:
# XXX: this doesn't sound right? it means that we'll end up with incomplete
# state.
failed_to_fetch = desired_events - event_metadata.keys()
# `event_id` could be missing from `event_metadata` because it's not necessarily
# a state event. We've already checked that we've fetched it above.
failed_to_fetch.discard(event_id)
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state events for %s %s",
@@ -1983,10 +1980,6 @@ class FederationEventHandler:
event, event_pos, max_stream_token, extra_users=extra_users
)
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
# TODO retrieve the previous state, and exclude join -> join transitions
self._notifier.notify_user_joined_room(event.event_id, event.room_id)
def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event
+7 -12
View File
@@ -463,7 +463,6 @@ class EventCreationHandler:
)
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self._notifier = hs.get_notifier()
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
@@ -1551,16 +1550,6 @@ class EventCreationHandler:
requester, is_admin_redaction=is_admin_redaction
)
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
(
current_membership,
_,
) = await self.store.get_local_current_membership_for_user_in_room(
event.state_key, event.room_id
)
if current_membership != Membership.JOIN:
self._notifier.notify_user_joined_room(event.event_id, event.room_id)
await self._maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
@@ -1860,8 +1849,13 @@ class EventCreationHandler:
# For each room we need to find a joined member we can use to send
# the dummy event with.
members = await self.store.get_local_users_in_room(room_id)
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
members = await self.state.get_current_users_in_room(
room_id, latest_event_ids=latest_event_ids
)
for user_id in members:
if not self.hs.is_mine_id(user_id):
continue
requester = create_requester(user_id, authenticated_entity=self.server_name)
try:
event, context = await self.create_event(
@@ -1872,6 +1866,7 @@ class EventCreationHandler:
"room_id": room_id,
"sender": user_id,
},
prev_event_ids=latest_event_ids,
)
event.internal_metadata.proactively_send = False
-37
View File
@@ -94,29 +94,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
)
# Tracks joins from local users to rooms this server isn't a member of.
# I.e. joins this server makes by requesting /make_join /send_join from
# another server.
self._join_rate_limiter_remote = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
)
# TODO: find a better place to keep this Ratelimiter.
# It needs to be
# - written to by event persistence code
# - written to by something which can snoop on replication streams
# - read by the RoomMemberHandler to rate limit joins from local users
# - read by the FederationServer to rate limit make_joins and send_joins from
# other homeservers
# I wonder if a homeserver-wide collection of rate limiters might be cleaner?
self._join_rate_per_room_limiter = Ratelimiter(
store=self.store,
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second,
burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
)
# Ratelimiter for invites, keyed by room (across all issuers, all
# recipients).
@@ -153,18 +136,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
self.request_ratelimiter = hs.get_request_ratelimiter()
hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
"""Notify the rate limiter that a room join has occurred.
Use this to inform the RoomMemberHandler about joins that have either
- taken place on another homeserver, or
- on another worker in this homeserver.
Joins actioned by this worker should use the usual `ratelimit` method, which
checks the limit and increments the counter in one go.
"""
self._join_rate_per_room_limiter.record_action(requester=None, key=room_id)
@abc.abstractmethod
async def _remote_join(
@@ -425,9 +396,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# up blocking profile updates.
if newly_joined and ratelimit:
await self._join_rate_limiter_local.ratelimit(requester)
await self._join_rate_per_room_limiter.ratelimit(
requester, key=room_id, update=False
)
result_event = await self.event_creation_handler.handle_new_client_event(
requester,
@@ -899,11 +867,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
await self._join_rate_limiter_remote.ratelimit(
requester,
)
await self._join_rate_per_room_limiter.ratelimit(
requester,
key=room_id,
update=False,
)
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
+22 -28
View File
@@ -84,13 +84,14 @@ the function becomes the operation name for the span.
return something_usual_and_useful
Operation names can be explicitly set for a function by using ``trace_with_opname``:
Operation names can be explicitly set for a function by passing the
operation name to ``trace``
.. code-block:: python
from synapse.logging.opentracing import trace_with_opname
from synapse.logging.opentracing import trace
@trace_with_opname("a_better_operation_name")
@trace(opname="a_better_operation_name")
def interesting_badly_named_function(*args, **kwargs):
# Does all kinds of cool and expected things
return something_usual_and_useful
@@ -797,31 +798,33 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
# Tracing decorators
def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
def trace(func=None, opname: Optional[str] = None):
"""
Decorator to trace a function with a custom opname.
See the module's doc string for usage examples.
Decorator to trace a function.
Sets the operation name to that of the function's or that given
as operation_name. See the module's doc string for usage
examples.
"""
def decorator(func: Callable[P, R]) -> Callable[P, R]:
def decorator(func):
if opentracing is None:
return func # type: ignore[unreachable]
_opname = opname if opname else func.__name__
if inspect.iscoroutinefunction(func):
@wraps(func)
async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
with start_active_span(opname):
return await func(*args, **kwargs) # type: ignore[misc]
async def _trace_inner(*args, **kwargs):
with start_active_span(_opname):
return await func(*args, **kwargs)
else:
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
@wraps(func)
def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
scope = start_active_span(opname)
def _trace_inner(*args, **kwargs):
scope = start_active_span(_opname)
scope.__enter__()
try:
@@ -855,21 +858,12 @@ def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]
scope.__exit__(type(e), None, e.__traceback__)
raise
return _trace_inner # type: ignore[return-value]
return _trace_inner
return decorator
def trace(func: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to trace a function.
Sets the operation name to that of the function's name.
See the module's doc string for usage examples.
"""
return trace_with_opname(func.__name__)(func)
if func:
return decorator(func)
else:
return decorator
def tag_args(func: Callable[P, R]) -> Callable[P, R]:
+11 -16
View File
@@ -328,7 +328,7 @@ class PusherPool:
return None
try:
pusher = self.pusher_factory.create_pusher(pusher_config)
p = self.pusher_factory.create_pusher(pusher_config)
except PusherConfigException as e:
logger.warning(
"Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
@@ -346,28 +346,23 @@ class PusherPool:
)
return None
if not pusher:
if not p:
return None
appid_pushkey = "%s:%s" % (pusher.app_id, pusher.pushkey)
appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
byuser = self.pushers.setdefault(pusher.user_id, {})
byuser = self.pushers.setdefault(pusher_config.user_name, {})
if appid_pushkey in byuser:
previous_pusher = byuser[appid_pushkey]
previous_pusher.on_stop()
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
synapse_pushers.labels(
type(previous_pusher).__name__, previous_pusher.app_id
).dec()
byuser[appid_pushkey] = pusher
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
synapse_pushers.labels(type(p).__name__, p.app_id).inc()
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
# push.
user_id = pusher.user_id
last_stream_ordering = pusher.last_stream_ordering
user_id = pusher_config.user_name
last_stream_ordering = pusher_config.last_stream_ordering
if last_stream_ordering:
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering
@@ -377,9 +372,9 @@ class PusherPool:
# risk missing push.
have_notifs = True
pusher.on_started(have_notifs)
p.on_started(have_notifs)
return pusher
return p
async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
appid_pushkey = "%s:%s" % (app_id, pushkey)
+2 -2
View File
@@ -29,7 +29,7 @@ from synapse.http import RequestTimedOutError
from synapse.http.server import HttpServer, is_method_cancellable
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.opentracing import trace_with_opname
from synapse.logging.opentracing import trace
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -196,7 +196,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
"ascii"
)
@trace_with_opname("outgoing_replication_request")
@trace(opname="outgoing_replication_request")
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
+1 -16
View File
@@ -21,7 +21,7 @@ from twisted.internet.interfaces import IAddress, IConnector
from twisted.internet.protocol import ReconnectingClientFactory
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.constants import EventTypes, ReceiptTypes
from synapse.federation import send_queue
from synapse.federation.sender import FederationSender
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
@@ -219,21 +219,6 @@ class ReplicationDataHandler:
membership=row.data.membership,
)
# If this event is a join, make a note of it so we have an accurate
# cross-worker room rate limit.
# TODO: Erik said we should exclude rows that came from ex_outliers
# here, but I don't see how we can determine that. I guess we could
# add a flag to row.data?
if (
row.data.type == EventTypes.Member
and row.data.membership == Membership.JOIN
and not row.data.outlier
):
# TODO retrieve the previous state, and exclude join -> join transitions
self.notifier.notify_user_joined_room(
row.data.event_id, row.data.room_id
)
await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows
)
@@ -98,7 +98,6 @@ class EventsStreamEventRow(BaseEventsStreamRow):
relates_to: Optional[str]
membership: Optional[str]
rejected: bool
outlier: bool
@attr.s(slots=True, frozen=True, auto_attribs=True)
+2 -2
View File
@@ -26,7 +26,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.types import JsonDict, StreamToken
from ._base import client_patterns, interactive_auth_handler
@@ -71,7 +71,7 @@ class KeyUploadServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
@trace_with_opname("upload_keys")
@trace(opname="upload_keys")
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]:
+5 -8
View File
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional, Tuple, cast
from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
@@ -127,7 +127,7 @@ class RoomKeysServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
version = parse_string(request, "version", required=True)
version = parse_string(request, "version")
if session_id:
body = {"sessions": {session_id: body}}
@@ -196,11 +196,8 @@ class RoomKeysServlet(RestServlet):
user_id = requester.user.to_string()
version = parse_string(request, "version", required=True)
room_keys = cast(
JsonDict,
await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id
),
room_keys = await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id
)
# Convert room_keys to the right format to return.
@@ -243,7 +240,7 @@ class RoomKeysServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
version = parse_string(request, "version", required=True)
version = parse_string(request, "version")
ret = await self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id
+2 -2
View File
@@ -19,7 +19,7 @@ from synapse.http import servlet
from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag, trace_with_opname
from synapse.logging.opentracing import set_tag, trace
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict
@@ -43,7 +43,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
@trace_with_opname("sendToDevice")
@trace(opname="sendToDevice")
def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
+6 -6
View File
@@ -37,7 +37,7 @@ from synapse.handlers.sync import (
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace_with_opname
from synapse.logging.opentracing import trace
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder
@@ -210,7 +210,7 @@ class SyncRestServlet(RestServlet):
logger.debug("Event formatting complete")
return 200, response_content
@trace_with_opname("sync.encode_response")
@trace(opname="sync.encode_response")
async def encode_response(
self,
time_now: int,
@@ -315,7 +315,7 @@ class SyncRestServlet(RestServlet):
]
}
@trace_with_opname("sync.encode_joined")
@trace(opname="sync.encode_joined")
async def encode_joined(
self,
rooms: List[JoinedSyncResult],
@@ -340,7 +340,7 @@ class SyncRestServlet(RestServlet):
return joined
@trace_with_opname("sync.encode_invited")
@trace(opname="sync.encode_invited")
async def encode_invited(
self,
rooms: List[InvitedSyncResult],
@@ -371,7 +371,7 @@ class SyncRestServlet(RestServlet):
return invited
@trace_with_opname("sync.encode_knocked")
@trace(opname="sync.encode_knocked")
async def encode_knocked(
self,
rooms: List[KnockedSyncResult],
@@ -420,7 +420,7 @@ class SyncRestServlet(RestServlet):
return knocked
@trace_with_opname("sync.encode_archived")
@trace(opname="sync.encode_archived")
async def encode_archived(
self,
rooms: List[ArchivedSyncResult],
+2 -38
View File
@@ -17,11 +17,9 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
from synapse.api.errors import SynapseError
from synapse.http.server import (
DirectServeJsonResource,
respond_with_json,
set_corp_headers,
set_cors_headers,
)
@@ -311,19 +309,6 @@ class ThumbnailResource(DirectServeJsonResource):
url_cache: True if this is from a URL cache.
server_name: The server name, if this is a remote thumbnail.
"""
logger.debug(
"_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
media_id,
desired_width,
desired_height,
desired_method,
thumbnail_infos,
)
# If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
# different code path to handle it.
assert not self.dynamic_thumbnails
if thumbnail_infos:
file_info = self._select_thumbnail(
desired_width,
@@ -399,29 +384,8 @@ class ThumbnailResource(DirectServeJsonResource):
file_info.thumbnail.length,
)
else:
# This might be because:
# 1. We can't create thumbnails for the given media (corrupted or
# unsupported file type), or
# 2. The thumbnailing process never ran or errored out initially
# when the media was first uploaded (these bugs should be
# reported and fixed).
# Note that we don't attempt to generate a thumbnail now because
# `dynamic_thumbnails` is disabled.
logger.info("Failed to find any generated thumbnails")
respond_with_json(
request,
400,
cs_error(
"Cannot find any thumbnails for the requested media (%r). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
% (
request.postpath,
", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
),
code=Codes.UNKNOWN,
),
send_cors=True,
)
respond_404(request)
def _select_thumbnail(
self,
+1 -2
View File
@@ -157,7 +157,6 @@ class StateHandler:
self,
room_id: str,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
) -> StateMap[str]:
"""Fetch the state after each of the given event IDs. Resolve them and return.
@@ -175,7 +174,7 @@ class StateHandler:
"""
logger.debug("calling resolve_state_groups from compute_state_after_events")
ret = await self.resolve_state_groups_for_events(room_id, event_ids)
return await ret.get_state(self._state_storage_controller, state_filter)
return await ret.get_state(self._state_storage_controller, StateFilter.all())
async def get_current_users_in_room(
self, room_id: str, latest_event_ids: List[str]
+1 -8
View File
@@ -96,10 +96,6 @@ class SQLBaseStore(metaclass=ABCMeta):
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
Note that this function does not invalidate any remote caches, only the
local in-memory ones. Any remote invalidation must be performed before
calling this.
Args:
cache_name
key: Entry to invalidate. If None then invalidates the entire
@@ -116,10 +112,7 @@ class SQLBaseStore(metaclass=ABCMeta):
if key is None:
cache.invalidate_all()
else:
# Prefer any local-only invalidation method. Invalidating any non-local
# cache must be be done before this.
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
invalidate_method(tuple(key))
cache.invalidate(tuple(key))
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
+6 -48
View File
@@ -23,7 +23,6 @@ from time import monotonic as monotonic_time
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Collection,
Dict,
@@ -58,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.util.async_helpers import delay_cancellation
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -169,7 +168,6 @@ class LoggingDatabaseConnection:
*,
txn_name: Optional[str] = None,
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
) -> "LoggingTransaction":
if not txn_name:
@@ -180,7 +178,6 @@ class LoggingDatabaseConnection:
name=txn_name,
database_engine=self.engine,
after_callbacks=after_callbacks,
async_after_callbacks=async_after_callbacks,
exception_callbacks=exception_callbacks,
)
@@ -212,9 +209,6 @@ class LoggingDatabaseConnection:
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
_AsyncCallbackListEntry = Tuple[
Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
]
P = ParamSpec("P")
R = TypeVar("R")
@@ -233,10 +227,6 @@ class LoggingTransaction:
that have been added by `call_after` which should be run on
successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run.
async_after_callbacks: A list that asynchronous callbacks will be appended
to by `async_call_after` which should run, before after_callbacks, on
successful completion of the transaction. None indicates that no
callbacks should be allowed to be scheduled to run.
exception_callbacks: A list that callbacks will be appended
to that have been added by `call_on_exception` which should be run
if transaction ends with an error. None indicates that no callbacks
@@ -248,7 +238,6 @@ class LoggingTransaction:
"name",
"database_engine",
"after_callbacks",
"async_after_callbacks",
"exception_callbacks",
]
@@ -258,14 +247,12 @@ class LoggingTransaction:
name: str,
database_engine: BaseDatabaseEngine,
after_callbacks: Optional[List[_CallbackListEntry]] = None,
async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
):
self.txn = txn
self.name = name
self.database_engine = database_engine
self.after_callbacks = after_callbacks
self.async_after_callbacks = async_after_callbacks
self.exception_callbacks = exception_callbacks
def call_after(
@@ -290,28 +277,6 @@ class LoggingTransaction:
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
def async_call_after(
self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
) -> None:
"""Call the given asynchronous callback on the main twisted thread after
the transaction has finished (but before those added in `call_after`).
Mostly used to invalidate remote caches after transactions.
Note that transactions may be retried a few times if they encounter database
errors such as serialization failures. Callbacks given to `async_call_after`
will accumulate across transaction attempts and will _all_ be called once a
transaction attempt succeeds, regardless of whether previous transaction
attempts failed. Otherwise, if all transaction attempts fail, all
`call_on_exception` callbacks will be run instead.
"""
# if self.async_after_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
assert self.async_after_callbacks is not None
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
def call_on_exception(
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
) -> None:
@@ -609,7 +574,6 @@ class DatabasePool:
conn: LoggingDatabaseConnection,
desc: str,
after_callbacks: List[_CallbackListEntry],
async_after_callbacks: List[_AsyncCallbackListEntry],
exception_callbacks: List[_CallbackListEntry],
func: Callable[Concatenate[LoggingTransaction, P], R],
*args: P.args,
@@ -633,7 +597,6 @@ class DatabasePool:
conn
desc
after_callbacks
async_after_callbacks
exception_callbacks
func
*args
@@ -696,7 +659,6 @@ class DatabasePool:
cursor = conn.cursor(
txn_name=name,
after_callbacks=after_callbacks,
async_after_callbacks=async_after_callbacks,
exception_callbacks=exception_callbacks,
)
try:
@@ -836,7 +798,6 @@ class DatabasePool:
async def _runInteraction() -> R:
after_callbacks: List[_CallbackListEntry] = []
async_after_callbacks: List[_AsyncCallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
if not current_context():
@@ -848,7 +809,6 @@ class DatabasePool:
self.new_transaction,
desc,
after_callbacks,
async_after_callbacks,
exception_callbacks,
func,
*args,
@@ -857,17 +817,15 @@ class DatabasePool:
**kwargs,
)
# We order these assuming that async functions call out to external
# systems (e.g. to invalidate a cache) and the sync functions make these
# changes on any local in-memory caches/similar, and thus must be second.
for async_callback, async_args, async_kwargs in async_after_callbacks:
await async_callback(*async_args, **async_kwargs)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
await maybe_awaitable(after_callback(*after_args, **after_kwargs))
return cast(R, result)
except Exception:
for exception_callback, after_args, after_kwargs in exception_callbacks:
exception_callback(*after_args, **after_kwargs)
await maybe_awaitable(
exception_callback(*after_args, **after_kwargs)
)
raise
# To handle cancellation, we ensure that `after_callback`s and
@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# changed its content in the database. We can't call
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
# right type.
self.invalidate_get_event_cache_after_txn(txn, event.event_id)
txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
# Send that invalidation to replication so that other workers also invalidate
# the event cache.
self._send_invalidation_to_replication(
+1 -1
View File
@@ -669,7 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
@trace
async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, Optional[str]]]
self, query_list: List[Tuple[str, str]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
@@ -22,14 +22,11 @@ from typing import (
List,
Optional,
Tuple,
Union,
cast,
overload,
)
import attr
from canonicaljson import encode_canonical_json
from typing_extensions import Literal
from synapse.api.constants import DeviceKeyAlgorithms
from synapse.appservice import (
@@ -116,7 +113,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
result: JsonDict = {"device_id": device_id}
result = {"device_id": device_id}
keys = device.keys
if keys:
@@ -159,9 +156,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
rv[user_id] = {}
for device_id, device_info in device_keys.items():
r = device_info.keys
if r is None:
continue
r["unsigned"] = {}
display_name = device_info.display_name
if display_name is not None:
@@ -170,42 +164,13 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
return rv
@overload
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: Literal[False] = False,
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
...
@overload
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: Literal[False] = False,
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
...
@overload
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
include_all_devices: Literal[True],
include_deleted_devices: Literal[True],
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
...
@trace
async def get_e2e_device_keys_and_signatures(
self,
query_list: Collection[Tuple[str, Optional[str]]],
query_list: List[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Union[
Dict[str, Dict[str, DeviceKeyLookupResult]],
Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
]:
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Fetch a list of device keys
Any cross-signatures made on the keys by the owner of the device are also
@@ -1079,7 +1044,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
db_autocommit = False
claim_row = await self.db_pool.runInteraction(
row = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
_claim_e2e_one_time_key,
user_id,
@@ -1087,11 +1052,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
algorithm,
db_autocommit=db_autocommit,
)
if claim_row:
if row:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
device_results[claim_row[0]] = claim_row[1]
device_results[row[0]] = row[1]
continue
# No one-time key available, so see if there's a fallback
+3 -3
View File
@@ -1293,7 +1293,7 @@ class PersistEventsStore:
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
# Then update the `stream_ordering` position to mark the latest
# event as the front of the room. This should not be done for
# backfilled events because backfilled events have negative
@@ -1675,7 +1675,7 @@ class PersistEventsStore:
(cache_entry.event.event_id,), cache_entry
)
txn.async_call_after(prefill)
txn.call_after(prefill)
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
"""Invalidate the caches for the redacted event.
@@ -1684,7 +1684,7 @@ class PersistEventsStore:
_invalidate_caches_for_event.
"""
assert event.redacts is not None
self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
+17 -53
View File
@@ -712,41 +712,17 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
def invalidate_get_event_cache_after_txn(
self, txn: LoggingTransaction, event_id: str
) -> None:
"""
Prepares a database transaction to invalidate the get event cache for a given
event ID when executed successfully. This is achieved by attaching two callbacks
to the transaction, one to invalidate the async cache and one for the in memory
sync cache (importantly called in that order).
Arguments:
txn: the database transaction to attach the callbacks to
event_id: the event ID to be invalidated from caches
"""
txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
txn.call_after(self._invalidate_local_get_event_cache, event_id)
async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
"""
Invalidates an event in the asyncronous get event cache, which may be remote.
Arguments:
event_id: the event ID to invalidate
"""
async def _invalidate_get_event_cache(self, event_id: str) -> None:
# First we invalidate the asynchronous cache instance. This may include
# out-of-process caches such as Redis/memcache. Once complete we can
# invalidate any in memory cache. The ordering is important here to
# ensure we don't pull in any remote invalid value after we invalidate
# the in-memory cache.
await self._get_event_cache.invalidate((event_id,))
self._event_ref.pop(event_id, None)
self._current_event_fetches.pop(event_id, None)
def _invalidate_local_get_event_cache(self, event_id: str) -> None:
"""
Invalidates an event in local in-memory get event caches.
Arguments:
event_id: the event ID to invalidate
"""
self._get_event_cache.invalidate_local((event_id,))
self._event_ref.pop(event_id, None)
self._current_event_fetches.pop(event_id, None)
@@ -982,13 +958,7 @@ class EventsWorkerStore(SQLBaseStore):
}
row_dict = self.db_pool.new_transaction(
conn,
"do_fetch",
[],
[],
[],
self._fetch_event_rows,
events_to_fetch,
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
)
# We only want to resolve deferreds from the main thread
@@ -1490,7 +1460,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_all_new_forward_event_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns new events, for the Events replication stream
Args:
@@ -1506,11 +1476,10 @@ class EventsWorkerStore(SQLBaseStore):
def get_all_new_forward_event_rows(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
" e.outlier"
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
@@ -1524,8 +1493,7 @@ class EventsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (last_id, current_id, instance_name, limit))
return cast(
List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
txn.fetchall(),
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
)
return await self.db_pool.runInteraction(
@@ -1534,7 +1502,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_ex_outlier_stream_rows(
self, instance_name: str, last_id: int, current_id: int
) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
"""Returns de-outliered events, for the Events replication stream
Args:
@@ -1549,14 +1517,11 @@ class EventsWorkerStore(SQLBaseStore):
def get_ex_outlier_stream_rows_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
" e.outlier"
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
" FROM events AS e"
# NB: the next line (inner join) is what makes this query different from
# get_all_new_forward_event_rows.
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events AS se USING (event_id)"
@@ -1571,8 +1536,7 @@ class EventsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, instance_name))
return cast(
List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
txn.fetchall(),
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
)
return await self.db_pool.runInteraction(
@@ -66,7 +66,6 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
"initialise_mau_threepids",
[],
[],
[],
self._initialise_reserved_users,
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
)
+3 -32
View File
@@ -19,8 +19,6 @@ from synapse.api.errors import SynapseError
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main import CacheInvalidationWorkerStore
from synapse.storage.databases.main.state import StateGroupWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.engines._base import IsolationLevel
from synapse.types import RoomStreamToken
logger = logging.getLogger(__name__)
@@ -304,7 +302,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
self.invalidate_get_event_cache_after_txn(txn, event_id)
txn.call_after(self._invalidate_get_event_cache, event_id)
logger.info("[purge] done")
@@ -319,38 +317,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
Returns:
The list of state groups to delete.
"""
# This first runs the purge transaction with READ_COMMITTED isolation level,
# meaning any new rows in the tables will not trigger a serialization error.
# We then run the same purge a second time without this isolation level to
# purge any of those rows which were added during the first.
state_groups_to_delete = await self.db_pool.runInteraction(
"purge_room",
self._purge_room_txn,
room_id=room_id,
isolation_level=IsolationLevel.READ_COMMITTED,
return await self.db_pool.runInteraction(
"purge_room", self._purge_room_txn, room_id
)
state_groups_to_delete.extend(
await self.db_pool.runInteraction(
"purge_room",
self._purge_room_txn,
room_id=room_id,
),
)
return state_groups_to_delete
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
# This collides with event persistence so we cannot write new events and metadata into
# a room while deleting it or this transaction will fail.
if isinstance(self.database_engine, PostgresEngine):
txn.execute(
"SELECT room_version FROM rooms WHERE room_id = ? FOR UPDATE",
(room_id,),
)
# First, fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
@@ -228,7 +228,6 @@ class PushRulesWorkerStore(
iterable=user_ids,
retcols=("*",),
desc="bulk_get_push_rules",
batch_size=1000,
)
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+2 -2
View File
@@ -243,7 +243,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn: LoggingTransaction,
) -> Dict[str, ProfileInfo]:
clause, ids = make_in_list_sql_clause(
self.database_engine, "c.state_key", user_ids
self.database_engine, "m.user_id", user_ids
)
sql = """
@@ -904,7 +904,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
iterable=event_ids,
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=1000,
batch_size=500,
desc="_get_joined_profiles_from_event_ids",
)
@@ -24,7 +24,6 @@ from synapse.storage.database import (
from synapse.storage.engines import PostgresEngine
from synapse.storage.state import StateFilter
from synapse.types import MutableStateMap, StateMap
from synapse.util.caches import intern_string
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -137,7 +136,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
txn.execute(sql % (where_clause,), args)
for row in txn:
typ, state_key, event_id = row
key = (intern_string(typ), intern_string(state_key))
key = (typ, state_key)
results[group][key] = event_id
else:
max_entries_returned = state_filter.max_entries_returned()
+8 -1
View File
@@ -202,7 +202,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
requests state from the cache, if False we need to query the DB for the
missing state.
"""
cache_entry = cache.get(group)
# If we are asked explicitly for a subset of keys, we only ask for those
# from the cache. This ensures that the `DictionaryCache` can make
# better decisions about what to cache and what to expire.
dict_keys = None
if not state_filter.has_wildcards():
dict_keys = state_filter.concrete_types()
cache_entry = cache.get(group, dict_keys=dict_keys)
state_dict_ids = cache_entry.value
if cache_entry.full or state_filter.is_full():
+164 -29
View File
@@ -14,11 +14,13 @@
import enum
import logging
import threading
from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
import attr
from typing_extensions import Literal
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_items
logger = logging.getLogger(__name__)
@@ -53,20 +55,67 @@ class DictionaryEntry: # should be: Generic[DKT, DV].
return len(self.value)
class _FullCacheKey(enum.Enum):
"""The key we use to cache the full dict."""
KEY = object()
class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
class _PerKeyValue(Generic[DV]):
"""The cached value of a dictionary key. If `value` is the sentinel,
indicates that the requested key is known to *not* be in the full dict.
"""
__slots__ = ["value"]
def __init__(self, value: Union[DV, Literal[_Sentinel.sentinel]]) -> None:
self.value = value
def __len__(self) -> int:
# We add a `__len__` implementation as we use this class in a cache
# where the values are variable length.
return 1
class DictionaryCache(Generic[KT, DKT, DV]):
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key.
"""
def __init__(self, name: str, max_entries: int = 1000):
self.cache: LruCache[KT, DictionaryEntry] = LruCache(
max_size=max_entries, cache_name=name, size_callback=len
# We use a single cache to cache two different types of entries:
# 1. Map from (key, dict_key) -> dict value (or sentinel, indicating
# the key doesn't exist in the dict); and
# 2. Map from (key, _FullCacheKey.KEY) -> full dict.
#
# The former is used when explicit keys of the dictionary are looked up,
# and the latter when the full dictionary is requested.
#
# If when explicit keys are requested and not in the cache, we then look
# to see if we have the full dict and use that if we do. If found in the
# full dict each key is added into the cache.
#
# This set up allows the `LruCache` to prune the full dict entries if
# they haven't been used in a while, even when there have been recent
# queries for subsets of the dict.
#
# Typing:
# * A key of `(KT, DKT)` has a value of `_PerKeyValue`
# * A key of `(KT, _FullCacheKey.KEY)` has a value of `Dict[DKT, DV]`
self.cache: LruCache[
Tuple[KT, Union[DKT, Literal[_FullCacheKey.KEY]]],
Union[_PerKeyValue, Dict[DKT, DV]],
] = LruCache(
max_size=max_entries,
cache_name=name,
cache_type=TreeCache,
size_callback=len,
)
self.name = name
@@ -96,20 +145,97 @@ class DictionaryCache(Generic[KT, DKT, DV]):
Returns:
DictionaryEntry
"""
entry = self.cache.get(key, _Sentinel.sentinel)
if entry is not _Sentinel.sentinel:
if dict_keys is None:
return DictionaryEntry(
entry.full, entry.known_absent, dict(entry.value)
)
else:
return DictionaryEntry(
entry.full,
entry.known_absent,
{k: entry.value[k] for k in dict_keys if k in entry.value},
)
return DictionaryEntry(False, set(), {})
if dict_keys is None:
# First we check if we have cached the full dict.
entry = self.cache.get((key, _FullCacheKey.KEY), _Sentinel.sentinel)
if entry is not _Sentinel.sentinel:
assert isinstance(entry, dict)
return DictionaryEntry(True, set(), entry)
# If not, check if we have cached any of dict keys.
all_entries = self.cache.get_multi(
(key,),
_Sentinel.sentinel,
)
if all_entries is _Sentinel.sentinel:
return DictionaryEntry(False, set(), {})
# If there are entries we need to unwrap the returned cache nodes
# and `_PerKeyValue` into the `DictionaryEntry`.
values = {}
known_absent = set()
for dict_key, dict_value in iterate_tree_cache_items((), all_entries):
dict_key = dict_key[0]
dict_value = dict_value.value
# We have explicitly looked for a full cache key, so we
# shouldn't see one.
assert dict_key != _FullCacheKey.KEY
# ... therefore the values must be `_PerKeyValue`
assert isinstance(dict_value, _PerKeyValue)
if dict_value.value is _Sentinel.sentinel:
known_absent.add(dict_key)
else:
values[dict_key] = dict_value.value
return DictionaryEntry(False, known_absent, values)
# We are being asked for a subset of keys.
# First got and check for each requested dict key in the cache, tracking
# which we couldn't find.
values = {}
known_absent = set()
missing = set()
for dict_key in dict_keys:
entry = self.cache.get((key, dict_key), _Sentinel.sentinel)
if entry is _Sentinel.sentinel:
missing.add(dict_key)
continue
assert isinstance(entry, _PerKeyValue)
if entry.value is _Sentinel.sentinel:
known_absent.add(dict_key)
else:
values[dict_key] = entry.value
# If we found everything we can return immediately.
if not missing:
return DictionaryEntry(False, known_absent, values)
# If we are missing any keys check if we happen to have the full dict in
# the cache.
#
# We don't update the last access time for this cache fetch, as we
# aren't explicitly interested in the full dict and so we don't want
# requests for explicit dict keys to keep the full dict in the cache.
entry = self.cache.get(
(key, _FullCacheKey.KEY),
_Sentinel.sentinel,
update_last_access=False,
)
if entry is _Sentinel.sentinel:
# Not in the cache, return the subset of keys we found.
return DictionaryEntry(False, known_absent, values)
# We have the full dict!
assert isinstance(entry, dict)
values = {}
for dict_key in dict_keys:
# We explicitly add each dict key to the cache, so that cache hit
# rates for each key can be tracked separately.
value = entry.get(dict_key, _Sentinel.sentinel) # type: ignore[arg-type]
self.cache[(key, dict_key)] = _PerKeyValue(value)
if value is not _Sentinel.sentinel:
values[dict_key] = value
return DictionaryEntry(True, set(), values)
def invalidate(self, key: KT) -> None:
self.check_thread()
@@ -117,7 +243,9 @@ class DictionaryCache(Generic[KT, DKT, DV]):
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
self.cache.pop(key, None)
# Del-multi accepts truncated tuples.
self.cache.del_multi((key,)) # type: ignore[arg-type]
def invalidate_all(self) -> None:
self.check_thread()
@@ -149,20 +277,27 @@ class DictionaryCache(Generic[KT, DKT, DV]):
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
if fetched_keys is None:
self._insert(key, value, set())
self.cache[(key, _FullCacheKey.KEY)] = value
else:
self._update_or_insert(key, value, fetched_keys)
self._update_subset(key, value, fetched_keys)
def _update_or_insert(
self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
def _update_subset(
self, key: KT, value: Dict[DKT, DV], fetched_keys: Iterable[DKT]
) -> None:
# We pop and reinsert as we need to tell the cache the size may have
# changed
"""Add the given dictionary values as explicit keys in the cache.
entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
entry.value.update(value)
entry.known_absent.update(known_absent)
self.cache[key] = entry
Args:
key
value: The dictionary with all the values that we should cache
fetched_keys: The full set of keys that were looked up, any keys
here not in `value` should be marked as "known absent".
"""
def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
self.cache[key] = DictionaryEntry(True, known_absent, value)
for dict_key, dict_value in value.items():
self.cache[(key, dict_key)] = _PerKeyValue(dict_value)
for dict_key in fetched_keys:
if (key, dict_key) in self.cache:
continue
self.cache[(key, dict_key)] = _PerKeyValue(_Sentinel.sentinel)
+64 -3
View File
@@ -44,7 +44,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.metrics.jemalloc import get_jemalloc_stats
from synapse.util import Clock, caches
from synapse.util.caches import CacheMetric, EvictionReason, register_cache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.caches.treecache import (
TreeCache,
TreeCacheNode,
iterate_tree_cache_entry,
)
from synapse.util.linked_list import ListNode
if TYPE_CHECKING:
@@ -413,7 +417,7 @@ class LruCache(Generic[KT, VT]):
else:
real_clock = clock
cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
cache: Union[Dict[KT, _Node[KT, VT]], TreeCache[_Node[KT, VT]]] = cache_type()
self.cache = cache # Used for introspection.
self.apply_cache_factor_from_config = apply_cache_factor_from_config
@@ -537,6 +541,7 @@ class LruCache(Generic[KT, VT]):
default: Literal[None] = None,
callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ...,
update_last_access: bool = ...,
) -> Optional[VT]:
...
@@ -546,6 +551,7 @@ class LruCache(Generic[KT, VT]):
default: T,
callbacks: Collection[Callable[[], None]] = ...,
update_metrics: bool = ...,
update_last_access: bool = ...,
) -> Union[T, VT]:
...
@@ -555,10 +561,27 @@ class LruCache(Generic[KT, VT]):
default: Optional[T] = None,
callbacks: Collection[Callable[[], None]] = (),
update_metrics: bool = True,
update_last_access: bool = True,
) -> Union[None, T, VT]:
"""Lookup a key in the cache
Args:
key
default
callbacks: A collection of callbacks that will fire when the
node is removed from the cache (either due to invalidation
or expiry).
update_metrics: Whether to update the hit rate metrics
update_last_access: Whether to update the last access metrics
on a node if successfully fetched. These metrics are used
to determine when to remove the node from the cache. Set
to False if this fetch should *not* prevent a node from
being expired.
"""
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
if update_last_access:
move_node_to_front(node)
node.add_callbacks(callbacks)
if update_metrics and metrics:
metrics.inc_hits()
@@ -568,6 +591,42 @@ class LruCache(Generic[KT, VT]):
metrics.inc_misses()
return default
@overload
def cache_get_multi(
key: tuple,
default: Literal[None] = None,
update_metrics: bool = True,
) -> Union[None, TreeCacheNode]:
...
@overload
def cache_get_multi(
key: tuple,
default: T,
update_metrics: bool = True,
) -> Union[T, TreeCacheNode]:
...
@synchronized
def cache_get_multi(
key: tuple,
default: Optional[T] = None,
update_metrics: bool = True,
) -> Union[None, T, TreeCacheNode]:
"""Used only for `TreeCache` to fetch a subtree."""
assert isinstance(cache, TreeCache)
node = cache.get(key, None)
if node is not None:
if update_metrics and metrics:
metrics.inc_hits()
return node
else:
if update_metrics and metrics:
metrics.inc_misses()
return default
@synchronized
def cache_set(
key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
@@ -674,6 +733,8 @@ class LruCache(Generic[KT, VT]):
self.setdefault = cache_set_default
self.pop = cache_pop
self.del_multi = cache_del_multi
if cache_type is TreeCache:
self.get_multi = cache_get_multi
# `invalidate` is exposed for consistency with DeferredCache, so that it can be
# invalidated by the cache invalidation replication stream.
self.invalidate = cache_del_multi
+134 -39
View File
@@ -12,18 +12,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.
SENTINEL = object()
from enum import Enum
from typing import (
Any,
Dict,
Generator,
Generic,
List,
Literal,
Optional,
Tuple,
TypeVar,
Union,
overload,
)
class TreeCacheNode(dict):
class Sentinel(Enum):
sentinel = object()
V = TypeVar("V")
T = TypeVar("T")
class TreeCacheNode(Generic[V]):
"""The type of nodes in our tree.
Has its own type so we can distinguish it from real dicts that are stored at the
leaves.
Either a leaf node or a branch node.
"""
__slots__ = ["leaf_value", "sub_tree"]
class TreeCache:
def __init__(
self,
leaf_value: Union[V, Literal[Sentinel.sentinel]] = Sentinel.sentinel,
sub_tree: Optional[Dict[Any, "TreeCacheNode[V]"]] = None,
) -> None:
if leaf_value is Sentinel.sentinel and sub_tree is None:
raise Exception("One of leaf or sub tree must be set")
self.leaf_value: Union[V, Literal[Sentinel.sentinel]] = leaf_value
self.sub_tree: Optional[Dict[Any, "TreeCacheNode[V]"]] = sub_tree
@staticmethod
def leaf(value: V) -> "TreeCacheNode[V]":
return TreeCacheNode(leaf_value=value)
@staticmethod
def empty_branch() -> "TreeCacheNode[V]":
return TreeCacheNode(sub_tree={})
class TreeCache(Generic[V]):
"""
Tree-based backing store for LruCache. Allows subtrees of data to be deleted
efficiently.
@@ -35,15 +76,15 @@ class TreeCache:
def __init__(self) -> None:
self.size: int = 0
self.root = TreeCacheNode()
self.root: TreeCacheNode[V] = TreeCacheNode.empty_branch()
def __setitem__(self, key, value) -> None:
def __setitem__(self, key: tuple, value: V) -> None:
self.set(key, value)
def __contains__(self, key) -> bool:
return self.get(key, SENTINEL) is not SENTINEL
def __contains__(self, key: tuple) -> bool:
return self.get(key, None) is not None
def set(self, key, value) -> None:
def set(self, key: tuple, value: V) -> None:
if isinstance(value, TreeCacheNode):
# this would mean we couldn't tell where our tree ended and the value
# started.
@@ -51,31 +92,56 @@ class TreeCache:
node = self.root
for k in key[:-1]:
next_node = node.get(k, SENTINEL)
if next_node is SENTINEL:
next_node = node[k] = TreeCacheNode()
elif not isinstance(next_node, TreeCacheNode):
# this suggests that the caller is not being consistent with its key
# length.
sub_tree = node.sub_tree
if sub_tree is None:
raise ValueError("value conflicts with an existing subtree")
node = next_node
node[key[-1]] = value
next_node = sub_tree.get(k, None)
if next_node is None:
node = TreeCacheNode.empty_branch()
sub_tree[k] = node
else:
node = next_node
if node.sub_tree is None:
raise ValueError("value conflicts with an existing subtree")
node.sub_tree[key[-1]] = TreeCacheNode.leaf(value)
self.size += 1
def get(self, key, default=None):
@overload
def get(self, key: tuple, default: Literal[None] = None) -> Union[None, V]:
...
@overload
def get(self, key: tuple, default: T) -> Union[T, V]:
...
def get(self, key: tuple, default: Optional[T] = None) -> Union[None, T, V]:
node = self.root
for k in key[:-1]:
node = node.get(k, None)
if node is None:
for k in key:
sub_tree = node.sub_tree
if sub_tree is None:
raise ValueError("get() key too long")
next_node = sub_tree.get(k, None)
if next_node is None:
return default
return node.get(key[-1], default)
node = next_node
if node.leaf_value is Sentinel.sentinel:
raise ValueError("key points to a branch")
return node.leaf_value
def clear(self) -> None:
self.size = 0
self.root = TreeCacheNode()
def pop(self, key, default=None):
def pop(
self, key: tuple, default: Optional[T] = None
) -> Union[None, T, V, TreeCacheNode[V]]:
"""Remove the given key, or subkey, from the cache
Args:
@@ -91,20 +157,25 @@ class TreeCache:
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
# a list of the nodes we have touched on the way down the tree
nodes = []
nodes: List[TreeCacheNode[V]] = []
node = self.root
for k in key[:-1]:
node = node.get(k, None)
if node is None:
return default
if not isinstance(node, TreeCacheNode):
# we've gone off the end of the tree
sub_tree = node.sub_tree
if sub_tree is None:
raise ValueError("pop() key too long")
nodes.append(node) # don't add the root node
popped = node.pop(key[-1], SENTINEL)
if popped is SENTINEL:
return default
next_node = sub_tree.get(k, None)
if next_node is None:
return default
node = next_node
nodes.append(node)
if node.sub_tree is None:
raise ValueError("pop() key too long")
popped = node.sub_tree.pop(key[-1])
# working back up the tree, clear out any nodes that are now empty
node_and_keys = list(zip(nodes, key))
@@ -116,8 +187,13 @@ class TreeCache:
if n:
break
# found an empty node: remove it from its parent, and loop.
node_and_keys[i + 1][0].pop(k)
node = node_and_keys[i + 1][0]
# We added it to the list so already know its a branch node.
assert node.sub_tree is not None
node.sub_tree.pop(k)
cnt = sum(1 for _ in iterate_tree_cache_entry(popped))
self.size -= cnt
@@ -130,12 +206,31 @@ class TreeCache:
return self.size
def iterate_tree_cache_entry(d):
def iterate_tree_cache_entry(d: TreeCacheNode[V]) -> Generator[V, None, None]:
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
can contain dicts.
"""
if isinstance(d, TreeCacheNode):
for value_d in d.values():
if d.sub_tree is not None:
for value_d in d.sub_tree.values():
yield from iterate_tree_cache_entry(value_d)
else:
yield d
assert d.leaf_value is not Sentinel.sentinel
yield d.leaf_value
def iterate_tree_cache_items(
key: tuple, value: TreeCacheNode[V]
) -> Generator[Tuple[tuple, V], None, None]:
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
can contain dicts.
Returns:
A generator yielding key/value pairs.
"""
if value.sub_tree is not None:
for sub_key, sub_value in value.sub_tree.items():
yield from iterate_tree_cache_items((*key, sub_key), sub_value)
else:
assert value.leaf_value is not Sentinel.sentinel
yield key, value.leaf_value
+2 -3
View File
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
@@ -51,7 +50,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEqual(HTTPStatus.OK, channel.code)
self.assertEqual(200, channel.code)
complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity)
@@ -63,7 +62,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEqual(HTTPStatus.OK, channel.code)
self.assertEqual(200, channel.code)
complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23)
+6 -68
View File
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from http import HTTPStatus
from parameterized import parameterized
@@ -59,7 +58,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
query_content,
)
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
self.assertEqual(400, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
@@ -120,7 +119,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,)
)
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, channel.result)
self.assertEqual(403, channel.code, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -148,13 +147,13 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
tok2 = self.login("fozzie", "bear")
self.helper.join(self._room_id, second_member_user_id, tok=tok2)
def _make_join(self, user_id: str) -> JsonDict:
def _make_join(self, user_id) -> JsonDict:
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
f"?ver={DEFAULT_ROOM_VERSION}",
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
self.assertEqual(channel.code, 200, channel.json_body)
return channel.json_body
def test_send_join(self):
@@ -172,7 +171,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
content=join_event_dict,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
self.assertEqual(channel.code, 200, channel.json_body)
# we should get complete room state back
returned_state = [
@@ -227,7 +226,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
content=join_event_dict,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
self.assertEqual(channel.code, 200, channel.json_body)
# expect a reduced room state
returned_state = [
@@ -260,67 +259,6 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}})
def test_make_join_respects_room_join_rate_limit(self) -> None:
# In the test setup, two users join the room. Since the rate limiter burst
# count is 3, a new make_join request to the room should be accepted.
joining_user = "@ronniecorbett:" + self.OTHER_SERVER_NAME
self._make_join(joining_user)
# Now have a new local user join the room. This saturates the rate limiter
# bucket, so the next make_join should be denied.
new_local_user = self.register_user("animal", "animal")
token = self.login("animal", "animal")
self.helper.join(self._room_id, new_local_user, tok=token)
joining_user = "@ronniebarker:" + self.OTHER_SERVER_NAME
channel = self.make_signed_federation_request(
"GET",
f"/_matrix/federation/v1/make_join/{self._room_id}/{joining_user}"
f"?ver={DEFAULT_ROOM_VERSION}",
)
self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body)
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}})
def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> None:
# Make two make_join requests up front. (These are rate limited, but do not
# contribute to the rate limit.)
join_event_dicts = []
for i in range(2):
joining_user = f"@misspiggy{i}:{self.OTHER_SERVER_NAME}"
join_result = self._make_join(joining_user)
join_event_dict = join_result["event"]
self.add_hashes_and_signatures_from_other_server(
join_event_dict,
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
)
join_event_dicts.append(join_event_dict)
# In the test setup, two users join the room. Since the rate limiter burst
# count is 3, the first send_join should be accepted...
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/join0",
content=join_event_dicts[0],
)
self.assertEqual(channel.code, 200, channel.json_body)
# ... but the second should be denied.
channel = self.make_signed_federation_request(
"PUT",
f"/_matrix/federation/v2/send_join/{self._room_id}/join1",
content=join_event_dicts[1],
)
self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body)
# NB: we could write a test which checks that the send_join event is seen
# by other workers over replication, and that they update their rate limit
# buckets accordingly. I'm going to assume that the join event gets sent over
# replication, at which point the tests.handlers.room_member test
# test_local_users_joining_on_another_worker_contribute_to_rate_limit
# is probably sufficient to reassure that the bucket is updated.
def _create_acl_event(content):
return make_event_from_dict(
+2 -3
View File
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
from http import HTTPStatus
from typing import Dict, List
from synapse.api.constants import EventTypes, JoinRules, Membership
@@ -256,7 +255,7 @@ class FederationKnockingTestCase(
RoomVersions.V7.identifier,
),
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertEqual(200, channel.code, channel.result)
# Note: We don't expect the knock membership event to be sent over federation as
# part of the stripped room state, as the knocking homeserver already has that
@@ -294,7 +293,7 @@ class FederationKnockingTestCase(
% (room_id, signed_knock_event.event_id),
signed_knock_event_json,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertEqual(200, channel.code, channel.result)
# Check that we got the stripped room state in return
room_state_events = channel.json_body["knock_state_events"]
+20 -21
View File
@@ -14,7 +14,6 @@
"""Tests for the password_auth_provider interface"""
from http import HTTPStatus
from typing import Any, Type, Union
from unittest.mock import Mock
@@ -189,14 +188,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
mock_password_provider.reset_mock()
# login with mxid should work too
channel = self._send_password_login("@u:bz", "p")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:bz", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
mock_password_provider.reset_mock()
@@ -205,7 +204,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# in these cases, but at least we can guard against the API changing
# unexpectedly
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
mock_password_provider.check_password.assert_called_once_with(
"@ USER🙂NAME :test", " pASS😢word "
@@ -259,10 +258,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# check_password must return an awaitable
mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
self.assertEqual(channel.code, 403, channel.result)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
@@ -383,7 +382,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_password.assert_not_called()
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
@@ -407,14 +406,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login with missing param should be rejected
channel = self._send_login("test.login_type", "u")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
@@ -428,7 +427,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
@@ -511,7 +510,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
mock_password_provider.check_auth.assert_called_once_with(
"u", "test.login_type", {"test_field": "y"}
@@ -550,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
@@ -585,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
@override_config(
@@ -616,7 +615,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# login shouldn't work and should be rejected with a 400 ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
mock_password_provider.check_password.assert_not_called()
@@ -647,13 +646,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
tok1 = channel.json_body["access_token"]
channel = self._send_login(
"test.login_type", "localuser", test_field="", device_id="dev2"
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
# make the initial request which returns a 401
channel = self._delete_device(tok1, "dev2")
@@ -722,7 +721,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# password login shouldn't work and should be rejected with a 400
# ("unknown login type")
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.code, 400, channel.result)
def test_on_logged_out(self):
"""Tests that the on_logged_out callback is called when the user logs out."""
@@ -885,7 +884,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
},
access_token=tok,
)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
self.assertEqual(channel.code, 403, channel.result)
self.assertEqual(
channel.json_body["errcode"],
Codes.THREEPID_DENIED,
@@ -907,7 +906,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
},
access_token=tok,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
self.assertIn("sid", channel.json_body)
m.assert_called_once_with("email", "bar@test.com", registration)
@@ -950,12 +949,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"register",
{"auth": {"session": session, "type": LoginType.DUMMY}},
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
self.assertEqual(channel.code, 200, channel.json_body)
return channel.json_body
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["flows"]
def _send_password_login(self, user: str, password: str) -> FakeChannel:
-290
View File
@@ -1,290 +0,0 @@
from http import HTTPStatus
from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.rest.client.login
import synapse.rest.client.room
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import LimitExceededError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import FrozenEventV3
from synapse.federation.federation_client import SendJoinResult
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util import Clock
from tests.replication._base import RedisMultiWorkerStreamTestCase
from tests.server import make_request
from tests.test_utils import make_awaitable
from tests.unittest import FederatingHomeserverTestCase, override_config
class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.client.login.register_servlets,
synapse.rest.client.room.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_room_member_handler()
# Create three users.
self.alice = self.register_user("alice", "pass")
self.alice_token = self.login("alice", "pass")
self.bob = self.register_user("bob", "pass")
self.bob_token = self.login("bob", "pass")
self.chris = self.register_user("chris", "pass")
self.chris_token = self.login("chris", "pass")
# Create a room on this homeserver. Note that this counts as a join: it
# contributes to the rate limter's count of actions
self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}"
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None:
# The rate limiter has accumulated one token from Alice's join after the create
# event.
# Try joining the room as Bob.
self.get_success(
self.handler.update_membership(
requester=create_requester(self.bob),
target=UserID.from_string(self.bob),
room_id=self.room_id,
action=Membership.JOIN,
)
)
# The rate limiter bucket is full. A second join should be denied.
self.get_failure(
self.handler.update_membership(
requester=create_requester(self.chris),
target=UserID.from_string(self.chris),
room_id=self.room_id,
action=Membership.JOIN,
),
LimitExceededError,
)
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None:
# The rate limiter has accumulated one token from Alice's join after the create
# event. Alice should still be able to change her displayname.
self.get_success(
self.handler.update_membership(
requester=create_requester(self.alice),
target=UserID.from_string(self.alice),
room_id=self.room_id,
action=Membership.JOIN,
content={"displayname": "Alice Cooper"},
)
)
# Still room in the limiter bucket. Chris's join should be accepted.
self.get_success(
self.handler.update_membership(
requester=create_requester(self.chris),
target=UserID.from_string(self.chris),
room_id=self.room_id,
action=Membership.JOIN,
)
)
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 1}})
def test_remote_joins_contribute_to_rate_limit(self) -> None:
# Join once, to fill the rate limiter bucket.
#
# To do this we have to mock the responses from the remote homeserver.
# We also patch out a bunch of event checks on our end. All we're really
# trying to check here is that remote joins will bump the rate limter when
# they are persisted.
create_event_source = {
"auth_events": [],
"content": {
"creator": f"@creator:{self.OTHER_SERVER_NAME}",
"room_version": self.hs.config.server.default_room_version.identifier,
},
"depth": 0,
"origin_server_ts": 0,
"prev_events": [],
"room_id": self.intially_unjoined_room_id,
"sender": f"@creator:{self.OTHER_SERVER_NAME}",
"state_key": "",
"type": EventTypes.Create,
}
self.add_hashes_and_signatures_from_other_server(
create_event_source,
self.hs.config.server.default_room_version,
)
create_event = FrozenEventV3(
create_event_source,
self.hs.config.server.default_room_version,
{},
None,
)
join_event_source = {
"auth_events": [create_event.event_id],
"content": {"membership": "join"},
"depth": 1,
"origin_server_ts": 100,
"prev_events": [create_event.event_id],
"sender": self.bob,
"state_key": self.bob,
"room_id": self.intially_unjoined_room_id,
"type": EventTypes.Member,
}
add_hashes_and_signatures(
self.hs.config.server.default_room_version,
join_event_source,
self.hs.hostname,
self.hs.signing_key,
)
join_event = FrozenEventV3(
join_event_source,
self.hs.config.server.default_room_version,
{},
None,
)
mock_make_membership_event = Mock(
return_value=make_awaitable(
(
self.OTHER_SERVER_NAME,
join_event,
self.hs.config.server.default_room_version,
)
)
)
mock_send_join = Mock(
return_value=make_awaitable(
SendJoinResult(
join_event,
self.OTHER_SERVER_NAME,
state=[create_event],
auth_chain=[create_event],
partial_state=False,
servers_in_room=[],
)
)
)
with patch.object(
self.handler.federation_handler.federation_client,
"make_membership_event",
mock_make_membership_event,
), patch.object(
self.handler.federation_handler.federation_client,
"send_join",
mock_send_join,
), patch(
"synapse.event_auth._is_membership_change_allowed",
return_value=None,
), patch(
"synapse.handlers.federation_event.check_state_dependent_auth_rules",
return_value=None,
):
self.get_success(
self.handler.update_membership(
requester=create_requester(self.bob),
target=UserID.from_string(self.bob),
room_id=self.intially_unjoined_room_id,
action=Membership.JOIN,
remote_room_hosts=[self.OTHER_SERVER_NAME],
)
)
# Try to join as Chris. Should get denied.
self.get_failure(
self.handler.update_membership(
requester=create_requester(self.chris),
target=UserID.from_string(self.chris),
room_id=self.intially_unjoined_room_id,
action=Membership.JOIN,
remote_room_hosts=[self.OTHER_SERVER_NAME],
),
LimitExceededError,
)
# TODO: test that remote joins to a room are rate limited.
# Could do this by setting the burst count to 1, then:
# - remote-joining a room
# - immediately leaving
# - trying to remote-join again.
class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.client.login.register_servlets,
synapse.rest.client.room.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_room_member_handler()
# Create three users.
self.alice = self.register_user("alice", "pass")
self.alice_token = self.login("alice", "pass")
self.bob = self.register_user("bob", "pass")
self.bob_token = self.login("bob", "pass")
self.chris = self.register_user("chris", "pass")
self.chris_token = self.login("chris", "pass")
# Create a room on this homeserver.
# Note that this counts as a
self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
self.intially_unjoined_room_id = "!example:otherhs"
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
def test_local_users_joining_on_another_worker_contribute_to_rate_limit(
self,
) -> None:
# The rate limiter has accumulated one token from Alice's join after the create
# event.
self.replicate()
# Spawn another worker and have bob join via it.
worker_app = self.make_worker_hs(
"synapse.app.generic_worker", extra_config={"worker_name": "other worker"}
)
worker_site = self._hs_to_site[worker_app]
channel = make_request(
self.reactor,
worker_site,
"POST",
f"/_matrix/client/v3/rooms/{self.room_id}/join",
access_token=self.bob_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
# wait for join to arrive over replication
self.replicate()
# Try to join as Chris on the worker. Should get denied because Alice
# and Bob have both joined the room.
self.get_failure(
worker_app.get_room_member_handler().update_membership(
requester=create_requester(self.chris),
target=UserID.from_string(self.chris),
room_id=self.room_id,
action=Membership.JOIN,
),
LimitExceededError,
)
# Try to join as Chris on the original worker. Should get denied because Alice
# and Bob have both joined the room.
self.get_failure(
self.handler.update_membership(
requester=create_requester(self.chris),
target=UserID.from_string(self.chris),
room_id=self.room_id,
action=Membership.JOIN,
),
LimitExceededError,
)
+8 -8
View File
@@ -1379,7 +1379,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1434,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
@@ -1512,7 +1512,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123", "admin": False},
)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@@ -1550,7 +1550,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
)
# Admin user is not blocked by mau anymore
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertFalse(channel.json_body["admin"])
@@ -1585,7 +1585,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1626,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
@@ -1666,7 +1666,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content=body,
)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
@@ -2407,7 +2407,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
content={"password": "abc123"},
)
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
self.assertEqual(201, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["name"])
self.assertEqual("bob", channel.json_body["displayname"])
+44 -50
View File
@@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import re
from email.parser import Parser
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock
@@ -95,8 +95,10 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
"""
body = {"type": "m.login.password", "user": username, "password": password}
channel = self.make_request("POST", "/_matrix/client/r0/login", body)
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
self.assertEqual(channel.code, 403, channel.result)
def test_basic_password_reset(self) -> None:
"""Test basic password reset flow"""
@@ -345,7 +347,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertEqual(200, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
# password reset confirm button
@@ -360,7 +362,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
content_is_form=True,
)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertEqual(200, channel.code, channel.result)
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
@@ -388,7 +390,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
new_password: str,
session_id: str,
client_secret: str,
expected_code: int = HTTPStatus.OK,
expected_code: int = 200,
) -> None:
channel = self.make_request(
"POST",
@@ -477,14 +479,16 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.assertEqual(memberships[0].room_id, room_id, memberships)
def deactivate(self, user_id: str, tok: str) -> None:
request_data = {
"auth": {
"type": "m.login.password",
"user": user_id,
"password": "test",
},
"erase": False,
}
request_data = json.dumps(
{
"auth": {
"type": "m.login.password",
"user": user_id,
"password": "test",
},
"erase": False,
}
)
channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
@@ -711,9 +715,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -723,7 +725,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email(self) -> None:
@@ -745,7 +747,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": self.email},
access_token=self.user_id_tok,
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -754,7 +756,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email_if_disabled(self) -> None:
@@ -779,9 +781,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -791,7 +791,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
@@ -817,9 +817,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -829,7 +827,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_no_valid_token(self) -> None:
@@ -854,9 +852,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -866,7 +862,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@override_config({"next_link_domain_whitelist": None})
@@ -876,7 +872,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="https://example.com/a/good/site",
expect_code=HTTPStatus.OK,
expect_code=200,
)
@override_config({"next_link_domain_whitelist": None})
@@ -888,7 +884,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
expect_code=HTTPStatus.OK,
expect_code=200,
)
@override_config({"next_link_domain_whitelist": None})
@@ -899,7 +895,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="file:///host/path",
expect_code=HTTPStatus.BAD_REQUEST,
expect_code=400,
)
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
@@ -911,28 +907,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link=None,
expect_code=HTTPStatus.OK,
expect_code=200,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://example.com/some/good/page",
expect_code=HTTPStatus.OK,
expect_code=200,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://example.org/some/also/good/page",
expect_code=HTTPStatus.OK,
expect_code=200,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://bad.example.org/some/bad/page",
expect_code=HTTPStatus.BAD_REQUEST,
expect_code=400,
)
@override_config({"next_link_domain_whitelist": []})
@@ -944,7 +940,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="https://example.com/a/page",
expect_code=HTTPStatus.BAD_REQUEST,
expect_code=400,
)
def _request_token(
@@ -952,7 +948,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
email: str,
client_secret: str,
next_link: Optional[str] = None,
expect_code: int = HTTPStatus.OK,
expect_code: int = 200,
) -> Optional[str]:
"""Request a validation token to add an email address to a user's account
@@ -997,9 +993,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
)
self.assertEqual(
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
)
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(expected_errcode, channel.json_body["errcode"])
self.assertEqual(expected_error, channel.json_body["error"])
@@ -1008,7 +1002,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
path = link.replace("https://example.com", "")
channel = self.make_request("GET", path, shorthand=False)
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertEqual(200, channel.code, channel.result)
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
@@ -1058,7 +1052,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -1067,7 +1061,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
@@ -1098,7 +1092,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that not providing any MXID raises an error."""
self._test_status(
users=None,
expected_status_code=HTTPStatus.BAD_REQUEST,
expected_status_code=400,
expected_errcode=Codes.MISSING_PARAM,
)
@@ -1106,7 +1100,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that providing an invalid MXID raises an error."""
self._test_status(
users=["bad:test"],
expected_status_code=HTTPStatus.BAD_REQUEST,
expected_status_code=400,
expected_errcode=Codes.INVALID_PARAM,
)
@@ -1292,7 +1286,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
def _test_status(
self,
users: Optional[List[str]],
expected_status_code: int = HTTPStatus.OK,
expected_status_code: int = 200,
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
expected_failures: Optional[List[str]] = None,
expected_errcode: Optional[str] = None,
+11 -5
View File
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
@@ -96,7 +97,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# We use deliberately a localpart under the length threshold so
# that we can make sure that the check is done on the whole alias.
request_data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
request_data = json.dumps(data)
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
@@ -108,7 +110,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# Check with an alias of allowed length. There should already be
# a test that ensures it works in test_register.py, but let's be
# as cautious as possible here.
request_data = {"room_alias_name": random_string(5)}
data = {"room_alias_name": random_string(5)}
request_data = json.dumps(data)
channel = self.make_request(
"POST", url, request_data, access_token=self.user_tok
)
@@ -141,7 +144,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
# Add an alias for the room, as the appservice
alias = RoomAlias(f"asns-{random_string(5)}", self.hs.hostname).to_string()
request_data = {"room_id": self.room_id}
data = {"room_id": self.room_id}
request_data = json.dumps(data)
channel = self.make_request(
"PUT",
@@ -189,7 +193,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.hs.hostname,
)
request_data = {"aliases": [self.random_alias(alias_length)]}
data = {"aliases": [self.random_alias(alias_length)]}
request_data = json.dumps(data)
channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
@@ -201,7 +206,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
) -> str:
alias = self.random_alias(alias_length)
url = "/_matrix/client/r0/directory/room/%s" % alias
request_data = {"room_id": self.room_id}
data = {"room_id": self.room_id}
request_data = json.dumps(data)
channel = self.make_request(
"PUT", url, request_data, access_token=self.user_tok
+3 -1
View File
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
@@ -50,11 +51,12 @@ class IdentityTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
room_id = channel.json_body["room_id"]
request_data = {
params = {
"id_server": "testis",
"medium": "email",
"address": "test@example.com",
}
request_data = json.dumps(params)
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
channel = self.make_request(
b"POST", request_url, request_data, access_token=tok
+29 -29
View File
@@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time
import urllib.parse
from http import HTTPStatus
from typing import Any, Dict, List, Optional
from unittest.mock import Mock
from urllib.parse import urlencode
@@ -261,20 +261,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
}
channel = self.make_request(b"POST", LOGIN_URL, params)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"]
device_id = channel.json_body["device_id"]
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -288,7 +288,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# more requests with the expired token should still return a soft-logout
self.reactor.advance(3600)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -296,7 +296,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self._delete_device(access_token_2, "kermit", "monkey", device_id)
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], False)
@@ -307,7 +307,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
b"DELETE", "devices/" + device_id, access_token=access_token
)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.code, 401, channel.result)
# check it's a UI-Auth fail
self.assertEqual(
set(channel.json_body.keys()),
@@ -330,7 +330,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
access_token=access_token,
content={"auth": auth},
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
@override_config({"session_lifetime": "24h"})
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
@@ -341,14 +341,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -367,14 +367,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# we should now be able to make requests with the access token
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
# time passes
self.reactor.advance(24 * 3600)
# ... and we should be soft-logouted
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
self.assertEqual(channel.code, 401, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
self.assertEqual(channel.json_body["soft_logout"], True)
@@ -399,7 +399,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST",
"/_matrix/client/v3/login",
body,
json.dumps(body).encode("utf8"),
custom_headers=None,
)
@@ -466,7 +466,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
def test_get_login_flows(self) -> None:
"""GET /login should return password and SSO flows"""
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
expected_flow_types = [
"m.login.cas",
@@ -494,14 +494,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"""/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker
channel = self._make_sso_redirect_request(None)
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
uri = location_headers[0]
# hitting that picker should give us some HTML
channel = self.make_request("GET", uri)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
html = channel.result["body"].decode("utf-8")
@@ -530,7 +530,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=cas",
shorthand=False,
)
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
cas_uri = location_headers[0]
@@ -555,7 +555,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml",
)
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
saml_uri = location_headers[0]
@@ -579,7 +579,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=oidc",
)
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
@@ -606,7 +606,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
# that should serve a confirmation page
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
content_type_headers = channel.headers.getRawHeaders("Content-Type")
assert content_type_headers
self.assertTrue(content_type_headers[-1].startswith("text/html"))
@@ -634,7 +634,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"/login",
content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@user1:test")
def test_multi_sso_redirect_to_unknown(self) -> None:
@@ -643,18 +643,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
"GET",
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
self.assertEqual(channel.code, 400, channel.result)
def test_client_idp_redirect_to_unknown(self) -> None:
"""If the client tries to pick an unknown IdP, return a 404"""
channel = self._make_sso_redirect_request("xxx")
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
self.assertEqual(channel.code, 404, channel.result)
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
def test_client_idp_redirect_to_oidc(self) -> None:
"""If the client pick a known IdP, redirect to it"""
channel = self._make_sso_redirect_request("oidc")
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
@@ -765,7 +765,7 @@ class CASTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", cas_ticket_url)
# Test that the response is HTML.
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual(channel.code, 200, channel.result)
content_type_header_value = ""
for header in channel.result.get("headers", []):
if header[0] == b"Content-Type":
@@ -1246,7 +1246,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
# that should redirect to the username picker
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
self.assertEqual(channel.code, 302, channel.result)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
picker_url = location_headers[0]
@@ -1290,7 +1290,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
("Content-Length", str(len(content))),
],
)
self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
@@ -1300,7 +1300,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
path=location_headers[0],
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
)
self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
@@ -1325,5 +1325,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
"/login",
content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
self.assertEqual(chan.code, 200, chan.result)
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
+17 -14
View File
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from http import HTTPStatus
from twisted.test.proto_helpers import MemoryReactor
@@ -88,7 +89,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_too_short(self) -> None:
request_data = {"username": "kermit", "password": "shorty"}
request_data = json.dumps({"username": "kermit", "password": "shorty"})
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -99,7 +100,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_digit(self) -> None:
request_data = {"username": "kermit", "password": "longerpassword"}
request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -110,7 +111,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_symbol(self) -> None:
request_data = {"username": "kermit", "password": "l0ngerpassword"}
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -121,7 +122,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_uppercase(self) -> None:
request_data = {"username": "kermit", "password": "l0ngerpassword!"}
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -132,7 +133,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_no_lowercase(self) -> None:
request_data = {"username": "kermit", "password": "L0NGERPASSWORD!"}
request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
channel = self.make_request("POST", self.register_url, request_data)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
@@ -143,7 +144,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
)
def test_password_compliant(self) -> None:
request_data = {"username": "kermit", "password": "L0ngerpassword!"}
request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
channel = self.make_request("POST", self.register_url, request_data)
# Getting a 401 here means the password has passed validation and the server has
@@ -160,14 +161,16 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
user_id = self.register_user("kermit", compliant_password)
tok = self.login("kermit", compliant_password)
request_data = {
"new_password": not_compliant_password,
"auth": {
"password": compliant_password,
"type": LoginType.PASSWORD,
"user": user_id,
},
}
request_data = json.dumps(
{
"new_password": not_compliant_password,
"auth": {
"password": compliant_password,
"type": LoginType.PASSWORD,
"user": user_id,
},
}
)
channel = self.make_request(
"POST",
"/_matrix/client/r0/account/password",
+57 -46
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import json
import os
from typing import Any, Dict, List, Tuple
@@ -61,10 +62,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
self.hs.get_datastores().main.services_cache.append(appservice)
request_data = {
"username": "as_user_kermit",
"type": APP_SERVICE_REGISTRATION_TYPE,
}
request_data = json.dumps(
{"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
)
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
@@ -85,7 +85,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
self.hs.get_datastores().main.services_cache.append(appservice)
request_data = {"username": "as_user_kermit"}
request_data = json.dumps({"username": "as_user_kermit"})
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
@@ -95,7 +95,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_appservice_registration_invalid(self) -> None:
self.appservice = None # no application service exists
request_data = {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
request_data = json.dumps(
{"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
)
channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
)
@@ -103,14 +105,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"401", channel.result)
def test_POST_bad_password(self) -> None:
request_data = {"username": "kermit", "password": 666}
request_data = json.dumps({"username": "kermit", "password": 666})
channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"400", channel.result)
self.assertEqual(channel.json_body["error"], "Invalid password")
def test_POST_bad_username(self) -> None:
request_data = {"username": 777, "password": "monkey"}
request_data = json.dumps({"username": 777, "password": "monkey"})
channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"400", channel.result)
@@ -119,12 +121,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
def test_POST_user_valid(self) -> None:
user_id = "@kermit:test"
device_id = "frogfone"
request_data = {
params = {
"username": "kermit",
"password": "monkey",
"device_id": device_id,
"auth": {"type": LoginType.DUMMY},
}
request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
det_data = {
@@ -137,7 +140,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@override_config({"enable_registration": False})
def test_POST_disabled_registration(self) -> None:
request_data = {"username": "kermit", "password": "monkey"}
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
channel = self.make_request(b"POST", self.url, request_data)
@@ -185,12 +188,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
def test_POST_ratelimiting(self) -> None:
for i in range(0, 6):
request_data = {
params = {
"username": "kermit" + str(i),
"password": "monkey",
"device_id": "frogfone",
"auth": {"type": LoginType.DUMMY},
}
request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
if i == 5:
@@ -230,7 +234,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
}
# Request without auth to get flows and session
channel = self.make_request(b"POST", self.url, params)
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEqual(channel.result["code"], b"401", channel.result)
flows = channel.json_body["flows"]
# Synapse adds a dummy stage to differentiate flows where otherwise one
@@ -247,7 +251,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session,
}
channel = self.make_request(b"POST", self.url, params)
request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
self.assertEqual(channel.result["code"], b"401", channel.result)
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@@ -257,7 +262,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"type": LoginType.DUMMY,
"session": session,
}
channel = self.make_request(b"POST", self.url, params)
request_data = json.dumps(params)
channel = self.make_request(b"POST", self.url, request_data)
det_data = {
"user_id": f"@{username}:{self.hs.hostname}",
"home_server": self.hs.hostname,
@@ -284,7 +290,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"password": "monkey",
}
# Request without auth to get session
channel = self.make_request(b"POST", self.url, params)
channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"]
# Test with token param missing (invalid)
@@ -292,21 +298,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"type": LoginType.REGISTRATION_TOKEN,
"session": session,
}
channel = self.make_request(b"POST", self.url, params)
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
self.assertEqual(channel.json_body["completed"], [])
# Test with non-string (invalid)
params["auth"]["token"] = 1234
channel = self.make_request(b"POST", self.url, params)
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
self.assertEqual(channel.json_body["completed"], [])
# Test with unknown token (invalid)
params["auth"]["token"] = "1234"
channel = self.make_request(b"POST", self.url, params)
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -331,9 +337,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
params1: JsonDict = {"username": "bert", "password": "monkey"}
params2: JsonDict = {"username": "ernie", "password": "monkey"}
# Do 2 requests without auth to get two session IDs
channel1 = self.make_request(b"POST", self.url, params1)
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
session1 = channel1.json_body["session"]
channel2 = self.make_request(b"POST", self.url, params2)
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
session2 = channel2.json_body["session"]
# Use token with session1 and check `pending` is 1
@@ -342,9 +348,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session1,
}
self.make_request(b"POST", self.url, params1)
self.make_request(b"POST", self.url, json.dumps(params1))
# Repeat request to make sure pending isn't increased again
self.make_request(b"POST", self.url, params1)
self.make_request(b"POST", self.url, json.dumps(params1))
pending = self.get_success(
store.db_pool.simple_select_one_onecol(
"registration_tokens",
@@ -360,14 +366,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session2,
}
channel = self.make_request(b"POST", self.url, params2)
channel = self.make_request(b"POST", self.url, json.dumps(params2))
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
# Complete registration with session1
params1["auth"]["type"] = LoginType.DUMMY
self.make_request(b"POST", self.url, params1)
self.make_request(b"POST", self.url, json.dumps(params1))
# Check pending=0 and completed=1
res = self.get_success(
store.db_pool.simple_select_one(
@@ -380,7 +386,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEqual(res["completed"], 1)
# Check auth still fails when using token with session2
channel = self.make_request(b"POST", self.url, params2)
channel = self.make_request(b"POST", self.url, json.dumps(params2))
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -405,7 +411,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
params: JsonDict = {"username": "kermit", "password": "monkey"}
# Request without auth to get session
channel = self.make_request(b"POST", self.url, params)
channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"]
# Check authentication fails with expired token
@@ -414,7 +420,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session,
}
channel = self.make_request(b"POST", self.url, params)
channel = self.make_request(b"POST", self.url, json.dumps(params))
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
self.assertEqual(channel.json_body["completed"], [])
@@ -429,7 +435,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
)
# Check authentication succeeds
channel = self.make_request(b"POST", self.url, params)
channel = self.make_request(b"POST", self.url, json.dumps(params))
completed = channel.json_body["completed"]
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
@@ -454,9 +460,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Do 2 requests without auth to get two session IDs
params1: JsonDict = {"username": "bert", "password": "monkey"}
params2: JsonDict = {"username": "ernie", "password": "monkey"}
channel1 = self.make_request(b"POST", self.url, params1)
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
session1 = channel1.json_body["session"]
channel2 = self.make_request(b"POST", self.url, params2)
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
session2 = channel2.json_body["session"]
# Use token with both sessions
@@ -465,18 +471,18 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session1,
}
self.make_request(b"POST", self.url, params1)
self.make_request(b"POST", self.url, json.dumps(params1))
params2["auth"] = {
"type": LoginType.REGISTRATION_TOKEN,
"token": token,
"session": session2,
}
self.make_request(b"POST", self.url, params2)
self.make_request(b"POST", self.url, json.dumps(params2))
# Complete registration with session1
params1["auth"]["type"] = LoginType.DUMMY
self.make_request(b"POST", self.url, params1)
self.make_request(b"POST", self.url, json.dumps(params1))
# Check `result` of registration token stage for session1 is `True`
result1 = self.get_success(
@@ -544,7 +550,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
# Do request without auth to get a session ID
params: JsonDict = {"username": "kermit", "password": "monkey"}
channel = self.make_request(b"POST", self.url, params)
channel = self.make_request(b"POST", self.url, json.dumps(params))
session = channel.json_body["session"]
# Use token
@@ -553,7 +559,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
"token": token,
"session": session,
}
self.make_request(b"POST", self.url, params)
self.make_request(b"POST", self.url, json.dumps(params))
# Delete token
self.get_success(
@@ -821,7 +827,8 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
url = "/_synapse/admin/v1/account_validity/validity"
request_data = {"user_id": user_id}
params = {"user_id": user_id}
request_data = json.dumps(params)
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEqual(channel.result["code"], b"200", channel.result)
@@ -838,11 +845,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
url = "/_synapse/admin/v1/account_validity/validity"
request_data = {
params = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEqual(channel.result["code"], b"200", channel.result)
@@ -862,11 +870,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
admin_tok = self.login("admin", "adminpassword")
url = "/_synapse/admin/v1/account_validity/validity"
request_data = {
params = {
"user_id": user_id,
"expiration_ts": 0,
"enable_renewal_emails": False,
}
request_data = json.dumps(params)
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
self.assertEqual(channel.result["code"], b"200", channel.result)
@@ -1032,14 +1041,16 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
(user_id, tok) = self.create_user()
request_data = {
"auth": {
"type": "m.login.password",
"user": user_id,
"password": "monkey",
},
"erase": False,
}
request_data = json.dumps(
{
"auth": {
"type": "m.login.password",
"user": user_id,
"password": "monkey",
},
"erase": False,
}
)
channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
+6 -1
View File
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -75,7 +77,10 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
def _assert_status(self, response_status: int, data: JsonDict) -> None:
channel = self.make_request(
"POST", self.report_path, data, access_token=self.other_user_tok
"POST",
self.report_path,
json.dumps(data),
access_token=self.other_user_tok,
)
self.assertEqual(
response_status, int(channel.result["code"]), msg=channel.result["body"]
File diff suppressed because it is too large Load Diff
+2 -1
View File
@@ -606,10 +606,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(1)
# Send a read receipt to tell the server we've read the latest event.
body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8")
channel = self.make_request(
"POST",
f"/rooms/{self.room_id}/read_markers",
{ReceiptTypes.READ: res["event_id"]},
body,
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.json_body)
+4 -4
View File
@@ -136,7 +136,7 @@ class RestHelper:
self.site,
"POST",
path,
content,
json.dumps(content).encode("utf8"),
custom_headers=custom_headers,
)
@@ -210,7 +210,7 @@ class RestHelper:
self.site,
"POST",
path,
data,
json.dumps(data).encode("utf8"),
)
assert (
@@ -309,7 +309,7 @@ class RestHelper:
self.site,
"PUT",
path,
data,
json.dumps(data).encode("utf8"),
)
assert (
@@ -392,7 +392,7 @@ class RestHelper:
self.site,
"PUT",
path,
content or {},
json.dumps(content or {}).encode("utf8"),
custom_headers=custom_headers,
)
+8 -62
View File
@@ -126,9 +126,7 @@ class _TestImage:
expected_scaled: The expected bytes from scaled thumbnailing, or None if
test should just check for a valid image returned.
expected_found: True if the file should exist on the server, or False if
a 404/400 is expected.
unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
False if the thumbnailing should succeed or a normal 404 is expected.
a 404 is expected.
"""
data: bytes
@@ -137,7 +135,6 @@ class _TestImage:
expected_cropped: Optional[bytes] = None
expected_scaled: Optional[bytes] = None
expected_found: bool = True
unable_to_thumbnail: bool = False
@parameterized_class(
@@ -195,7 +192,6 @@ class _TestImage:
b"image/gif",
b".gif",
expected_found=False,
unable_to_thumbnail=True,
),
),
],
@@ -370,29 +366,18 @@ class MediaRepoTests(unittest.HomeserverTestCase):
def test_thumbnail_crop(self) -> None:
"""Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
"crop",
self.test_image.expected_cropped,
expected_found=self.test_image.expected_found,
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
"crop", self.test_image.expected_cropped, self.test_image.expected_found
)
def test_thumbnail_scale(self) -> None:
"""Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
"scale",
self.test_image.expected_scaled,
expected_found=self.test_image.expected_found,
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)
def test_invalid_type(self) -> None:
"""An invalid thumbnail type is never available."""
self._test_thumbnail(
"invalid",
None,
expected_found=False,
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
self._test_thumbnail("invalid", None, False)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
@@ -401,12 +386,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"""
Override the config to generate only scaled thumbnails, but request a cropped one.
"""
self._test_thumbnail(
"crop",
None,
expected_found=False,
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
self._test_thumbnail("crop", None, False)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
@@ -415,22 +395,14 @@ class MediaRepoTests(unittest.HomeserverTestCase):
"""
Override the config to generate only cropped thumbnails, but request a scaled one.
"""
self._test_thumbnail(
"scale",
None,
expected_found=False,
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
)
self._test_thumbnail("scale", None, False)
def test_thumbnail_repeated_thumbnail(self) -> None:
"""Test that fetching the same thumbnail works, and deleting the on disk
thumbnail regenerates it.
"""
self._test_thumbnail(
"scale",
self.test_image.expected_scaled,
expected_found=self.test_image.expected_found,
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)
if not self.test_image.expected_found:
@@ -487,24 +459,8 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
def _test_thumbnail(
self,
method: str,
expected_body: Optional[bytes],
expected_found: bool,
unable_to_thumbnail: bool = False,
self, method: str, expected_body: Optional[bytes], expected_found: bool
) -> None:
"""Test the given thumbnailing method works as expected.
Args:
method: The thumbnailing method to use (crop, scale).
expected_body: The expected bytes from thumbnailing, or None if
test should just check for a valid image.
expected_found: True if the file should exist on the server, or False if
a 404/400 is expected.
unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
False if the thumbnailing should succeed or a normal 404 is expected.
"""
params = "?width=32&height=32&method=" + method
channel = make_request(
self.reactor,
@@ -540,16 +496,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
else:
# ensure that the result is at least some valid image
Image.open(BytesIO(channel.result["body"]))
elif unable_to_thumbnail:
# A 400 with a JSON body.
self.assertEqual(channel.code, 400)
self.assertEqual(
channel.json_body,
{
"errcode": "M_UNKNOWN",
"error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
},
)
else:
# A 404 with a JSON body.
self.assertEqual(channel.code, 404)
+1 -1
View File
@@ -369,7 +369,7 @@ class StateStoreTestCase(HomeserverTestCase):
state_dict_ids = cache_entry.value
self.assertEqual(cache_entry.full, False)
self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
self.assertEqual(cache_entry.known_absent, set())
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
############################################

Some files were not shown because too many files have changed in this diff Show More