Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5de571987e | |||
| 129691f190 | |||
| 462db2a171 | |||
| 7f7b36d56d | |||
| 057ae8b61c | |||
| 7aceec3ed9 | |||
| cad555f07c | |||
| 23c2f394a5 | |||
| 602a81f5a2 | |||
| f046366d2a | |||
| a22716c5c5 | |||
| 40a8fba5f6 | |||
| 326a175987 |
@@ -69,7 +69,7 @@ with open('pyproject.toml', 'w') as f:
|
||||
"
|
||||
python3 -c "$REMOVE_DEV_DEPENDENCIES"
|
||||
|
||||
pipx install poetry==1.1.14
|
||||
pipx install poetry==1.1.12
|
||||
~/.local/bin/poetry lock
|
||||
|
||||
echo "::group::Patched pyproject.toml"
|
||||
|
||||
@@ -1,16 +1,3 @@
|
||||
# Commits in this file will be removed from GitHub blame results.
|
||||
#
|
||||
# To use this file locally, use:
|
||||
# git blame --ignore-revs-file="path/to/.git-blame-ignore-revs" <files>
|
||||
#
|
||||
# or configure the `blame.ignoreRevsFile` option in your git config.
|
||||
#
|
||||
# If ignoring a pull request that was not squash merged, only the merge
|
||||
# commit needs to be put here. Child commits will be resolved from it.
|
||||
|
||||
# Run black (#3679).
|
||||
8b3d9b6b199abb87246f982d5db356f1966db925
|
||||
|
||||
# Black reformatting (#5482).
|
||||
32e7c9e7f20b57dd081023ac42d6931a8da9b3a3
|
||||
|
||||
|
||||
@@ -127,12 +127,12 @@ jobs:
|
||||
run: |
|
||||
set -x
|
||||
DEBIAN_FRONTEND=noninteractive sudo apt-get install -yqq python3 pipx
|
||||
pipx install poetry==1.1.14
|
||||
pipx install poetry==1.1.12
|
||||
|
||||
poetry remove -n twisted
|
||||
poetry add -n --extras tls git+https://github.com/twisted/twisted.git#trunk
|
||||
poetry lock --no-update
|
||||
# NOT IN 1.1.14 poetry lock --check
|
||||
# NOT IN 1.1.12 poetry lock --check
|
||||
working-directory: synapse
|
||||
|
||||
- run: |
|
||||
|
||||
+2
-11
@@ -3,15 +3,6 @@ Synapse vNext
|
||||
|
||||
As of this release, Synapse no longer allows the tasks of verifying email address ownership, and password reset confirmation, to be delegated to an identity server. For more information, see the [upgrade notes](https://matrix-org.github.io/synapse/v1.64/upgrade.html#upgrading-to-v1640).
|
||||
|
||||
Synapse 1.63.0 (2022-07-19)
|
||||
===========================
|
||||
|
||||
Improved Documentation
|
||||
----------------------
|
||||
|
||||
- Clarify that homeserver server names are included in the reported data when the `report_stats` config option is enabled. ([\#13321](https://github.com/matrix-org/synapse/issues/13321))
|
||||
|
||||
|
||||
Synapse 1.63.0rc1 (2022-07-12)
|
||||
==============================
|
||||
|
||||
@@ -20,7 +11,7 @@ Features
|
||||
|
||||
- Add a rate limit for local users sending invites. ([\#13125](https://github.com/matrix-org/synapse/issues/13125))
|
||||
- Implement [MSC3827](https://github.com/matrix-org/matrix-spec-proposals/pull/3827): Filtering of `/publicRooms` by room type. ([\#13031](https://github.com/matrix-org/synapse/issues/13031))
|
||||
- Improve validation logic in the account data REST endpoints. ([\#13148](https://github.com/matrix-org/synapse/issues/13148))
|
||||
- Improve validation logic in Synapse's REST endpoints. ([\#13148](https://github.com/matrix-org/synapse/issues/13148))
|
||||
|
||||
|
||||
Bugfixes
|
||||
@@ -48,7 +39,7 @@ Improved Documentation
|
||||
- Add an explanation of the `--report-stats` argument to the docs. ([\#13029](https://github.com/matrix-org/synapse/issues/13029))
|
||||
- Add a helpful example bash script to the contrib directory for creating multiple worker configuration files of the same type. Contributed by @villepeh. ([\#13032](https://github.com/matrix-org/synapse/issues/13032))
|
||||
- Add missing links to config options. ([\#13166](https://github.com/matrix-org/synapse/issues/13166))
|
||||
- Add documentation for homeserver usage statistics collection. ([\#13086](https://github.com/matrix-org/synapse/issues/13086))
|
||||
- Add documentation for anonymised homeserver statistics collection. ([\#13086](https://github.com/matrix-org/synapse/issues/13086))
|
||||
- Add documentation for the existing `databases` option in the homeserver configuration manual. ([\#13212](https://github.com/matrix-org/synapse/issues/13212))
|
||||
- Clean up references to sample configuration and redirect users to the configuration manual instead. ([\#13077](https://github.com/matrix-org/synapse/issues/13077), [\#13139](https://github.com/matrix-org/synapse/issues/13139))
|
||||
- Document how the Synapse team does reviews. ([\#13132](https://github.com/matrix-org/synapse/issues/13132))
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Use lower isolation level when purging rooms to avoid serialization errors. Contributed by Nick @ Beeper.
|
||||
@@ -1 +0,0 @@
|
||||
Provide more info why we don't have any thumbnails to serve.
|
||||
@@ -1 +0,0 @@
|
||||
Always use a version of canonicaljson that supports the C implementation of frozendict.
|
||||
@@ -1 +0,0 @@
|
||||
Fix spurious warning when fetching state after a missing prev event.
|
||||
@@ -1 +0,0 @@
|
||||
Add another `contrib` script to help set up worker processes. Contributed by @villepeh.
|
||||
@@ -1 +0,0 @@
|
||||
Add per-room rate limiting for room joins. For each room, Synapse now monitors the rate of join events in that room, and throttle additional joins if that rate grows too large.
|
||||
@@ -1 +0,0 @@
|
||||
Don't pull out the full state when creating an event.
|
||||
@@ -1 +0,0 @@
|
||||
Upgrade from Poetry 1.1.14 to 1.1.12, to fix bugs when locking packages.
|
||||
@@ -0,0 +1 @@
|
||||
Make `DictionaryCache` expire full entries if they haven't been queried in a while, even if specific keys have been queried recently.
|
||||
@@ -1 +0,0 @@
|
||||
Fix a bug introduced in v1.18.0 where the `synapse_pushers` metric would overcount pushers when they are replaced.
|
||||
@@ -1 +0,0 @@
|
||||
Use `HTTPStatus` constants in place of literals in tests.
|
||||
@@ -1 +0,0 @@
|
||||
Improve performance of query `_get_subset_users_in_room_with_profiles`.
|
||||
@@ -1 +0,0 @@
|
||||
Up batch size of `bulk_get_push_rules` and `_get_joined_profiles_from_event_ids`.
|
||||
@@ -1 +0,0 @@
|
||||
Remove unnecessary `json.dumps` from tests.
|
||||
@@ -1 +0,0 @@
|
||||
Don't pull out the full state when creating an event.
|
||||
@@ -1 +0,0 @@
|
||||
Use an asynchronous cache wrapper for the get event cache. Contributed by Nick @ Beeper (@fizzadar).
|
||||
@@ -1 +0,0 @@
|
||||
Reduce memory usage of sending dummy events.
|
||||
@@ -1 +0,0 @@
|
||||
Prevent formatting changes of [#3679](https://github.com/matrix-org/synapse/pull/3679) from appearing in `git blame`.
|
||||
@@ -1 +0,0 @@
|
||||
Add notes when config options where changed. Contributed by @behrmann.
|
||||
@@ -1 +0,0 @@
|
||||
Reduce memory usage of state caches.
|
||||
@@ -1 +0,0 @@
|
||||
Stop builindg `.deb` packages for Ubuntu 21.10 (Impish Indri), which has reached end of life.
|
||||
@@ -1 +0,0 @@
|
||||
Add type hints to `trace` decorator.
|
||||
@@ -1 +0,0 @@
|
||||
Document the new `rc_invites.per_issuer` throttling option added in Synapse 1.63.
|
||||
@@ -1,145 +0,0 @@
|
||||
# Creating multiple stream writers with a bash script
|
||||
|
||||
This script creates multiple [stream writer](https://github.com/matrix-org/synapse/blob/develop/docs/workers.md#stream-writers) workers.
|
||||
|
||||
Stream writers require both replication and HTTP listeners.
|
||||
|
||||
It also prints out the example lines for Synapse main configuration file.
|
||||
|
||||
Remember to route necessary endpoints directly to a worker associated with it.
|
||||
|
||||
If you run the script as-is, it will create workers with the replication listener starting from port 8034 and another, regular http listener starting from 8044. If you don't need all of the stream writers listed in the script, just remove them from the ```STREAM_WRITERS``` array.
|
||||
|
||||
```sh
|
||||
#!/bin/bash
|
||||
|
||||
# Start with these replication and http ports.
|
||||
# The script loop starts with the exact port and then increments it by one.
|
||||
REP_START_PORT=8034
|
||||
HTTP_START_PORT=8044
|
||||
|
||||
# Stream writer workers to generate. Feel free to add or remove them as you wish.
|
||||
# Event persister ("events") isn't included here as it does not require its
|
||||
# own HTTP listener.
|
||||
|
||||
STREAM_WRITERS+=( "presence" "typing" "receipts" "to_device" "account_data" )
|
||||
|
||||
NUM_WRITERS=$(expr ${#STREAM_WRITERS[@]})
|
||||
|
||||
i=0
|
||||
|
||||
while [ $i -lt "$NUM_WRITERS" ]
|
||||
do
|
||||
cat << EOF > ${STREAM_WRITERS[$i]}_stream_writer.yaml
|
||||
worker_app: synapse.app.generic_worker
|
||||
worker_name: ${STREAM_WRITERS[$i]}_stream_writer
|
||||
|
||||
# The replication listener on the main synapse process.
|
||||
worker_replication_host: 127.0.0.1
|
||||
worker_replication_http_port: 9093
|
||||
|
||||
worker_listeners:
|
||||
- type: http
|
||||
port: $(expr $REP_START_PORT + $i)
|
||||
resources:
|
||||
- names: [replication]
|
||||
|
||||
- type: http
|
||||
port: $(expr $HTTP_START_PORT + $i)
|
||||
resources:
|
||||
- names: [client]
|
||||
|
||||
worker_log_config: /etc/matrix-synapse/stream-writer-log.yaml
|
||||
EOF
|
||||
HOMESERVER_YAML_INSTANCE_MAP+=$" ${STREAM_WRITERS[$i]}_stream_writer:
|
||||
host: 127.0.0.1
|
||||
port: $(expr $REP_START_PORT + $i)
|
||||
"
|
||||
|
||||
HOMESERVER_YAML_STREAM_WRITERS+=$" ${STREAM_WRITERS[$i]}: ${STREAM_WRITERS[$i]}_stream_writer
|
||||
"
|
||||
|
||||
((i++))
|
||||
done
|
||||
|
||||
cat << EXAMPLECONFIG
|
||||
# Add these lines to your homeserver.yaml.
|
||||
# Don't forget to configure your reverse proxy and
|
||||
# necessary endpoints to their respective worker.
|
||||
|
||||
# See https://github.com/matrix-org/synapse/blob/develop/docs/workers.md
|
||||
# for more information.
|
||||
|
||||
# Remember: Under NO circumstances should the replication
|
||||
# listener be exposed to the public internet;
|
||||
# it has no authentication and is unencrypted.
|
||||
|
||||
instance_map:
|
||||
$HOMESERVER_YAML_INSTANCE_MAP
|
||||
stream_writers:
|
||||
$HOMESERVER_YAML_STREAM_WRITERS
|
||||
EXAMPLECONFIG
|
||||
```
|
||||
|
||||
Copy the code above save it to a file ```create_stream_writers.sh``` (for example).
|
||||
|
||||
Make the script executable by running ```chmod +x create_stream_writers.sh```.
|
||||
|
||||
## Run the script to create workers and print out a sample configuration
|
||||
|
||||
Simply run the script to create YAML files in the current folder and print out the required configuration for ```homeserver.yaml```.
|
||||
|
||||
```console
|
||||
$ ./create_stream_writers.sh
|
||||
|
||||
# Add these lines to your homeserver.yaml.
|
||||
# Don't forget to configure your reverse proxy and
|
||||
# necessary endpoints to their respective worker.
|
||||
|
||||
# See https://github.com/matrix-org/synapse/blob/develop/docs/workers.md
|
||||
# for more information
|
||||
|
||||
# Remember: Under NO circumstances should the replication
|
||||
# listener be exposed to the public internet;
|
||||
# it has no authentication and is unencrypted.
|
||||
|
||||
instance_map:
|
||||
presence_stream_writer:
|
||||
host: 127.0.0.1
|
||||
port: 8034
|
||||
typing_stream_writer:
|
||||
host: 127.0.0.1
|
||||
port: 8035
|
||||
receipts_stream_writer:
|
||||
host: 127.0.0.1
|
||||
port: 8036
|
||||
to_device_stream_writer:
|
||||
host: 127.0.0.1
|
||||
port: 8037
|
||||
account_data_stream_writer:
|
||||
host: 127.0.0.1
|
||||
port: 8038
|
||||
|
||||
stream_writers:
|
||||
presence: presence_stream_writer
|
||||
typing: typing_stream_writer
|
||||
receipts: receipts_stream_writer
|
||||
to_device: to_device_stream_writer
|
||||
account_data: account_data_stream_writer
|
||||
```
|
||||
|
||||
Simply copy-and-paste the output to an appropriate place in your Synapse main configuration file.
|
||||
|
||||
## Write directly to Synapse configuration file
|
||||
|
||||
You could also write the output directly to homeserver main configuration file. **This, however, is not recommended** as even a small typo (such as replacing >> with >) can erase the entire ```homeserver.yaml```.
|
||||
|
||||
If you do this, back up your original configuration file first:
|
||||
|
||||
```console
|
||||
# Back up homeserver.yaml first
|
||||
cp /etc/matrix-synapse/homeserver.yaml /etc/matrix-synapse/homeserver.yaml.bak
|
||||
|
||||
# Create workers and write output to your homeserver.yaml
|
||||
./create_stream_writers.sh >> /etc/matrix-synapse/homeserver.yaml
|
||||
```
|
||||
+1
-1
@@ -1,4 +1,4 @@
|
||||
# Creating multiple generic workers with a bash script
|
||||
# Creating multiple workers with a bash script
|
||||
|
||||
Setting up multiple worker configuration files manually can be time-consuming.
|
||||
You can alternatively create multiple worker configuration files with a simple `bash` script. For example:
|
||||
Vendored
-8
@@ -1,11 +1,3 @@
|
||||
matrix-synapse-py3 (1.63.0) stable; urgency=medium
|
||||
|
||||
* Clarify that homeserver server names are included in the data reported
|
||||
by opt-in server stats reporting (`report_stats` homeserver config option).
|
||||
* New Synapse release 1.63.0.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Tue, 19 Jul 2022 14:42:24 +0200
|
||||
|
||||
matrix-synapse-py3 (1.63.0~rc1) stable; urgency=medium
|
||||
|
||||
* New Synapse release 1.63.0rc1.
|
||||
|
||||
Vendored
+1
-1
@@ -31,7 +31,7 @@ EOF
|
||||
# This file is autogenerated, and will be recreated on upgrade if it is deleted.
|
||||
# Any changes you make will be preserved.
|
||||
|
||||
# Whether to report homeserver usage statistics.
|
||||
# Whether to report anonymized homeserver usage statistics.
|
||||
report_stats: false
|
||||
EOF
|
||||
fi
|
||||
|
||||
Vendored
+6
-6
@@ -37,7 +37,7 @@ msgstr ""
|
||||
#. Type: boolean
|
||||
#. Description
|
||||
#: ../templates:2001
|
||||
msgid "Report homeserver usage statistics?"
|
||||
msgid "Report anonymous statistics?"
|
||||
msgstr ""
|
||||
|
||||
#. Type: boolean
|
||||
@@ -45,11 +45,11 @@ msgstr ""
|
||||
#: ../templates:2001
|
||||
msgid ""
|
||||
"Developers of Matrix and Synapse really appreciate helping the project out "
|
||||
"by reporting homeserver usage statistics from this homeserver. Your "
|
||||
"homeserver's server name, along with very basic aggregate data (e.g. "
|
||||
"number of users) will be reported. But it helps track the growth of the "
|
||||
"Matrix community, and helps in making Matrix a success, as well as to "
|
||||
"convince other networks that they should peer with Matrix."
|
||||
"by reporting anonymized usage statistics from this homeserver. Only very "
|
||||
"basic aggregate data (e.g. number of users) will be reported, but it helps "
|
||||
"track the growth of the Matrix community, and helps in making Matrix a "
|
||||
"success, as well as to convince other networks that they should peer with "
|
||||
"Matrix."
|
||||
msgstr ""
|
||||
|
||||
#. Type: boolean
|
||||
|
||||
Vendored
+6
-7
@@ -10,13 +10,12 @@ _Description: Name of the server:
|
||||
Template: matrix-synapse/report-stats
|
||||
Type: boolean
|
||||
Default: false
|
||||
_Description: Report homeserver usage statistics?
|
||||
_Description: Report anonymous statistics?
|
||||
Developers of Matrix and Synapse really appreciate helping the
|
||||
project out by reporting homeserver usage statistics from this
|
||||
homeserver. Your homeserver's server name, along with very basic
|
||||
aggregate data (e.g. number of users) will be reported. But it
|
||||
helps track the growth of the Matrix community, and helps in
|
||||
making Matrix a success, as well as to convince other networks
|
||||
that they should peer with Matrix.
|
||||
project out by reporting anonymized usage statistics from this
|
||||
homeserver. Only very basic aggregate data (e.g. number of users)
|
||||
will be reported, but it helps track the growth of the Matrix
|
||||
community, and helps in making Matrix a success, as well as to
|
||||
convince other networks that they should peer with Matrix.
|
||||
.
|
||||
Thank you.
|
||||
|
||||
+1
-1
@@ -45,7 +45,7 @@ RUN \
|
||||
|
||||
# We install poetry in its own build stage to avoid its dependencies conflicting with
|
||||
# synapse's dependencies.
|
||||
# We use a specific commit from poetry's master branch instead of our usual 1.1.14,
|
||||
# We use a specific commit from poetry's master branch instead of our usual 1.1.12,
|
||||
# to incorporate fixes to some bugs in `poetry export`. This commit corresponds to
|
||||
# https://github.com/python-poetry/poetry/pull/5156 and
|
||||
# https://github.com/python-poetry/poetry/issues/5141 ;
|
||||
|
||||
@@ -67,10 +67,6 @@ rc_joins:
|
||||
per_second: 9999
|
||||
burst_count: 9999
|
||||
|
||||
rc_joins_per_room:
|
||||
per_second: 9999
|
||||
burst_count: 9999
|
||||
|
||||
rc_3pid_validation:
|
||||
per_second: 1000
|
||||
burst_count: 1000
|
||||
|
||||
+1
-1
@@ -68,7 +68,7 @@
|
||||
- [Federation](usage/administration/admin_api/federation.md)
|
||||
- [Manhole](manhole.md)
|
||||
- [Monitoring](metrics-howto.md)
|
||||
- [Reporting Homeserver Usage Statistics](usage/administration/monitoring/reporting_homeserver_usage_statistics.md)
|
||||
- [Reporting Anonymised Statistics](usage/administration/monitoring/reporting_anonymised_statistics.md)
|
||||
- [Understanding Synapse Through Grafana Graphs](usage/administration/understanding_synapse_through_grafana_graphs.md)
|
||||
- [Useful SQL for Admins](usage/administration/useful_sql_for_admins.md)
|
||||
- [Database Maintenance Tools](usage/administration/database_maintenance_tools.md)
|
||||
|
||||
@@ -237,28 +237,3 @@ poetry run pip install build && poetry run python -m build
|
||||
because [`build`](https://github.com/pypa/build) is a standardish tool which
|
||||
doesn't require poetry. (It's what we use in CI too). However, you could try
|
||||
`poetry build` too.
|
||||
|
||||
|
||||
# Troubleshooting
|
||||
|
||||
## Check the version of poetry with `poetry --version`.
|
||||
|
||||
At the time of writing, the 1.2 series is beta only. We have seen some examples
|
||||
where the lockfiles generated by 1.2 prereleasese aren't interpreted correctly
|
||||
by poetry 1.1.x. For now, use poetry 1.1.14, which includes a critical
|
||||
[change](https://github.com/python-poetry/poetry/pull/5973) needed to remain
|
||||
[compatible with PyPI](https://github.com/pypi/warehouse/pull/11775).
|
||||
|
||||
It can also be useful to check the version of `poetry-core` in use. If you've
|
||||
installed `poetry` with `pipx`, try `pipx runpip poetry list | grep poetry-core`.
|
||||
|
||||
## Clear caches: `poetry cache clear --all pypi`.
|
||||
|
||||
Poetry caches a bunch of information about packages that isn't readily available
|
||||
from PyPI. (This is what makes poetry seem slow when doing the first
|
||||
`poetry install`.) Try `poetry cache list` and `poetry cache clear --all
|
||||
<name of cache>` to see if that fixes things.
|
||||
|
||||
## Try `--verbose` or `--dry-run` arguments.
|
||||
|
||||
Sometimes useful to see what poetry's internal logic is.
|
||||
|
||||
@@ -104,16 +104,6 @@ minimum, a `notif_from` setting.)
|
||||
Specifying an `email` setting under `account_threepid_delegates` will now cause
|
||||
an error at startup.
|
||||
|
||||
## Changes to the event replication streams
|
||||
|
||||
Synapse now includes a flag indicating if an event is an outlier when
|
||||
replicating it to other workers. This is a forwards- and backwards-incompatible
|
||||
change: v1.63 and workers cannot process events replicated by v1.64 workers, and
|
||||
vice versa.
|
||||
|
||||
Once all workers are upgraded to v1.64 (or downgraded to v1.63), event
|
||||
replication will resume as normal.
|
||||
|
||||
# Upgrading to v1.62.0
|
||||
|
||||
## New signatures for spam checker callbacks
|
||||
|
||||
+4
-4
@@ -1,11 +1,11 @@
|
||||
# Reporting Homeserver Usage Statistics
|
||||
# Reporting Anonymised Statistics
|
||||
|
||||
When generating your Synapse configuration file, you are asked whether you
|
||||
would like to report usage statistics to Matrix.org. These statistics
|
||||
would like to report anonymised statistics to Matrix.org. These statistics
|
||||
provide the foundation a glimpse into the number of Synapse homeservers
|
||||
participating in the network, as well as statistics such as the number of
|
||||
rooms being created and messages being sent. This feature is sometimes
|
||||
affectionately called "phone home" stats. Reporting
|
||||
affectionately called "phone-home" stats. Reporting
|
||||
[is optional](../../configuration/config_documentation.md#report_stats)
|
||||
and the reporting endpoint
|
||||
[can be configured](../../configuration/config_documentation.md#report_stats_endpoint),
|
||||
@@ -21,9 +21,9 @@ The following statistics are sent to the configured reporting endpoint:
|
||||
|
||||
| Statistic Name | Type | Description |
|
||||
|----------------------------|--------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `homeserver` | string | The homeserver's server name. |
|
||||
| `memory_rss` | int | The memory usage of the process (in kilobytes on Unix-based systems, bytes on MacOS). |
|
||||
| `cpu_average` | int | CPU time in % of a single core (not % of all cores). |
|
||||
| `homeserver` | string | The homeserver's server name. |
|
||||
| `server_context` | string | An arbitrary string used to group statistics from a set of homeservers. |
|
||||
| `timestamp` | int | The current time, represented as the number of seconds since the epoch. |
|
||||
| `uptime_seconds` | int | The number of seconds since the homeserver was last started. |
|
||||
@@ -239,8 +239,6 @@ If this option is provided, it parses the given yaml to json and
|
||||
serves it on `/.well-known/matrix/client` endpoint
|
||||
alongside the standard properties.
|
||||
|
||||
*Added in Synapse 1.62.0.*
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
extra_well_known_client_content :
|
||||
@@ -1157,9 +1155,6 @@ Caching can be configured through the following sub-options:
|
||||
with intermittent connections, at the cost of higher memory usage.
|
||||
A value of zero means that sync responses are not cached.
|
||||
Defaults to 2m.
|
||||
|
||||
*Changed in Synapse 1.62.0*: The default was changed from 0 to 2m.
|
||||
|
||||
* `cache_autotuning` and its sub-options `max_cache_memory_usage`, `target_cache_memory_usage`, and
|
||||
`min_cache_ttl` work in conjunction with each other to maintain a balance between cache memory
|
||||
usage and cache entry availability. You must be using [jemalloc](https://github.com/matrix-org/synapse#help-synapse-is-slow-and-eats-all-my-ramcpu)
|
||||
@@ -1476,25 +1471,6 @@ rc_joins:
|
||||
per_second: 0.03
|
||||
burst_count: 12
|
||||
```
|
||||
---
|
||||
### `rc_joins_per_room`
|
||||
|
||||
This option allows admins to ratelimit joins to a room based on the number of recent
|
||||
joins (local or remote) to that room. It is intended to mitigate mass-join spam
|
||||
waves which target multiple homeservers.
|
||||
|
||||
By default, one join is permitted to a room every second, with an accumulating
|
||||
buffer of up to ten instantaneous joins.
|
||||
|
||||
Example configuration (default values):
|
||||
```yaml
|
||||
rc_joins_per_room:
|
||||
per_second: 1
|
||||
burst_count: 10
|
||||
```
|
||||
|
||||
_Added in Synapse 1.64.0._
|
||||
|
||||
---
|
||||
### `rc_3pid_validation`
|
||||
|
||||
@@ -1528,8 +1504,6 @@ The `rc_invites.per_user` limit applies to the *receiver* of the invite, rather
|
||||
sender, meaning that a `rc_invite.per_user.burst_count` of 5 mandates that a single user
|
||||
cannot *receive* more than a burst of 5 invites at a time.
|
||||
|
||||
In contrast, the `rc_invites.per_issuer` limit applies to the *issuer* of the invite, meaning that a `rc_invite.per_issuer.burst_count` of 5 mandates that single user cannot *send* more than a burst of 5 invites at a time.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
rc_invites:
|
||||
@@ -1539,13 +1513,7 @@ rc_invites:
|
||||
per_user:
|
||||
per_second: 0.004
|
||||
burst_count: 3
|
||||
per_issuer:
|
||||
per_second: 0.5
|
||||
burst_count: 5
|
||||
```
|
||||
|
||||
_Changed in version 1.63:_ added the `per_issuer` limit.
|
||||
|
||||
---
|
||||
### `rc_third_party_invite`
|
||||
|
||||
@@ -2437,14 +2405,9 @@ metrics_flags:
|
||||
---
|
||||
### `report_stats`
|
||||
|
||||
Whether or not to report homeserver usage statistics. This is originally
|
||||
Whether or not to report anonymized homeserver usage statistics. This is originally
|
||||
set when generating the config. Set this option to true or false to change the current
|
||||
behavior. See
|
||||
[Reporting Homeserver Usage Statistics](../administration/monitoring/reporting_homeserver_usage_statistics.md)
|
||||
for information on what data is reported.
|
||||
|
||||
Statistics will be reported 5 minutes after Synapse starts, and then every 3 hours
|
||||
after that.
|
||||
behavior.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
@@ -2453,7 +2416,7 @@ report_stats: true
|
||||
---
|
||||
### `report_stats_endpoint`
|
||||
|
||||
The endpoint to report homeserver usage statistics to.
|
||||
The endpoint to report the anonymized homeserver usage statistics to.
|
||||
Defaults to https://matrix.org/report-usage-stats/push
|
||||
|
||||
Example configuration:
|
||||
|
||||
Generated
+1
-1
@@ -1563,7 +1563,7 @@ url_preview = ["lxml"]
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.7.1"
|
||||
content-hash = "c24bbcee7e86dbbe7cdbf49f91a25b310bf21095452641e7440129f59b077f78"
|
||||
content-hash = "e96625923122e29b6ea5964379828e321b6cede2b020fc32c6f86c09d86d1ae8"
|
||||
|
||||
[metadata.files]
|
||||
attrs = [
|
||||
|
||||
+2
-4
@@ -54,7 +54,7 @@ skip_gitignore = true
|
||||
|
||||
[tool.poetry]
|
||||
name = "matrix-synapse"
|
||||
version = "1.63.0"
|
||||
version = "1.63.0rc1"
|
||||
description = "Homeserver for the Matrix decentralised comms protocol"
|
||||
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
|
||||
license = "Apache-2.0"
|
||||
@@ -110,9 +110,7 @@ jsonschema = ">=3.0.0"
|
||||
frozendict = ">=1,!=2.1.2"
|
||||
# We require 2.1.0 or higher for type hints. Previous guard was >= 1.1.0
|
||||
unpaddedbase64 = ">=2.1.0"
|
||||
# We require 1.5.0 to work around an issue when running against the C implementation of
|
||||
# frozendict: https://github.com/matrix-org/python-canonicaljson/issues/36
|
||||
canonicaljson = "^1.5.0"
|
||||
canonicaljson = "^1.4.0"
|
||||
# we use the type definitions added in signedjson 1.1.
|
||||
signedjson = "^1.1.0"
|
||||
# validating SSL certs for IP addresses requires service_identity 18.1.
|
||||
|
||||
@@ -26,6 +26,7 @@ DISTS = (
|
||||
"debian:bookworm",
|
||||
"debian:sid",
|
||||
"ubuntu:focal", # 20.04 LTS (our EOL forced by Py38 on 2024-10-14)
|
||||
"ubuntu:impish", # 21.10 (EOL 2022-07)
|
||||
"ubuntu:jammy", # 22.04 LTS (EOL 2027-04)
|
||||
)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ def main() -> None:
|
||||
parser.add_argument(
|
||||
"--report-stats",
|
||||
action="store",
|
||||
help="Whether the generated config reports homeserver usage statistics",
|
||||
help="Whether the generated config reports anonymized usage statistics",
|
||||
choices=["yes", "no"],
|
||||
)
|
||||
|
||||
|
||||
@@ -97,16 +97,16 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
|
||||
# We split these messages out to allow packages to override with package
|
||||
# specific instructions.
|
||||
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS = """\
|
||||
Please opt in or out of reporting homeserver usage statistics, by setting
|
||||
the `report_stats` key in your config file to either True or False.
|
||||
Please opt in or out of reporting anonymized homeserver usage statistics, by
|
||||
setting the `report_stats` key in your config file to either True or False.
|
||||
"""
|
||||
|
||||
MISSING_REPORT_STATS_SPIEL = """\
|
||||
We would really appreciate it if you could help our project out by reporting
|
||||
homeserver usage statistics from your homeserver. Your homeserver's server name,
|
||||
along with very basic aggregate data (e.g. number of users) will be reported. But
|
||||
it helps us to track the growth of the Matrix community, and helps us to make Matrix
|
||||
a success, as well as to convince other networks that they should peer with us.
|
||||
anonymized usage statistics from your homeserver. Only very basic aggregate
|
||||
data (e.g. number of users) will be reported, but it helps us to track the
|
||||
growth of the Matrix community, and helps us to make Matrix a success, as well
|
||||
as to convince other networks that they should peer with us.
|
||||
|
||||
Thank you.
|
||||
"""
|
||||
@@ -621,7 +621,7 @@ class RootConfig:
|
||||
generate_group.add_argument(
|
||||
"--report-stats",
|
||||
action="store",
|
||||
help="Whether the generated config reports homeserver usage statistics.",
|
||||
help="Whether the generated config reports anonymized usage statistics.",
|
||||
choices=["yes", "no"],
|
||||
)
|
||||
generate_group.add_argument(
|
||||
|
||||
@@ -112,13 +112,6 @@ class RatelimitConfig(Config):
|
||||
defaults={"per_second": 0.01, "burst_count": 10},
|
||||
)
|
||||
|
||||
# Track the rate of joins to a given room. If there are too many, temporarily
|
||||
# prevent local joins and remote joins via this server.
|
||||
self.rc_joins_per_room = RateLimitConfig(
|
||||
config.get("rc_joins_per_room", {}),
|
||||
defaults={"per_second": 1, "burst_count": 10},
|
||||
)
|
||||
|
||||
# Ratelimit cross-user key requests:
|
||||
# * For local requests this is keyed by the sending device.
|
||||
# * For requests received over federation this is keyed by the origin.
|
||||
|
||||
@@ -42,18 +42,6 @@ THUMBNAIL_SIZE_YAML = """\
|
||||
# method: %(method)s
|
||||
"""
|
||||
|
||||
# A map from the given media type to the type of thumbnail we should generate
|
||||
# for it.
|
||||
THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP = {
|
||||
"image/jpeg": "jpeg",
|
||||
"image/jpg": "jpeg",
|
||||
"image/webp": "jpeg",
|
||||
# Thumbnails can only be jpeg or png. We choose png thumbnails for gif
|
||||
# because it can have transparency.
|
||||
"image/gif": "png",
|
||||
"image/png": "png",
|
||||
}
|
||||
|
||||
HTTP_PROXY_SET_WARNING = """\
|
||||
The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured."""
|
||||
|
||||
@@ -91,22 +79,13 @@ def parse_thumbnail_requirements(
|
||||
width = size["width"]
|
||||
height = size["height"]
|
||||
method = size["method"]
|
||||
|
||||
for format, thumbnail_format in THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.items():
|
||||
requirement = requirements.setdefault(format, [])
|
||||
if thumbnail_format == "jpeg":
|
||||
requirement.append(
|
||||
ThumbnailRequirement(width, height, method, "image/jpeg")
|
||||
)
|
||||
elif thumbnail_format == "png":
|
||||
requirement.append(
|
||||
ThumbnailRequirement(width, height, method, "image/png")
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"Unknown thumbnail mapping from %s to %s. This is a Synapse problem, please report!"
|
||||
% (format, thumbnail_format)
|
||||
)
|
||||
jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg")
|
||||
png_thumbnail = ThumbnailRequirement(width, height, method, "image/png")
|
||||
requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail)
|
||||
requirements.setdefault("image/jpg", []).append(jpeg_thumbnail)
|
||||
requirements.setdefault("image/webp", []).append(jpeg_thumbnail)
|
||||
requirements.setdefault("image/gif", []).append(png_thumbnail)
|
||||
requirements.setdefault("image/png", []).append(png_thumbnail)
|
||||
return {
|
||||
media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items()
|
||||
}
|
||||
|
||||
@@ -24,11 +24,9 @@ from synapse.api.room_versions import (
|
||||
RoomVersion,
|
||||
)
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.event_auth import auth_types_for_event
|
||||
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import EventID, JsonDict
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
@@ -123,11 +121,7 @@ class EventBuilder:
|
||||
"""
|
||||
if auth_event_ids is None:
|
||||
state_ids = await self._state.compute_state_after_events(
|
||||
self.room_id,
|
||||
prev_event_ids,
|
||||
state_filter=StateFilter.from_types(
|
||||
auth_types_for_event(self.room_version, self)
|
||||
),
|
||||
self.room_id, prev_event_ids
|
||||
)
|
||||
auth_event_ids = self._event_auth_handler.compute_auth_events(
|
||||
self, state_ids
|
||||
|
||||
@@ -217,7 +217,7 @@ class FederationClient(FederationBase):
|
||||
)
|
||||
|
||||
async def claim_client_keys(
|
||||
self, destination: str, content: JsonDict, timeout: Optional[int]
|
||||
self, destination: str, content: JsonDict, timeout: int
|
||||
) -> JsonDict:
|
||||
"""Claims one-time keys for a device hosted on a remote server.
|
||||
|
||||
|
||||
@@ -118,7 +118,6 @@ class FederationServer(FederationBase):
|
||||
self._federation_event_handler = hs.get_federation_event_handler()
|
||||
self.state = hs.get_state_handler()
|
||||
self._event_auth_handler = hs.get_event_auth_handler()
|
||||
self._room_member_handler = hs.get_room_member_handler()
|
||||
|
||||
self._state_storage_controller = hs.get_storage_controllers().state
|
||||
|
||||
@@ -622,15 +621,6 @@ class FederationServer(FederationBase):
|
||||
)
|
||||
raise IncompatibleRoomVersionError(room_version=room_version)
|
||||
|
||||
# Refuse the request if that room has seen too many joins recently.
|
||||
# This is in addition to the HS-level rate limiting applied by
|
||||
# BaseFederationServlet.
|
||||
# type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?)
|
||||
await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
|
||||
requester=None,
|
||||
key=room_id,
|
||||
update=False,
|
||||
)
|
||||
pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
|
||||
return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
|
||||
|
||||
@@ -665,12 +655,6 @@ class FederationServer(FederationBase):
|
||||
room_id: str,
|
||||
caller_supports_partial_state: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
|
||||
requester=None,
|
||||
key=room_id,
|
||||
update=False,
|
||||
)
|
||||
|
||||
event, context = await self._on_send_membership_event(
|
||||
origin, content, Membership.JOIN, room_id
|
||||
)
|
||||
|
||||
@@ -619,7 +619,7 @@ class TransportLayerClient:
|
||||
)
|
||||
|
||||
async def claim_client_keys(
|
||||
self, destination: str, query_content: JsonDict, timeout: Optional[int]
|
||||
self, destination: str, query_content: JsonDict, timeout: int
|
||||
) -> JsonDict:
|
||||
"""Claim one-time keys for a list of devices hosted on a remote server.
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
@@ -92,11 +92,7 @@ class E2eKeysHandler:
|
||||
|
||||
@trace
|
||||
async def query_devices(
|
||||
self,
|
||||
query_body: JsonDict,
|
||||
timeout: int,
|
||||
from_user_id: str,
|
||||
from_device_id: Optional[str],
|
||||
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
|
||||
) -> JsonDict:
|
||||
"""Handle a device key query from a client
|
||||
|
||||
@@ -124,7 +120,9 @@ class E2eKeysHandler:
|
||||
the number of in-flight queries at a time.
|
||||
"""
|
||||
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
|
||||
device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
|
||||
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
|
||||
"device_keys", {}
|
||||
)
|
||||
|
||||
# separate users by domain.
|
||||
# make a map from domain to user_id to device_ids
|
||||
@@ -394,7 +392,7 @@ class E2eKeysHandler:
|
||||
|
||||
@trace
|
||||
async def query_local_devices(
|
||||
self, query: Mapping[str, Optional[List[str]]]
|
||||
self, query: Dict[str, Optional[List[str]]]
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Get E2E device keys for local users
|
||||
|
||||
@@ -463,7 +461,7 @@ class E2eKeysHandler:
|
||||
|
||||
@trace
|
||||
async def claim_one_time_keys(
|
||||
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
|
||||
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
|
||||
) -> JsonDict:
|
||||
local_query: List[Tuple[str, str, str]] = []
|
||||
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
|
||||
|
||||
@@ -1037,9 +1037,6 @@ class FederationEventHandler:
|
||||
# XXX: this doesn't sound right? it means that we'll end up with incomplete
|
||||
# state.
|
||||
failed_to_fetch = desired_events - event_metadata.keys()
|
||||
# `event_id` could be missing from `event_metadata` because it's not necessarily
|
||||
# a state event. We've already checked that we've fetched it above.
|
||||
failed_to_fetch.discard(event_id)
|
||||
if failed_to_fetch:
|
||||
logger.warning(
|
||||
"Failed to fetch missing state events for %s %s",
|
||||
@@ -1983,10 +1980,6 @@ class FederationEventHandler:
|
||||
event, event_pos, max_stream_token, extra_users=extra_users
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
||||
# TODO retrieve the previous state, and exclude join -> join transitions
|
||||
self._notifier.notify_user_joined_room(event.event_id, event.room_id)
|
||||
|
||||
def _sanity_check_event(self, ev: EventBase) -> None:
|
||||
"""
|
||||
Do some early sanity checks of a received event
|
||||
|
||||
@@ -463,7 +463,6 @@ class EventCreationHandler:
|
||||
)
|
||||
self._events_shard_config = self.config.worker.events_shard_config
|
||||
self._instance_name = hs.get_instance_name()
|
||||
self._notifier = hs.get_notifier()
|
||||
|
||||
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
|
||||
|
||||
@@ -1551,16 +1550,6 @@ class EventCreationHandler:
|
||||
requester, is_admin_redaction=is_admin_redaction
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
||||
(
|
||||
current_membership,
|
||||
_,
|
||||
) = await self.store.get_local_current_membership_for_user_in_room(
|
||||
event.state_key, event.room_id
|
||||
)
|
||||
if current_membership != Membership.JOIN:
|
||||
self._notifier.notify_user_joined_room(event.event_id, event.room_id)
|
||||
|
||||
await self._maybe_kick_guest_users(event, context)
|
||||
|
||||
if event.type == EventTypes.CanonicalAlias:
|
||||
@@ -1860,8 +1849,13 @@ class EventCreationHandler:
|
||||
|
||||
# For each room we need to find a joined member we can use to send
|
||||
# the dummy event with.
|
||||
members = await self.store.get_local_users_in_room(room_id)
|
||||
latest_event_ids = await self.store.get_prev_events_for_room(room_id)
|
||||
members = await self.state.get_current_users_in_room(
|
||||
room_id, latest_event_ids=latest_event_ids
|
||||
)
|
||||
for user_id in members:
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
continue
|
||||
requester = create_requester(user_id, authenticated_entity=self.server_name)
|
||||
try:
|
||||
event, context = await self.create_event(
|
||||
@@ -1872,6 +1866,7 @@ class EventCreationHandler:
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
},
|
||||
prev_event_ids=latest_event_ids,
|
||||
)
|
||||
|
||||
event.internal_metadata.proactively_send = False
|
||||
|
||||
@@ -94,29 +94,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
|
||||
)
|
||||
# Tracks joins from local users to rooms this server isn't a member of.
|
||||
# I.e. joins this server makes by requesting /make_join /send_join from
|
||||
# another server.
|
||||
self._join_rate_limiter_remote = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
|
||||
)
|
||||
# TODO: find a better place to keep this Ratelimiter.
|
||||
# It needs to be
|
||||
# - written to by event persistence code
|
||||
# - written to by something which can snoop on replication streams
|
||||
# - read by the RoomMemberHandler to rate limit joins from local users
|
||||
# - read by the FederationServer to rate limit make_joins and send_joins from
|
||||
# other homeservers
|
||||
# I wonder if a homeserver-wide collection of rate limiters might be cleaner?
|
||||
self._join_rate_per_room_limiter = Ratelimiter(
|
||||
store=self.store,
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second,
|
||||
burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
|
||||
)
|
||||
|
||||
# Ratelimiter for invites, keyed by room (across all issuers, all
|
||||
# recipients).
|
||||
@@ -153,18 +136,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
)
|
||||
|
||||
self.request_ratelimiter = hs.get_request_ratelimiter()
|
||||
hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
|
||||
|
||||
def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
|
||||
"""Notify the rate limiter that a room join has occurred.
|
||||
|
||||
Use this to inform the RoomMemberHandler about joins that have either
|
||||
- taken place on another homeserver, or
|
||||
- on another worker in this homeserver.
|
||||
Joins actioned by this worker should use the usual `ratelimit` method, which
|
||||
checks the limit and increments the counter in one go.
|
||||
"""
|
||||
self._join_rate_per_room_limiter.record_action(requester=None, key=room_id)
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _remote_join(
|
||||
@@ -425,9 +396,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
# up blocking profile updates.
|
||||
if newly_joined and ratelimit:
|
||||
await self._join_rate_limiter_local.ratelimit(requester)
|
||||
await self._join_rate_per_room_limiter.ratelimit(
|
||||
requester, key=room_id, update=False
|
||||
)
|
||||
|
||||
result_event = await self.event_creation_handler.handle_new_client_event(
|
||||
requester,
|
||||
@@ -899,11 +867,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
await self._join_rate_limiter_remote.ratelimit(
|
||||
requester,
|
||||
)
|
||||
await self._join_rate_per_room_limiter.ratelimit(
|
||||
requester,
|
||||
key=room_id,
|
||||
update=False,
|
||||
)
|
||||
|
||||
inviter = await self._get_inviter(target.to_string(), room_id)
|
||||
if inviter and not self.hs.is_mine(inviter):
|
||||
|
||||
@@ -84,13 +84,14 @@ the function becomes the operation name for the span.
|
||||
return something_usual_and_useful
|
||||
|
||||
|
||||
Operation names can be explicitly set for a function by using ``trace_with_opname``:
|
||||
Operation names can be explicitly set for a function by passing the
|
||||
operation name to ``trace``
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from synapse.logging.opentracing import trace_with_opname
|
||||
from synapse.logging.opentracing import trace
|
||||
|
||||
@trace_with_opname("a_better_operation_name")
|
||||
@trace(opname="a_better_operation_name")
|
||||
def interesting_badly_named_function(*args, **kwargs):
|
||||
# Does all kinds of cool and expected things
|
||||
return something_usual_and_useful
|
||||
@@ -797,31 +798,33 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
|
||||
# Tracing decorators
|
||||
|
||||
|
||||
def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
||||
def trace(func=None, opname: Optional[str] = None):
|
||||
"""
|
||||
Decorator to trace a function with a custom opname.
|
||||
|
||||
See the module's doc string for usage examples.
|
||||
|
||||
Decorator to trace a function.
|
||||
Sets the operation name to that of the function's or that given
|
||||
as operation_name. See the module's doc string for usage
|
||||
examples.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, R]) -> Callable[P, R]:
|
||||
def decorator(func):
|
||||
if opentracing is None:
|
||||
return func # type: ignore[unreachable]
|
||||
|
||||
_opname = opname if opname else func.__name__
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
with start_active_span(opname):
|
||||
return await func(*args, **kwargs) # type: ignore[misc]
|
||||
async def _trace_inner(*args, **kwargs):
|
||||
with start_active_span(_opname):
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
else:
|
||||
# The other case here handles both sync functions and those
|
||||
# decorated with inlineDeferred.
|
||||
@wraps(func)
|
||||
def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
scope = start_active_span(opname)
|
||||
def _trace_inner(*args, **kwargs):
|
||||
scope = start_active_span(_opname)
|
||||
scope.__enter__()
|
||||
|
||||
try:
|
||||
@@ -855,21 +858,12 @@ def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]
|
||||
scope.__exit__(type(e), None, e.__traceback__)
|
||||
raise
|
||||
|
||||
return _trace_inner # type: ignore[return-value]
|
||||
return _trace_inner
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def trace(func: Callable[P, R]) -> Callable[P, R]:
|
||||
"""
|
||||
Decorator to trace a function.
|
||||
|
||||
Sets the operation name to that of the function's name.
|
||||
|
||||
See the module's doc string for usage examples.
|
||||
"""
|
||||
|
||||
return trace_with_opname(func.__name__)(func)
|
||||
if func:
|
||||
return decorator(func)
|
||||
else:
|
||||
return decorator
|
||||
|
||||
|
||||
def tag_args(func: Callable[P, R]) -> Callable[P, R]:
|
||||
|
||||
+11
-16
@@ -328,7 +328,7 @@ class PusherPool:
|
||||
return None
|
||||
|
||||
try:
|
||||
pusher = self.pusher_factory.create_pusher(pusher_config)
|
||||
p = self.pusher_factory.create_pusher(pusher_config)
|
||||
except PusherConfigException as e:
|
||||
logger.warning(
|
||||
"Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
|
||||
@@ -346,28 +346,23 @@ class PusherPool:
|
||||
)
|
||||
return None
|
||||
|
||||
if not pusher:
|
||||
if not p:
|
||||
return None
|
||||
|
||||
appid_pushkey = "%s:%s" % (pusher.app_id, pusher.pushkey)
|
||||
appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
|
||||
|
||||
byuser = self.pushers.setdefault(pusher.user_id, {})
|
||||
byuser = self.pushers.setdefault(pusher_config.user_name, {})
|
||||
if appid_pushkey in byuser:
|
||||
previous_pusher = byuser[appid_pushkey]
|
||||
previous_pusher.on_stop()
|
||||
byuser[appid_pushkey].on_stop()
|
||||
byuser[appid_pushkey] = p
|
||||
|
||||
synapse_pushers.labels(
|
||||
type(previous_pusher).__name__, previous_pusher.app_id
|
||||
).dec()
|
||||
byuser[appid_pushkey] = pusher
|
||||
|
||||
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
|
||||
synapse_pushers.labels(type(p).__name__, p.app_id).inc()
|
||||
|
||||
# Check if there *may* be push to process. We do this as this check is a
|
||||
# lot cheaper to do than actually fetching the exact rows we need to
|
||||
# push.
|
||||
user_id = pusher.user_id
|
||||
last_stream_ordering = pusher.last_stream_ordering
|
||||
user_id = pusher_config.user_name
|
||||
last_stream_ordering = pusher_config.last_stream_ordering
|
||||
if last_stream_ordering:
|
||||
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
|
||||
user_id, last_stream_ordering
|
||||
@@ -377,9 +372,9 @@ class PusherPool:
|
||||
# risk missing push.
|
||||
have_notifs = True
|
||||
|
||||
pusher.on_started(have_notifs)
|
||||
p.on_started(have_notifs)
|
||||
|
||||
return pusher
|
||||
return p
|
||||
|
||||
async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
|
||||
appid_pushkey = "%s:%s" % (app_id, pushkey)
|
||||
|
||||
@@ -29,7 +29,7 @@ from synapse.http import RequestTimedOutError
|
||||
from synapse.http.server import HttpServer, is_method_cancellable
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging import opentracing
|
||||
from synapse.logging.opentracing import trace_with_opname
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.stringutils import random_string
|
||||
@@ -196,7 +196,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
"ascii"
|
||||
)
|
||||
|
||||
@trace_with_opname("outgoing_replication_request")
|
||||
@trace(opname="outgoing_replication_request")
|
||||
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
|
||||
with outgoing_gauge.track_inprogress():
|
||||
if instance_name == local_instance_name:
|
||||
|
||||
@@ -21,7 +21,7 @@ from twisted.internet.interfaces import IAddress, IConnector
|
||||
from twisted.internet.protocol import ReconnectingClientFactory
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
|
||||
from synapse.api.constants import EventTypes, ReceiptTypes
|
||||
from synapse.federation import send_queue
|
||||
from synapse.federation.sender import FederationSender
|
||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||
@@ -219,21 +219,6 @@ class ReplicationDataHandler:
|
||||
membership=row.data.membership,
|
||||
)
|
||||
|
||||
# If this event is a join, make a note of it so we have an accurate
|
||||
# cross-worker room rate limit.
|
||||
# TODO: Erik said we should exclude rows that came from ex_outliers
|
||||
# here, but I don't see how we can determine that. I guess we could
|
||||
# add a flag to row.data?
|
||||
if (
|
||||
row.data.type == EventTypes.Member
|
||||
and row.data.membership == Membership.JOIN
|
||||
and not row.data.outlier
|
||||
):
|
||||
# TODO retrieve the previous state, and exclude join -> join transitions
|
||||
self.notifier.notify_user_joined_room(
|
||||
row.data.event_id, row.data.room_id
|
||||
)
|
||||
|
||||
await self._presence_handler.process_replication_rows(
|
||||
stream_name, instance_name, token, rows
|
||||
)
|
||||
|
||||
@@ -98,7 +98,6 @@ class EventsStreamEventRow(BaseEventsStreamRow):
|
||||
relates_to: Optional[str]
|
||||
membership: Optional[str]
|
||||
rejected: bool
|
||||
outlier: bool
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
|
||||
@@ -26,7 +26,7 @@ from synapse.http.servlet import (
|
||||
parse_string,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.types import JsonDict, StreamToken
|
||||
|
||||
from ._base import client_patterns, interactive_auth_handler
|
||||
@@ -71,7 +71,7 @@ class KeyUploadServlet(RestServlet):
|
||||
self.e2e_keys_handler = hs.get_e2e_keys_handler()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
@trace_with_opname("upload_keys")
|
||||
@trace(opname="upload_keys")
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, device_id: Optional[str]
|
||||
) -> Tuple[int, JsonDict]:
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, cast
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
@@ -127,7 +127,7 @@ class RoomKeysServlet(RestServlet):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
version = parse_string(request, "version", required=True)
|
||||
version = parse_string(request, "version")
|
||||
|
||||
if session_id:
|
||||
body = {"sessions": {session_id: body}}
|
||||
@@ -196,11 +196,8 @@ class RoomKeysServlet(RestServlet):
|
||||
user_id = requester.user.to_string()
|
||||
version = parse_string(request, "version", required=True)
|
||||
|
||||
room_keys = cast(
|
||||
JsonDict,
|
||||
await self.e2e_room_keys_handler.get_room_keys(
|
||||
user_id, version, room_id, session_id
|
||||
),
|
||||
room_keys = await self.e2e_room_keys_handler.get_room_keys(
|
||||
user_id, version, room_id, session_id
|
||||
)
|
||||
|
||||
# Convert room_keys to the right format to return.
|
||||
@@ -243,7 +240,7 @@ class RoomKeysServlet(RestServlet):
|
||||
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
version = parse_string(request, "version", required=True)
|
||||
version = parse_string(request, "version")
|
||||
|
||||
ret = await self.e2e_room_keys_handler.delete_room_keys(
|
||||
user_id, version, room_id, session_id
|
||||
|
||||
@@ -19,7 +19,7 @@ from synapse.http import servlet
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import set_tag, trace_with_opname
|
||||
from synapse.logging.opentracing import set_tag, trace
|
||||
from synapse.rest.client.transactions import HttpTransactionCache
|
||||
from synapse.types import JsonDict
|
||||
|
||||
@@ -43,7 +43,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
|
||||
self.txns = HttpTransactionCache(hs)
|
||||
self.device_message_handler = hs.get_device_message_handler()
|
||||
|
||||
@trace_with_opname("sendToDevice")
|
||||
@trace(opname="sendToDevice")
|
||||
def on_PUT(
|
||||
self, request: SynapseRequest, message_type: str, txn_id: str
|
||||
) -> Awaitable[Tuple[int, JsonDict]]:
|
||||
|
||||
@@ -37,7 +37,7 @@ from synapse.handlers.sync import (
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import trace_with_opname
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.types import JsonDict, StreamToken
|
||||
from synapse.util import json_decoder
|
||||
|
||||
@@ -210,7 +210,7 @@ class SyncRestServlet(RestServlet):
|
||||
logger.debug("Event formatting complete")
|
||||
return 200, response_content
|
||||
|
||||
@trace_with_opname("sync.encode_response")
|
||||
@trace(opname="sync.encode_response")
|
||||
async def encode_response(
|
||||
self,
|
||||
time_now: int,
|
||||
@@ -315,7 +315,7 @@ class SyncRestServlet(RestServlet):
|
||||
]
|
||||
}
|
||||
|
||||
@trace_with_opname("sync.encode_joined")
|
||||
@trace(opname="sync.encode_joined")
|
||||
async def encode_joined(
|
||||
self,
|
||||
rooms: List[JoinedSyncResult],
|
||||
@@ -340,7 +340,7 @@ class SyncRestServlet(RestServlet):
|
||||
|
||||
return joined
|
||||
|
||||
@trace_with_opname("sync.encode_invited")
|
||||
@trace(opname="sync.encode_invited")
|
||||
async def encode_invited(
|
||||
self,
|
||||
rooms: List[InvitedSyncResult],
|
||||
@@ -371,7 +371,7 @@ class SyncRestServlet(RestServlet):
|
||||
|
||||
return invited
|
||||
|
||||
@trace_with_opname("sync.encode_knocked")
|
||||
@trace(opname="sync.encode_knocked")
|
||||
async def encode_knocked(
|
||||
self,
|
||||
rooms: List[KnockedSyncResult],
|
||||
@@ -420,7 +420,7 @@ class SyncRestServlet(RestServlet):
|
||||
|
||||
return knocked
|
||||
|
||||
@trace_with_opname("sync.encode_archived")
|
||||
@trace(opname="sync.encode_archived")
|
||||
async def encode_archived(
|
||||
self,
|
||||
rooms: List[ArchivedSyncResult],
|
||||
|
||||
@@ -17,11 +17,9 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||
from synapse.config.repository import THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import (
|
||||
DirectServeJsonResource,
|
||||
respond_with_json,
|
||||
set_corp_headers,
|
||||
set_cors_headers,
|
||||
)
|
||||
@@ -311,19 +309,6 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||
url_cache: True if this is from a URL cache.
|
||||
server_name: The server name, if this is a remote thumbnail.
|
||||
"""
|
||||
logger.debug(
|
||||
"_select_and_respond_with_thumbnail: media_id=%s desired=%sx%s (%s) thumbnail_infos=%s",
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
thumbnail_infos,
|
||||
)
|
||||
|
||||
# If `dynamic_thumbnails` is enabled, we expect Synapse to go down a
|
||||
# different code path to handle it.
|
||||
assert not self.dynamic_thumbnails
|
||||
|
||||
if thumbnail_infos:
|
||||
file_info = self._select_thumbnail(
|
||||
desired_width,
|
||||
@@ -399,29 +384,8 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||
file_info.thumbnail.length,
|
||||
)
|
||||
else:
|
||||
# This might be because:
|
||||
# 1. We can't create thumbnails for the given media (corrupted or
|
||||
# unsupported file type), or
|
||||
# 2. The thumbnailing process never ran or errored out initially
|
||||
# when the media was first uploaded (these bugs should be
|
||||
# reported and fixed).
|
||||
# Note that we don't attempt to generate a thumbnail now because
|
||||
# `dynamic_thumbnails` is disabled.
|
||||
logger.info("Failed to find any generated thumbnails")
|
||||
|
||||
respond_with_json(
|
||||
request,
|
||||
400,
|
||||
cs_error(
|
||||
"Cannot find any thumbnails for the requested media (%r). This might mean the media is not a supported_media_format=(%s) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)"
|
||||
% (
|
||||
request.postpath,
|
||||
", ".join(THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP.keys()),
|
||||
),
|
||||
code=Codes.UNKNOWN,
|
||||
),
|
||||
send_cors=True,
|
||||
)
|
||||
respond_404(request)
|
||||
|
||||
def _select_thumbnail(
|
||||
self,
|
||||
|
||||
@@ -157,7 +157,6 @@ class StateHandler:
|
||||
self,
|
||||
room_id: str,
|
||||
event_ids: Collection[str],
|
||||
state_filter: Optional[StateFilter] = None,
|
||||
) -> StateMap[str]:
|
||||
"""Fetch the state after each of the given event IDs. Resolve them and return.
|
||||
|
||||
@@ -175,7 +174,7 @@ class StateHandler:
|
||||
"""
|
||||
logger.debug("calling resolve_state_groups from compute_state_after_events")
|
||||
ret = await self.resolve_state_groups_for_events(room_id, event_ids)
|
||||
return await ret.get_state(self._state_storage_controller, state_filter)
|
||||
return await ret.get_state(self._state_storage_controller, StateFilter.all())
|
||||
|
||||
async def get_current_users_in_room(
|
||||
self, room_id: str, latest_event_ids: List[str]
|
||||
|
||||
@@ -96,10 +96,6 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
cache doesn't exist. Mainly used for invalidating caches on workers,
|
||||
where they may not have the cache.
|
||||
|
||||
Note that this function does not invalidate any remote caches, only the
|
||||
local in-memory ones. Any remote invalidation must be performed before
|
||||
calling this.
|
||||
|
||||
Args:
|
||||
cache_name
|
||||
key: Entry to invalidate. If None then invalidates the entire
|
||||
@@ -116,10 +112,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
if key is None:
|
||||
cache.invalidate_all()
|
||||
else:
|
||||
# Prefer any local-only invalidation method. Invalidating any non-local
|
||||
# cache must be be done before this.
|
||||
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
|
||||
invalidate_method(tuple(key))
|
||||
cache.invalidate(tuple(key))
|
||||
|
||||
|
||||
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
|
||||
|
||||
@@ -23,7 +23,6 @@ from time import monotonic as monotonic_time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
@@ -58,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.background_updates import BackgroundUpdater
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -169,7 +168,6 @@ class LoggingDatabaseConnection:
|
||||
*,
|
||||
txn_name: Optional[str] = None,
|
||||
after_callbacks: Optional[List["_CallbackListEntry"]] = None,
|
||||
async_after_callbacks: Optional[List["_AsyncCallbackListEntry"]] = None,
|
||||
exception_callbacks: Optional[List["_CallbackListEntry"]] = None,
|
||||
) -> "LoggingTransaction":
|
||||
if not txn_name:
|
||||
@@ -180,7 +178,6 @@ class LoggingDatabaseConnection:
|
||||
name=txn_name,
|
||||
database_engine=self.engine,
|
||||
after_callbacks=after_callbacks,
|
||||
async_after_callbacks=async_after_callbacks,
|
||||
exception_callbacks=exception_callbacks,
|
||||
)
|
||||
|
||||
@@ -212,9 +209,6 @@ class LoggingDatabaseConnection:
|
||||
|
||||
# The type of entry which goes on our after_callbacks and exception_callbacks lists.
|
||||
_CallbackListEntry = Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]]
|
||||
_AsyncCallbackListEntry = Tuple[
|
||||
Callable[..., Awaitable], Tuple[object, ...], Dict[str, object]
|
||||
]
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -233,10 +227,6 @@ class LoggingTransaction:
|
||||
that have been added by `call_after` which should be run on
|
||||
successful completion of the transaction. None indicates that no
|
||||
callbacks should be allowed to be scheduled to run.
|
||||
async_after_callbacks: A list that asynchronous callbacks will be appended
|
||||
to by `async_call_after` which should run, before after_callbacks, on
|
||||
successful completion of the transaction. None indicates that no
|
||||
callbacks should be allowed to be scheduled to run.
|
||||
exception_callbacks: A list that callbacks will be appended
|
||||
to that have been added by `call_on_exception` which should be run
|
||||
if transaction ends with an error. None indicates that no callbacks
|
||||
@@ -248,7 +238,6 @@ class LoggingTransaction:
|
||||
"name",
|
||||
"database_engine",
|
||||
"after_callbacks",
|
||||
"async_after_callbacks",
|
||||
"exception_callbacks",
|
||||
]
|
||||
|
||||
@@ -258,14 +247,12 @@ class LoggingTransaction:
|
||||
name: str,
|
||||
database_engine: BaseDatabaseEngine,
|
||||
after_callbacks: Optional[List[_CallbackListEntry]] = None,
|
||||
async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None,
|
||||
exception_callbacks: Optional[List[_CallbackListEntry]] = None,
|
||||
):
|
||||
self.txn = txn
|
||||
self.name = name
|
||||
self.database_engine = database_engine
|
||||
self.after_callbacks = after_callbacks
|
||||
self.async_after_callbacks = async_after_callbacks
|
||||
self.exception_callbacks = exception_callbacks
|
||||
|
||||
def call_after(
|
||||
@@ -290,28 +277,6 @@ class LoggingTransaction:
|
||||
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
|
||||
self.after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
|
||||
|
||||
def async_call_after(
|
||||
self, callback: Callable[P, Awaitable], *args: P.args, **kwargs: P.kwargs
|
||||
) -> None:
|
||||
"""Call the given asynchronous callback on the main twisted thread after
|
||||
the transaction has finished (but before those added in `call_after`).
|
||||
|
||||
Mostly used to invalidate remote caches after transactions.
|
||||
|
||||
Note that transactions may be retried a few times if they encounter database
|
||||
errors such as serialization failures. Callbacks given to `async_call_after`
|
||||
will accumulate across transaction attempts and will _all_ be called once a
|
||||
transaction attempt succeeds, regardless of whether previous transaction
|
||||
attempts failed. Otherwise, if all transaction attempts fail, all
|
||||
`call_on_exception` callbacks will be run instead.
|
||||
"""
|
||||
# if self.async_after_callbacks is None, that means that whatever constructed the
|
||||
# LoggingTransaction isn't expecting there to be any callbacks; assert that
|
||||
# is not the case.
|
||||
assert self.async_after_callbacks is not None
|
||||
# type-ignore: need mypy containing https://github.com/python/mypy/pull/12668
|
||||
self.async_after_callbacks.append((callback, args, kwargs)) # type: ignore[arg-type]
|
||||
|
||||
def call_on_exception(
|
||||
self, callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs
|
||||
) -> None:
|
||||
@@ -609,7 +574,6 @@ class DatabasePool:
|
||||
conn: LoggingDatabaseConnection,
|
||||
desc: str,
|
||||
after_callbacks: List[_CallbackListEntry],
|
||||
async_after_callbacks: List[_AsyncCallbackListEntry],
|
||||
exception_callbacks: List[_CallbackListEntry],
|
||||
func: Callable[Concatenate[LoggingTransaction, P], R],
|
||||
*args: P.args,
|
||||
@@ -633,7 +597,6 @@ class DatabasePool:
|
||||
conn
|
||||
desc
|
||||
after_callbacks
|
||||
async_after_callbacks
|
||||
exception_callbacks
|
||||
func
|
||||
*args
|
||||
@@ -696,7 +659,6 @@ class DatabasePool:
|
||||
cursor = conn.cursor(
|
||||
txn_name=name,
|
||||
after_callbacks=after_callbacks,
|
||||
async_after_callbacks=async_after_callbacks,
|
||||
exception_callbacks=exception_callbacks,
|
||||
)
|
||||
try:
|
||||
@@ -836,7 +798,6 @@ class DatabasePool:
|
||||
|
||||
async def _runInteraction() -> R:
|
||||
after_callbacks: List[_CallbackListEntry] = []
|
||||
async_after_callbacks: List[_AsyncCallbackListEntry] = []
|
||||
exception_callbacks: List[_CallbackListEntry] = []
|
||||
|
||||
if not current_context():
|
||||
@@ -848,7 +809,6 @@ class DatabasePool:
|
||||
self.new_transaction,
|
||||
desc,
|
||||
after_callbacks,
|
||||
async_after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
@@ -857,17 +817,15 @@ class DatabasePool:
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# We order these assuming that async functions call out to external
|
||||
# systems (e.g. to invalidate a cache) and the sync functions make these
|
||||
# changes on any local in-memory caches/similar, and thus must be second.
|
||||
for async_callback, async_args, async_kwargs in async_after_callbacks:
|
||||
await async_callback(*async_args, **async_kwargs)
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
after_callback(*after_args, **after_kwargs)
|
||||
await maybe_awaitable(after_callback(*after_args, **after_kwargs))
|
||||
|
||||
return cast(R, result)
|
||||
except Exception:
|
||||
for exception_callback, after_args, after_kwargs in exception_callbacks:
|
||||
exception_callback(*after_args, **after_kwargs)
|
||||
await maybe_awaitable(
|
||||
exception_callback(*after_args, **after_kwargs)
|
||||
)
|
||||
raise
|
||||
|
||||
# To handle cancellation, we ensure that `after_callback`s and
|
||||
|
||||
@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
|
||||
# changed its content in the database. We can't call
|
||||
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
|
||||
# right type.
|
||||
self.invalidate_get_event_cache_after_txn(txn, event.event_id)
|
||||
txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
|
||||
# Send that invalidation to replication so that other workers also invalidate
|
||||
# the event cache.
|
||||
self._send_invalidation_to_replication(
|
||||
|
||||
@@ -669,7 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
|
||||
|
||||
@trace
|
||||
async def get_user_devices_from_cache(
|
||||
self, query_list: List[Tuple[str, Optional[str]]]
|
||||
self, query_list: List[Tuple[str, str]]
|
||||
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
|
||||
"""Get the devices (and keys if any) for remote users from the cache.
|
||||
|
||||
|
||||
@@ -22,14 +22,11 @@ from typing import (
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.api.constants import DeviceKeyAlgorithms
|
||||
from synapse.appservice import (
|
||||
@@ -116,7 +113,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
user_devices = devices[user_id]
|
||||
results = []
|
||||
for device_id, device in user_devices.items():
|
||||
result: JsonDict = {"device_id": device_id}
|
||||
result = {"device_id": device_id}
|
||||
|
||||
keys = device.keys
|
||||
if keys:
|
||||
@@ -159,9 +156,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
rv[user_id] = {}
|
||||
for device_id, device_info in device_keys.items():
|
||||
r = device_info.keys
|
||||
if r is None:
|
||||
continue
|
||||
|
||||
r["unsigned"] = {}
|
||||
display_name = device_info.display_name
|
||||
if display_name is not None:
|
||||
@@ -170,42 +164,13 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
|
||||
return rv
|
||||
|
||||
@overload
|
||||
async def get_e2e_device_keys_and_signatures(
|
||||
self,
|
||||
query_list: Collection[Tuple[str, Optional[str]]],
|
||||
include_all_devices: Literal[False] = False,
|
||||
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get_e2e_device_keys_and_signatures(
|
||||
self,
|
||||
query_list: Collection[Tuple[str, Optional[str]]],
|
||||
include_all_devices: bool = False,
|
||||
include_deleted_devices: Literal[False] = False,
|
||||
) -> Dict[str, Dict[str, DeviceKeyLookupResult]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get_e2e_device_keys_and_signatures(
|
||||
self,
|
||||
query_list: Collection[Tuple[str, Optional[str]]],
|
||||
include_all_devices: Literal[True],
|
||||
include_deleted_devices: Literal[True],
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
...
|
||||
|
||||
@trace
|
||||
async def get_e2e_device_keys_and_signatures(
|
||||
self,
|
||||
query_list: Collection[Tuple[str, Optional[str]]],
|
||||
query_list: List[Tuple[str, Optional[str]]],
|
||||
include_all_devices: bool = False,
|
||||
include_deleted_devices: bool = False,
|
||||
) -> Union[
|
||||
Dict[str, Dict[str, DeviceKeyLookupResult]],
|
||||
Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]],
|
||||
]:
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
"""Fetch a list of device keys
|
||||
|
||||
Any cross-signatures made on the keys by the owner of the device are also
|
||||
@@ -1079,7 +1044,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
|
||||
db_autocommit = False
|
||||
|
||||
claim_row = await self.db_pool.runInteraction(
|
||||
row = await self.db_pool.runInteraction(
|
||||
"claim_e2e_one_time_keys",
|
||||
_claim_e2e_one_time_key,
|
||||
user_id,
|
||||
@@ -1087,11 +1052,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
algorithm,
|
||||
db_autocommit=db_autocommit,
|
||||
)
|
||||
if claim_row:
|
||||
if row:
|
||||
device_results = results.setdefault(user_id, {}).setdefault(
|
||||
device_id, {}
|
||||
)
|
||||
device_results[claim_row[0]] = claim_row[1]
|
||||
device_results[row[0]] = row[1]
|
||||
continue
|
||||
|
||||
# No one-time key available, so see if there's a fallback
|
||||
|
||||
@@ -1293,7 +1293,7 @@ class PersistEventsStore:
|
||||
depth_updates: Dict[str, int] = {}
|
||||
for event, context in events_and_contexts:
|
||||
# Remove the any existing cache entries for the event_ids
|
||||
self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
|
||||
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
|
||||
# Then update the `stream_ordering` position to mark the latest
|
||||
# event as the front of the room. This should not be done for
|
||||
# backfilled events because backfilled events have negative
|
||||
@@ -1675,7 +1675,7 @@ class PersistEventsStore:
|
||||
(cache_entry.event.event_id,), cache_entry
|
||||
)
|
||||
|
||||
txn.async_call_after(prefill)
|
||||
txn.call_after(prefill)
|
||||
|
||||
def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
|
||||
"""Invalidate the caches for the redacted event.
|
||||
@@ -1684,7 +1684,7 @@ class PersistEventsStore:
|
||||
_invalidate_caches_for_event.
|
||||
"""
|
||||
assert event.redacts is not None
|
||||
self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
|
||||
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
|
||||
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
|
||||
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
|
||||
|
||||
|
||||
@@ -712,41 +712,17 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return event_entry_map
|
||||
|
||||
def invalidate_get_event_cache_after_txn(
|
||||
self, txn: LoggingTransaction, event_id: str
|
||||
) -> None:
|
||||
"""
|
||||
Prepares a database transaction to invalidate the get event cache for a given
|
||||
event ID when executed successfully. This is achieved by attaching two callbacks
|
||||
to the transaction, one to invalidate the async cache and one for the in memory
|
||||
sync cache (importantly called in that order).
|
||||
|
||||
Arguments:
|
||||
txn: the database transaction to attach the callbacks to
|
||||
event_id: the event ID to be invalidated from caches
|
||||
"""
|
||||
|
||||
txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
|
||||
txn.call_after(self._invalidate_local_get_event_cache, event_id)
|
||||
|
||||
async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
|
||||
"""
|
||||
Invalidates an event in the asyncronous get event cache, which may be remote.
|
||||
|
||||
Arguments:
|
||||
event_id: the event ID to invalidate
|
||||
"""
|
||||
|
||||
async def _invalidate_get_event_cache(self, event_id: str) -> None:
|
||||
# First we invalidate the asynchronous cache instance. This may include
|
||||
# out-of-process caches such as Redis/memcache. Once complete we can
|
||||
# invalidate any in memory cache. The ordering is important here to
|
||||
# ensure we don't pull in any remote invalid value after we invalidate
|
||||
# the in-memory cache.
|
||||
await self._get_event_cache.invalidate((event_id,))
|
||||
self._event_ref.pop(event_id, None)
|
||||
self._current_event_fetches.pop(event_id, None)
|
||||
|
||||
def _invalidate_local_get_event_cache(self, event_id: str) -> None:
|
||||
"""
|
||||
Invalidates an event in local in-memory get event caches.
|
||||
|
||||
Arguments:
|
||||
event_id: the event ID to invalidate
|
||||
"""
|
||||
|
||||
self._get_event_cache.invalidate_local((event_id,))
|
||||
self._event_ref.pop(event_id, None)
|
||||
self._current_event_fetches.pop(event_id, None)
|
||||
@@ -982,13 +958,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
}
|
||||
|
||||
row_dict = self.db_pool.new_transaction(
|
||||
conn,
|
||||
"do_fetch",
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
self._fetch_event_rows,
|
||||
events_to_fetch,
|
||||
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
|
||||
)
|
||||
|
||||
# We only want to resolve deferreds from the main thread
|
||||
@@ -1490,7 +1460,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
async def get_all_new_forward_event_rows(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
|
||||
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
|
||||
"""Returns new events, for the Events replication stream
|
||||
|
||||
Args:
|
||||
@@ -1506,11 +1476,10 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
def get_all_new_forward_event_rows(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
|
||||
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
|
||||
" e.outlier"
|
||||
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events AS se USING (event_id)"
|
||||
@@ -1524,8 +1493,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
txn.execute(sql, (last_id, current_id, instance_name, limit))
|
||||
return cast(
|
||||
List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
|
||||
txn.fetchall(),
|
||||
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
@@ -1534,7 +1502,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
async def get_ex_outlier_stream_rows(
|
||||
self, instance_name: str, last_id: int, current_id: int
|
||||
) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
|
||||
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
|
||||
"""Returns de-outliered events, for the Events replication stream
|
||||
|
||||
Args:
|
||||
@@ -1549,14 +1517,11 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
def get_ex_outlier_stream_rows_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Tuple[int, str, str, str, str, str, str, str, bool, bool]]:
|
||||
) -> List[Tuple[int, str, str, str, str, str, str, str, str]]:
|
||||
sql = (
|
||||
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL,"
|
||||
" e.outlier"
|
||||
" se.state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
|
||||
" FROM events AS e"
|
||||
# NB: the next line (inner join) is what makes this query different from
|
||||
# get_all_new_forward_event_rows.
|
||||
" INNER JOIN ex_outlier_stream AS out USING (event_id)"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events AS se USING (event_id)"
|
||||
@@ -1571,8 +1536,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
txn.execute(sql, (last_id, current_id, instance_name))
|
||||
return cast(
|
||||
List[Tuple[int, str, str, str, str, str, str, str, bool, bool]],
|
||||
txn.fetchall(),
|
||||
List[Tuple[int, str, str, str, str, str, str, str, str]], txn.fetchall()
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
|
||||
@@ -66,7 +66,6 @@ class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
|
||||
"initialise_mau_threepids",
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
self._initialise_reserved_users,
|
||||
hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
|
||||
)
|
||||
|
||||
@@ -19,8 +19,6 @@ from synapse.api.errors import SynapseError
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.databases.main import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.state import StateGroupWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.engines._base import IsolationLevel
|
||||
from synapse.types import RoomStreamToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -304,7 +302,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.have_seen_event, (room_id, event_id)
|
||||
)
|
||||
self.invalidate_get_event_cache_after_txn(txn, event_id)
|
||||
txn.call_after(self._invalidate_get_event_cache, event_id)
|
||||
|
||||
logger.info("[purge] done")
|
||||
|
||||
@@ -319,38 +317,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||
Returns:
|
||||
The list of state groups to delete.
|
||||
"""
|
||||
|
||||
# This first runs the purge transaction with READ_COMMITTED isolation level,
|
||||
# meaning any new rows in the tables will not trigger a serialization error.
|
||||
# We then run the same purge a second time without this isolation level to
|
||||
# purge any of those rows which were added during the first.
|
||||
|
||||
state_groups_to_delete = await self.db_pool.runInteraction(
|
||||
"purge_room",
|
||||
self._purge_room_txn,
|
||||
room_id=room_id,
|
||||
isolation_level=IsolationLevel.READ_COMMITTED,
|
||||
return await self.db_pool.runInteraction(
|
||||
"purge_room", self._purge_room_txn, room_id
|
||||
)
|
||||
|
||||
state_groups_to_delete.extend(
|
||||
await self.db_pool.runInteraction(
|
||||
"purge_room",
|
||||
self._purge_room_txn,
|
||||
room_id=room_id,
|
||||
),
|
||||
)
|
||||
|
||||
return state_groups_to_delete
|
||||
|
||||
def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
|
||||
# This collides with event persistence so we cannot write new events and metadata into
|
||||
# a room while deleting it or this transaction will fail.
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
txn.execute(
|
||||
"SELECT room_version FROM rooms WHERE room_id = ? FOR UPDATE",
|
||||
(room_id,),
|
||||
)
|
||||
|
||||
# First, fetch all the state groups that should be deleted, before
|
||||
# we delete that information.
|
||||
txn.execute(
|
||||
|
||||
@@ -228,7 +228,6 @@ class PushRulesWorkerStore(
|
||||
iterable=user_ids,
|
||||
retcols=("*",),
|
||||
desc="bulk_get_push_rules",
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
|
||||
|
||||
@@ -243,7 +243,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
txn: LoggingTransaction,
|
||||
) -> Dict[str, ProfileInfo]:
|
||||
clause, ids = make_in_list_sql_clause(
|
||||
self.database_engine, "c.state_key", user_ids
|
||||
self.database_engine, "m.user_id", user_ids
|
||||
)
|
||||
|
||||
sql = """
|
||||
@@ -904,7 +904,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
iterable=event_ids,
|
||||
retcols=("user_id", "display_name", "avatar_url", "event_id"),
|
||||
keyvalues={"membership": Membership.JOIN},
|
||||
batch_size=1000,
|
||||
batch_size=500,
|
||||
desc="_get_joined_profiles_from_event_ids",
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +24,6 @@ from synapse.storage.database import (
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
from synapse.util.caches import intern_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -137,7 +136,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
|
||||
txn.execute(sql % (where_clause,), args)
|
||||
for row in txn:
|
||||
typ, state_key, event_id = row
|
||||
key = (intern_string(typ), intern_string(state_key))
|
||||
key = (typ, state_key)
|
||||
results[group][key] = event_id
|
||||
else:
|
||||
max_entries_returned = state_filter.max_entries_returned()
|
||||
|
||||
@@ -202,7 +202,14 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||
requests state from the cache, if False we need to query the DB for the
|
||||
missing state.
|
||||
"""
|
||||
cache_entry = cache.get(group)
|
||||
# If we are asked explicitly for a subset of keys, we only ask for those
|
||||
# from the cache. This ensures that the `DictionaryCache` can make
|
||||
# better decisions about what to cache and what to expire.
|
||||
dict_keys = None
|
||||
if not state_filter.has_wildcards():
|
||||
dict_keys = state_filter.concrete_types()
|
||||
|
||||
cache_entry = cache.get(group, dict_keys=dict_keys)
|
||||
state_dict_ids = cache_entry.value
|
||||
|
||||
if cache_entry.full or state_filter.is_full():
|
||||
|
||||
@@ -14,11 +14,13 @@
|
||||
import enum
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
|
||||
from typing import Any, Dict, Generic, Iterable, Optional, Set, Tuple, TypeVar, Union
|
||||
|
||||
import attr
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_items
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -53,20 +55,67 @@ class DictionaryEntry: # should be: Generic[DKT, DV].
|
||||
return len(self.value)
|
||||
|
||||
|
||||
class _FullCacheKey(enum.Enum):
|
||||
"""The key we use to cache the full dict."""
|
||||
|
||||
KEY = object()
|
||||
|
||||
|
||||
class _Sentinel(enum.Enum):
|
||||
# defining a sentinel in this way allows mypy to correctly handle the
|
||||
# type of a dictionary lookup.
|
||||
sentinel = object()
|
||||
|
||||
|
||||
class _PerKeyValue(Generic[DV]):
|
||||
"""The cached value of a dictionary key. If `value` is the sentinel,
|
||||
indicates that the requested key is known to *not* be in the full dict.
|
||||
"""
|
||||
|
||||
__slots__ = ["value"]
|
||||
|
||||
def __init__(self, value: Union[DV, Literal[_Sentinel.sentinel]]) -> None:
|
||||
self.value = value
|
||||
|
||||
def __len__(self) -> int:
|
||||
# We add a `__len__` implementation as we use this class in a cache
|
||||
# where the values are variable length.
|
||||
return 1
|
||||
|
||||
|
||||
class DictionaryCache(Generic[KT, DKT, DV]):
|
||||
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
|
||||
fetching a subset of dictionary keys for a particular key.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, max_entries: int = 1000):
|
||||
self.cache: LruCache[KT, DictionaryEntry] = LruCache(
|
||||
max_size=max_entries, cache_name=name, size_callback=len
|
||||
# We use a single cache to cache two different types of entries:
|
||||
# 1. Map from (key, dict_key) -> dict value (or sentinel, indicating
|
||||
# the key doesn't exist in the dict); and
|
||||
# 2. Map from (key, _FullCacheKey.KEY) -> full dict.
|
||||
#
|
||||
# The former is used when explicit keys of the dictionary are looked up,
|
||||
# and the latter when the full dictionary is requested.
|
||||
#
|
||||
# If when explicit keys are requested and not in the cache, we then look
|
||||
# to see if we have the full dict and use that if we do. If found in the
|
||||
# full dict each key is added into the cache.
|
||||
#
|
||||
# This set up allows the `LruCache` to prune the full dict entries if
|
||||
# they haven't been used in a while, even when there have been recent
|
||||
# queries for subsets of the dict.
|
||||
#
|
||||
# Typing:
|
||||
# * A key of `(KT, DKT)` has a value of `_PerKeyValue`
|
||||
# * A key of `(KT, _FullCacheKey.KEY)` has a value of `Dict[DKT, DV]`
|
||||
self.cache: LruCache[
|
||||
Tuple[KT, Union[DKT, Literal[_FullCacheKey.KEY]]],
|
||||
Union[_PerKeyValue, Dict[DKT, DV]],
|
||||
] = LruCache(
|
||||
max_size=max_entries,
|
||||
cache_name=name,
|
||||
cache_type=TreeCache,
|
||||
size_callback=len,
|
||||
)
|
||||
|
||||
self.name = name
|
||||
@@ -96,20 +145,97 @@ class DictionaryCache(Generic[KT, DKT, DV]):
|
||||
Returns:
|
||||
DictionaryEntry
|
||||
"""
|
||||
entry = self.cache.get(key, _Sentinel.sentinel)
|
||||
if entry is not _Sentinel.sentinel:
|
||||
if dict_keys is None:
|
||||
return DictionaryEntry(
|
||||
entry.full, entry.known_absent, dict(entry.value)
|
||||
)
|
||||
else:
|
||||
return DictionaryEntry(
|
||||
entry.full,
|
||||
entry.known_absent,
|
||||
{k: entry.value[k] for k in dict_keys if k in entry.value},
|
||||
)
|
||||
|
||||
return DictionaryEntry(False, set(), {})
|
||||
if dict_keys is None:
|
||||
# First we check if we have cached the full dict.
|
||||
entry = self.cache.get((key, _FullCacheKey.KEY), _Sentinel.sentinel)
|
||||
if entry is not _Sentinel.sentinel:
|
||||
assert isinstance(entry, dict)
|
||||
return DictionaryEntry(True, set(), entry)
|
||||
|
||||
# If not, check if we have cached any of dict keys.
|
||||
all_entries = self.cache.get_multi(
|
||||
(key,),
|
||||
_Sentinel.sentinel,
|
||||
)
|
||||
if all_entries is _Sentinel.sentinel:
|
||||
return DictionaryEntry(False, set(), {})
|
||||
|
||||
# If there are entries we need to unwrap the returned cache nodes
|
||||
# and `_PerKeyValue` into the `DictionaryEntry`.
|
||||
values = {}
|
||||
known_absent = set()
|
||||
for dict_key, dict_value in iterate_tree_cache_items((), all_entries):
|
||||
dict_key = dict_key[0]
|
||||
dict_value = dict_value.value
|
||||
|
||||
# We have explicitly looked for a full cache key, so we
|
||||
# shouldn't see one.
|
||||
assert dict_key != _FullCacheKey.KEY
|
||||
|
||||
# ... therefore the values must be `_PerKeyValue`
|
||||
assert isinstance(dict_value, _PerKeyValue)
|
||||
|
||||
if dict_value.value is _Sentinel.sentinel:
|
||||
known_absent.add(dict_key)
|
||||
else:
|
||||
values[dict_key] = dict_value.value
|
||||
|
||||
return DictionaryEntry(False, known_absent, values)
|
||||
|
||||
# We are being asked for a subset of keys.
|
||||
|
||||
# First got and check for each requested dict key in the cache, tracking
|
||||
# which we couldn't find.
|
||||
values = {}
|
||||
known_absent = set()
|
||||
missing = set()
|
||||
for dict_key in dict_keys:
|
||||
entry = self.cache.get((key, dict_key), _Sentinel.sentinel)
|
||||
if entry is _Sentinel.sentinel:
|
||||
missing.add(dict_key)
|
||||
continue
|
||||
|
||||
assert isinstance(entry, _PerKeyValue)
|
||||
|
||||
if entry.value is _Sentinel.sentinel:
|
||||
known_absent.add(dict_key)
|
||||
else:
|
||||
values[dict_key] = entry.value
|
||||
|
||||
# If we found everything we can return immediately.
|
||||
if not missing:
|
||||
return DictionaryEntry(False, known_absent, values)
|
||||
|
||||
# If we are missing any keys check if we happen to have the full dict in
|
||||
# the cache.
|
||||
#
|
||||
# We don't update the last access time for this cache fetch, as we
|
||||
# aren't explicitly interested in the full dict and so we don't want
|
||||
# requests for explicit dict keys to keep the full dict in the cache.
|
||||
entry = self.cache.get(
|
||||
(key, _FullCacheKey.KEY),
|
||||
_Sentinel.sentinel,
|
||||
update_last_access=False,
|
||||
)
|
||||
if entry is _Sentinel.sentinel:
|
||||
# Not in the cache, return the subset of keys we found.
|
||||
return DictionaryEntry(False, known_absent, values)
|
||||
|
||||
# We have the full dict!
|
||||
assert isinstance(entry, dict)
|
||||
|
||||
values = {}
|
||||
for dict_key in dict_keys:
|
||||
# We explicitly add each dict key to the cache, so that cache hit
|
||||
# rates for each key can be tracked separately.
|
||||
value = entry.get(dict_key, _Sentinel.sentinel) # type: ignore[arg-type]
|
||||
self.cache[(key, dict_key)] = _PerKeyValue(value)
|
||||
|
||||
if value is not _Sentinel.sentinel:
|
||||
values[dict_key] = value
|
||||
|
||||
return DictionaryEntry(True, set(), values)
|
||||
|
||||
def invalidate(self, key: KT) -> None:
|
||||
self.check_thread()
|
||||
@@ -117,7 +243,9 @@ class DictionaryCache(Generic[KT, DKT, DV]):
|
||||
# Increment the sequence number so that any SELECT statements that
|
||||
# raced with the INSERT don't update the cache (SYN-369)
|
||||
self.sequence += 1
|
||||
self.cache.pop(key, None)
|
||||
|
||||
# Del-multi accepts truncated tuples.
|
||||
self.cache.del_multi((key,)) # type: ignore[arg-type]
|
||||
|
||||
def invalidate_all(self) -> None:
|
||||
self.check_thread()
|
||||
@@ -149,20 +277,27 @@ class DictionaryCache(Generic[KT, DKT, DV]):
|
||||
# Only update the cache if the caches sequence number matches the
|
||||
# number that the cache had before the SELECT was started (SYN-369)
|
||||
if fetched_keys is None:
|
||||
self._insert(key, value, set())
|
||||
self.cache[(key, _FullCacheKey.KEY)] = value
|
||||
else:
|
||||
self._update_or_insert(key, value, fetched_keys)
|
||||
self._update_subset(key, value, fetched_keys)
|
||||
|
||||
def _update_or_insert(
|
||||
self, key: KT, value: Dict[DKT, DV], known_absent: Iterable[DKT]
|
||||
def _update_subset(
|
||||
self, key: KT, value: Dict[DKT, DV], fetched_keys: Iterable[DKT]
|
||||
) -> None:
|
||||
# We pop and reinsert as we need to tell the cache the size may have
|
||||
# changed
|
||||
"""Add the given dictionary values as explicit keys in the cache.
|
||||
|
||||
entry: DictionaryEntry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
|
||||
entry.value.update(value)
|
||||
entry.known_absent.update(known_absent)
|
||||
self.cache[key] = entry
|
||||
Args:
|
||||
key
|
||||
value: The dictionary with all the values that we should cache
|
||||
fetched_keys: The full set of keys that were looked up, any keys
|
||||
here not in `value` should be marked as "known absent".
|
||||
"""
|
||||
|
||||
def _insert(self, key: KT, value: Dict[DKT, DV], known_absent: Set[DKT]) -> None:
|
||||
self.cache[key] = DictionaryEntry(True, known_absent, value)
|
||||
for dict_key, dict_value in value.items():
|
||||
self.cache[(key, dict_key)] = _PerKeyValue(dict_value)
|
||||
|
||||
for dict_key in fetched_keys:
|
||||
if (key, dict_key) in self.cache:
|
||||
continue
|
||||
|
||||
self.cache[(key, dict_key)] = _PerKeyValue(_Sentinel.sentinel)
|
||||
|
||||
@@ -44,7 +44,11 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
|
||||
from synapse.metrics.jemalloc import get_jemalloc_stats
|
||||
from synapse.util import Clock, caches
|
||||
from synapse.util.caches import CacheMetric, EvictionReason, register_cache
|
||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||
from synapse.util.caches.treecache import (
|
||||
TreeCache,
|
||||
TreeCacheNode,
|
||||
iterate_tree_cache_entry,
|
||||
)
|
||||
from synapse.util.linked_list import ListNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -413,7 +417,7 @@ class LruCache(Generic[KT, VT]):
|
||||
else:
|
||||
real_clock = clock
|
||||
|
||||
cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
|
||||
cache: Union[Dict[KT, _Node[KT, VT]], TreeCache[_Node[KT, VT]]] = cache_type()
|
||||
self.cache = cache # Used for introspection.
|
||||
self.apply_cache_factor_from_config = apply_cache_factor_from_config
|
||||
|
||||
@@ -537,6 +541,7 @@ class LruCache(Generic[KT, VT]):
|
||||
default: Literal[None] = None,
|
||||
callbacks: Collection[Callable[[], None]] = ...,
|
||||
update_metrics: bool = ...,
|
||||
update_last_access: bool = ...,
|
||||
) -> Optional[VT]:
|
||||
...
|
||||
|
||||
@@ -546,6 +551,7 @@ class LruCache(Generic[KT, VT]):
|
||||
default: T,
|
||||
callbacks: Collection[Callable[[], None]] = ...,
|
||||
update_metrics: bool = ...,
|
||||
update_last_access: bool = ...,
|
||||
) -> Union[T, VT]:
|
||||
...
|
||||
|
||||
@@ -555,10 +561,27 @@ class LruCache(Generic[KT, VT]):
|
||||
default: Optional[T] = None,
|
||||
callbacks: Collection[Callable[[], None]] = (),
|
||||
update_metrics: bool = True,
|
||||
update_last_access: bool = True,
|
||||
) -> Union[None, T, VT]:
|
||||
"""Lookup a key in the cache
|
||||
|
||||
Args:
|
||||
key
|
||||
default
|
||||
callbacks: A collection of callbacks that will fire when the
|
||||
node is removed from the cache (either due to invalidation
|
||||
or expiry).
|
||||
update_metrics: Whether to update the hit rate metrics
|
||||
update_last_access: Whether to update the last access metrics
|
||||
on a node if successfully fetched. These metrics are used
|
||||
to determine when to remove the node from the cache. Set
|
||||
to False if this fetch should *not* prevent a node from
|
||||
being expired.
|
||||
"""
|
||||
node = cache.get(key, None)
|
||||
if node is not None:
|
||||
move_node_to_front(node)
|
||||
if update_last_access:
|
||||
move_node_to_front(node)
|
||||
node.add_callbacks(callbacks)
|
||||
if update_metrics and metrics:
|
||||
metrics.inc_hits()
|
||||
@@ -568,6 +591,42 @@ class LruCache(Generic[KT, VT]):
|
||||
metrics.inc_misses()
|
||||
return default
|
||||
|
||||
@overload
|
||||
def cache_get_multi(
|
||||
key: tuple,
|
||||
default: Literal[None] = None,
|
||||
update_metrics: bool = True,
|
||||
) -> Union[None, TreeCacheNode]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cache_get_multi(
|
||||
key: tuple,
|
||||
default: T,
|
||||
update_metrics: bool = True,
|
||||
) -> Union[T, TreeCacheNode]:
|
||||
...
|
||||
|
||||
@synchronized
|
||||
def cache_get_multi(
|
||||
key: tuple,
|
||||
default: Optional[T] = None,
|
||||
update_metrics: bool = True,
|
||||
) -> Union[None, T, TreeCacheNode]:
|
||||
"""Used only for `TreeCache` to fetch a subtree."""
|
||||
|
||||
assert isinstance(cache, TreeCache)
|
||||
|
||||
node = cache.get(key, None)
|
||||
if node is not None:
|
||||
if update_metrics and metrics:
|
||||
metrics.inc_hits()
|
||||
return node
|
||||
else:
|
||||
if update_metrics and metrics:
|
||||
metrics.inc_misses()
|
||||
return default
|
||||
|
||||
@synchronized
|
||||
def cache_set(
|
||||
key: KT, value: VT, callbacks: Collection[Callable[[], None]] = ()
|
||||
@@ -674,6 +733,8 @@ class LruCache(Generic[KT, VT]):
|
||||
self.setdefault = cache_set_default
|
||||
self.pop = cache_pop
|
||||
self.del_multi = cache_del_multi
|
||||
if cache_type is TreeCache:
|
||||
self.get_multi = cache_get_multi
|
||||
# `invalidate` is exposed for consistency with DeferredCache, so that it can be
|
||||
# invalidated by the cache invalidation replication stream.
|
||||
self.invalidate = cache_del_multi
|
||||
|
||||
@@ -12,18 +12,59 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SENTINEL = object()
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generator,
|
||||
Generic,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
|
||||
class TreeCacheNode(dict):
|
||||
class Sentinel(Enum):
|
||||
sentinel = object()
|
||||
|
||||
|
||||
V = TypeVar("V")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TreeCacheNode(Generic[V]):
|
||||
"""The type of nodes in our tree.
|
||||
|
||||
Has its own type so we can distinguish it from real dicts that are stored at the
|
||||
leaves.
|
||||
Either a leaf node or a branch node.
|
||||
"""
|
||||
|
||||
__slots__ = ["leaf_value", "sub_tree"]
|
||||
|
||||
class TreeCache:
|
||||
def __init__(
|
||||
self,
|
||||
leaf_value: Union[V, Literal[Sentinel.sentinel]] = Sentinel.sentinel,
|
||||
sub_tree: Optional[Dict[Any, "TreeCacheNode[V]"]] = None,
|
||||
) -> None:
|
||||
if leaf_value is Sentinel.sentinel and sub_tree is None:
|
||||
raise Exception("One of leaf or sub tree must be set")
|
||||
|
||||
self.leaf_value: Union[V, Literal[Sentinel.sentinel]] = leaf_value
|
||||
self.sub_tree: Optional[Dict[Any, "TreeCacheNode[V]"]] = sub_tree
|
||||
|
||||
@staticmethod
|
||||
def leaf(value: V) -> "TreeCacheNode[V]":
|
||||
return TreeCacheNode(leaf_value=value)
|
||||
|
||||
@staticmethod
|
||||
def empty_branch() -> "TreeCacheNode[V]":
|
||||
return TreeCacheNode(sub_tree={})
|
||||
|
||||
|
||||
class TreeCache(Generic[V]):
|
||||
"""
|
||||
Tree-based backing store for LruCache. Allows subtrees of data to be deleted
|
||||
efficiently.
|
||||
@@ -35,15 +76,15 @@ class TreeCache:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.size: int = 0
|
||||
self.root = TreeCacheNode()
|
||||
self.root: TreeCacheNode[V] = TreeCacheNode.empty_branch()
|
||||
|
||||
def __setitem__(self, key, value) -> None:
|
||||
def __setitem__(self, key: tuple, value: V) -> None:
|
||||
self.set(key, value)
|
||||
|
||||
def __contains__(self, key) -> bool:
|
||||
return self.get(key, SENTINEL) is not SENTINEL
|
||||
def __contains__(self, key: tuple) -> bool:
|
||||
return self.get(key, None) is not None
|
||||
|
||||
def set(self, key, value) -> None:
|
||||
def set(self, key: tuple, value: V) -> None:
|
||||
if isinstance(value, TreeCacheNode):
|
||||
# this would mean we couldn't tell where our tree ended and the value
|
||||
# started.
|
||||
@@ -51,31 +92,56 @@ class TreeCache:
|
||||
|
||||
node = self.root
|
||||
for k in key[:-1]:
|
||||
next_node = node.get(k, SENTINEL)
|
||||
if next_node is SENTINEL:
|
||||
next_node = node[k] = TreeCacheNode()
|
||||
elif not isinstance(next_node, TreeCacheNode):
|
||||
# this suggests that the caller is not being consistent with its key
|
||||
# length.
|
||||
sub_tree = node.sub_tree
|
||||
if sub_tree is None:
|
||||
raise ValueError("value conflicts with an existing subtree")
|
||||
node = next_node
|
||||
|
||||
node[key[-1]] = value
|
||||
next_node = sub_tree.get(k, None)
|
||||
if next_node is None:
|
||||
node = TreeCacheNode.empty_branch()
|
||||
sub_tree[k] = node
|
||||
else:
|
||||
node = next_node
|
||||
|
||||
if node.sub_tree is None:
|
||||
raise ValueError("value conflicts with an existing subtree")
|
||||
|
||||
node.sub_tree[key[-1]] = TreeCacheNode.leaf(value)
|
||||
self.size += 1
|
||||
|
||||
def get(self, key, default=None):
|
||||
@overload
|
||||
def get(self, key: tuple, default: Literal[None] = None) -> Union[None, V]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self, key: tuple, default: T) -> Union[T, V]:
|
||||
...
|
||||
|
||||
def get(self, key: tuple, default: Optional[T] = None) -> Union[None, T, V]:
|
||||
node = self.root
|
||||
for k in key[:-1]:
|
||||
node = node.get(k, None)
|
||||
if node is None:
|
||||
for k in key:
|
||||
sub_tree = node.sub_tree
|
||||
if sub_tree is None:
|
||||
raise ValueError("get() key too long")
|
||||
|
||||
next_node = sub_tree.get(k, None)
|
||||
if next_node is None:
|
||||
return default
|
||||
return node.get(key[-1], default)
|
||||
|
||||
node = next_node
|
||||
|
||||
if node.leaf_value is Sentinel.sentinel:
|
||||
raise ValueError("key points to a branch")
|
||||
|
||||
return node.leaf_value
|
||||
|
||||
def clear(self) -> None:
|
||||
self.size = 0
|
||||
self.root = TreeCacheNode()
|
||||
|
||||
def pop(self, key, default=None):
|
||||
def pop(
|
||||
self, key: tuple, default: Optional[T] = None
|
||||
) -> Union[None, T, V, TreeCacheNode[V]]:
|
||||
"""Remove the given key, or subkey, from the cache
|
||||
|
||||
Args:
|
||||
@@ -91,20 +157,25 @@ class TreeCache:
|
||||
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
|
||||
|
||||
# a list of the nodes we have touched on the way down the tree
|
||||
nodes = []
|
||||
nodes: List[TreeCacheNode[V]] = []
|
||||
|
||||
node = self.root
|
||||
for k in key[:-1]:
|
||||
node = node.get(k, None)
|
||||
if node is None:
|
||||
return default
|
||||
if not isinstance(node, TreeCacheNode):
|
||||
# we've gone off the end of the tree
|
||||
sub_tree = node.sub_tree
|
||||
if sub_tree is None:
|
||||
raise ValueError("pop() key too long")
|
||||
nodes.append(node) # don't add the root node
|
||||
popped = node.pop(key[-1], SENTINEL)
|
||||
if popped is SENTINEL:
|
||||
return default
|
||||
|
||||
next_node = sub_tree.get(k, None)
|
||||
if next_node is None:
|
||||
return default
|
||||
|
||||
node = next_node
|
||||
nodes.append(node)
|
||||
|
||||
if node.sub_tree is None:
|
||||
raise ValueError("pop() key too long")
|
||||
|
||||
popped = node.sub_tree.pop(key[-1])
|
||||
|
||||
# working back up the tree, clear out any nodes that are now empty
|
||||
node_and_keys = list(zip(nodes, key))
|
||||
@@ -116,8 +187,13 @@ class TreeCache:
|
||||
|
||||
if n:
|
||||
break
|
||||
|
||||
# found an empty node: remove it from its parent, and loop.
|
||||
node_and_keys[i + 1][0].pop(k)
|
||||
node = node_and_keys[i + 1][0]
|
||||
|
||||
# We added it to the list so already know its a branch node.
|
||||
assert node.sub_tree is not None
|
||||
node.sub_tree.pop(k)
|
||||
|
||||
cnt = sum(1 for _ in iterate_tree_cache_entry(popped))
|
||||
self.size -= cnt
|
||||
@@ -130,12 +206,31 @@ class TreeCache:
|
||||
return self.size
|
||||
|
||||
|
||||
def iterate_tree_cache_entry(d):
|
||||
def iterate_tree_cache_entry(d: TreeCacheNode[V]) -> Generator[V, None, None]:
|
||||
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
|
||||
can contain dicts.
|
||||
"""
|
||||
if isinstance(d, TreeCacheNode):
|
||||
for value_d in d.values():
|
||||
|
||||
if d.sub_tree is not None:
|
||||
for value_d in d.sub_tree.values():
|
||||
yield from iterate_tree_cache_entry(value_d)
|
||||
else:
|
||||
yield d
|
||||
assert d.leaf_value is not Sentinel.sentinel
|
||||
yield d.leaf_value
|
||||
|
||||
|
||||
def iterate_tree_cache_items(
|
||||
key: tuple, value: TreeCacheNode[V]
|
||||
) -> Generator[Tuple[tuple, V], None, None]:
|
||||
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
|
||||
can contain dicts.
|
||||
|
||||
Returns:
|
||||
A generator yielding key/value pairs.
|
||||
"""
|
||||
if value.sub_tree is not None:
|
||||
for sub_key, sub_value in value.sub_tree.items():
|
||||
yield from iterate_tree_cache_items((*key, sub_key), sub_value)
|
||||
else:
|
||||
assert value.leaf_value is not Sentinel.sentinel
|
||||
yield key, value.leaf_value
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import Mock
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
@@ -51,7 +50,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||
channel = self.make_signed_federation_request(
|
||||
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code)
|
||||
self.assertEqual(200, channel.code)
|
||||
complexity = channel.json_body["v1"]
|
||||
self.assertTrue(complexity > 0, complexity)
|
||||
|
||||
@@ -63,7 +62,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
|
||||
channel = self.make_signed_federation_request(
|
||||
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code)
|
||||
self.assertEqual(200, channel.code)
|
||||
complexity = channel.json_body["v1"]
|
||||
self.assertEqual(complexity, 1.23)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
@@ -59,7 +58,7 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase):
|
||||
"/_matrix/federation/v1/get_missing_events/%s" % (room_1,),
|
||||
query_content,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, channel.result)
|
||||
self.assertEqual(400, channel.code, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON")
|
||||
|
||||
|
||||
@@ -120,7 +119,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
|
||||
channel = self.make_signed_federation_request(
|
||||
"GET", "/_matrix/federation/v1/state/%s?event_id=xyz" % (room_1,)
|
||||
)
|
||||
self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, channel.result)
|
||||
self.assertEqual(403, channel.code, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
|
||||
|
||||
@@ -148,13 +147,13 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
tok2 = self.login("fozzie", "bear")
|
||||
self.helper.join(self._room_id, second_member_user_id, tok=tok2)
|
||||
|
||||
def _make_join(self, user_id: str) -> JsonDict:
|
||||
def _make_join(self, user_id) -> JsonDict:
|
||||
channel = self.make_signed_federation_request(
|
||||
"GET",
|
||||
f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
|
||||
f"?ver={DEFAULT_ROOM_VERSION}",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
return channel.json_body
|
||||
|
||||
def test_send_join(self):
|
||||
@@ -172,7 +171,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
f"/_matrix/federation/v2/send_join/{self._room_id}/x",
|
||||
content=join_event_dict,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# we should get complete room state back
|
||||
returned_state = [
|
||||
@@ -227,7 +226,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
|
||||
content=join_event_dict,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# expect a reduced room state
|
||||
returned_state = [
|
||||
@@ -260,67 +259,6 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
||||
|
||||
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}})
|
||||
def test_make_join_respects_room_join_rate_limit(self) -> None:
|
||||
# In the test setup, two users join the room. Since the rate limiter burst
|
||||
# count is 3, a new make_join request to the room should be accepted.
|
||||
|
||||
joining_user = "@ronniecorbett:" + self.OTHER_SERVER_NAME
|
||||
self._make_join(joining_user)
|
||||
|
||||
# Now have a new local user join the room. This saturates the rate limiter
|
||||
# bucket, so the next make_join should be denied.
|
||||
new_local_user = self.register_user("animal", "animal")
|
||||
token = self.login("animal", "animal")
|
||||
self.helper.join(self._room_id, new_local_user, tok=token)
|
||||
|
||||
joining_user = "@ronniebarker:" + self.OTHER_SERVER_NAME
|
||||
channel = self.make_signed_federation_request(
|
||||
"GET",
|
||||
f"/_matrix/federation/v1/make_join/{self._room_id}/{joining_user}"
|
||||
f"?ver={DEFAULT_ROOM_VERSION}",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body)
|
||||
|
||||
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}})
|
||||
def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> None:
|
||||
# Make two make_join requests up front. (These are rate limited, but do not
|
||||
# contribute to the rate limit.)
|
||||
join_event_dicts = []
|
||||
for i in range(2):
|
||||
joining_user = f"@misspiggy{i}:{self.OTHER_SERVER_NAME}"
|
||||
join_result = self._make_join(joining_user)
|
||||
join_event_dict = join_result["event"]
|
||||
self.add_hashes_and_signatures_from_other_server(
|
||||
join_event_dict,
|
||||
KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
|
||||
)
|
||||
join_event_dicts.append(join_event_dict)
|
||||
|
||||
# In the test setup, two users join the room. Since the rate limiter burst
|
||||
# count is 3, the first send_join should be accepted...
|
||||
channel = self.make_signed_federation_request(
|
||||
"PUT",
|
||||
f"/_matrix/federation/v2/send_join/{self._room_id}/join0",
|
||||
content=join_event_dicts[0],
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
# ... but the second should be denied.
|
||||
channel = self.make_signed_federation_request(
|
||||
"PUT",
|
||||
f"/_matrix/federation/v2/send_join/{self._room_id}/join1",
|
||||
content=join_event_dicts[1],
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body)
|
||||
|
||||
# NB: we could write a test which checks that the send_join event is seen
|
||||
# by other workers over replication, and that they update their rate limit
|
||||
# buckets accordingly. I'm going to assume that the join event gets sent over
|
||||
# replication, at which point the tests.handlers.room_member test
|
||||
# test_local_users_joining_on_another_worker_contribute_to_rate_limit
|
||||
# is probably sufficient to reassure that the bucket is updated.
|
||||
|
||||
|
||||
def _create_acl_event(content):
|
||||
return make_event_from_dict(
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import OrderedDict
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
@@ -256,7 +255,7 @@ class FederationKnockingTestCase(
|
||||
RoomVersions.V7.identifier,
|
||||
),
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
|
||||
# Note: We don't expect the knock membership event to be sent over federation as
|
||||
# part of the stripped room state, as the knocking homeserver already has that
|
||||
@@ -294,7 +293,7 @@ class FederationKnockingTestCase(
|
||||
% (room_id, signed_knock_event.event_id),
|
||||
signed_knock_event_json,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
|
||||
# Check that we got the stripped room state in return
|
||||
room_state_events = channel.json_body["knock_state_events"]
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
|
||||
"""Tests for the password_auth_provider interface"""
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Type, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
@@ -189,14 +188,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
# check_password must return an awaitable
|
||||
mock_password_provider.check_password.return_value = make_awaitable(True)
|
||||
channel = self._send_password_login("u", "p")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@u:test", channel.json_body["user_id"])
|
||||
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
|
||||
mock_password_provider.reset_mock()
|
||||
|
||||
# login with mxid should work too
|
||||
channel = self._send_password_login("@u:bz", "p")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@u:bz", channel.json_body["user_id"])
|
||||
mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
|
||||
mock_password_provider.reset_mock()
|
||||
@@ -205,7 +204,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
# in these cases, but at least we can guard against the API changing
|
||||
# unexpectedly
|
||||
channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
|
||||
mock_password_provider.check_password.assert_called_once_with(
|
||||
"@ USER🙂NAME :test", " pASS😢word "
|
||||
@@ -259,10 +258,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
# check_password must return an awaitable
|
||||
mock_password_provider.check_password.return_value = make_awaitable(False)
|
||||
channel = self._send_password_login("u", "p")
|
||||
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
|
||||
self.assertEqual(channel.code, 403, channel.result)
|
||||
|
||||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@localuser:test", channel.json_body["user_id"])
|
||||
|
||||
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
|
||||
@@ -383,7 +382,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
|
||||
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||
channel = self._send_password_login("u", "p")
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
mock_password_provider.check_password.assert_not_called()
|
||||
|
||||
@override_config(legacy_providers_config(LegacyCustomAuthProvider))
|
||||
@@ -407,14 +406,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
|
||||
# login with missing param should be rejected
|
||||
channel = self._send_login("test.login_type", "u")
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
mock_password_provider.check_auth.assert_not_called()
|
||||
|
||||
mock_password_provider.check_auth.return_value = make_awaitable(
|
||||
("@user:bz", None)
|
||||
)
|
||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
||||
mock_password_provider.check_auth.assert_called_once_with(
|
||||
"u", "test.login_type", {"test_field": "y"}
|
||||
@@ -428,7 +427,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
("@ MALFORMED! :bz", None)
|
||||
)
|
||||
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
|
||||
mock_password_provider.check_auth.assert_called_once_with(
|
||||
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
|
||||
@@ -511,7 +510,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
("@user:bz", callback)
|
||||
)
|
||||
channel = self._send_login("test.login_type", "u", test_field="y")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual("@user:bz", channel.json_body["user_id"])
|
||||
mock_password_provider.check_auth.assert_called_once_with(
|
||||
"u", "test.login_type", {"test_field": "y"}
|
||||
@@ -550,7 +549,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
|
||||
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
mock_password_provider.check_auth.assert_not_called()
|
||||
|
||||
@override_config(
|
||||
@@ -585,7 +584,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
|
||||
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
mock_password_provider.check_auth.assert_not_called()
|
||||
|
||||
@override_config(
|
||||
@@ -616,7 +615,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
|
||||
# login shouldn't work and should be rejected with a 400 ("unknown login type")
|
||||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
mock_password_provider.check_auth.assert_not_called()
|
||||
mock_password_provider.check_password.assert_not_called()
|
||||
|
||||
@@ -647,13 +646,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
("@localuser:test", None)
|
||||
)
|
||||
channel = self._send_login("test.login_type", "localuser", test_field="")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
tok1 = channel.json_body["access_token"]
|
||||
|
||||
channel = self._send_login(
|
||||
"test.login_type", "localuser", test_field="", device_id="dev2"
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# make the initial request which returns a 401
|
||||
channel = self._delete_device(tok1, "dev2")
|
||||
@@ -722,7 +721,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
# password login shouldn't work and should be rejected with a 400
|
||||
# ("unknown login type")
|
||||
channel = self._send_password_login("localuser", "localpass")
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
|
||||
def test_on_logged_out(self):
|
||||
"""Tests that the on_logged_out callback is called when the user logs out."""
|
||||
@@ -885,7 +884,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
access_token=tok,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
|
||||
self.assertEqual(channel.code, 403, channel.result)
|
||||
self.assertEqual(
|
||||
channel.json_body["errcode"],
|
||||
Codes.THREEPID_DENIED,
|
||||
@@ -907,7 +906,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
access_token=tok,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertIn("sid", channel.json_body)
|
||||
|
||||
m.assert_called_once_with("email", "bar@test.com", registration)
|
||||
@@ -950,12 +949,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
|
||||
"register",
|
||||
{"auth": {"session": session, "type": LoginType.DUMMY}},
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
return channel.json_body
|
||||
|
||||
def _get_login_flows(self) -> JsonDict:
|
||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
return channel.json_body["flows"]
|
||||
|
||||
def _send_password_login(self, user: str, password: str) -> FakeChannel:
|
||||
|
||||
@@ -1,290 +0,0 @@
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
import synapse.rest.client.login
|
||||
import synapse.rest.client.room
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import LimitExceededError
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.events import FrozenEventV3
|
||||
from synapse.federation.federation_client import SendJoinResult
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.replication._base import RedisMultiWorkerStreamTestCase
|
||||
from tests.server import make_request
|
||||
from tests.test_utils import make_awaitable
|
||||
from tests.unittest import FederatingHomeserverTestCase, override_config
|
||||
|
||||
|
||||
class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
synapse.rest.client.login.register_servlets,
|
||||
synapse.rest.client.room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.handler = hs.get_room_member_handler()
|
||||
|
||||
# Create three users.
|
||||
self.alice = self.register_user("alice", "pass")
|
||||
self.alice_token = self.login("alice", "pass")
|
||||
self.bob = self.register_user("bob", "pass")
|
||||
self.bob_token = self.login("bob", "pass")
|
||||
self.chris = self.register_user("chris", "pass")
|
||||
self.chris_token = self.login("chris", "pass")
|
||||
|
||||
# Create a room on this homeserver. Note that this counts as a join: it
|
||||
# contributes to the rate limter's count of actions
|
||||
self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
|
||||
|
||||
self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}"
|
||||
|
||||
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
|
||||
def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None:
|
||||
# The rate limiter has accumulated one token from Alice's join after the create
|
||||
# event.
|
||||
# Try joining the room as Bob.
|
||||
self.get_success(
|
||||
self.handler.update_membership(
|
||||
requester=create_requester(self.bob),
|
||||
target=UserID.from_string(self.bob),
|
||||
room_id=self.room_id,
|
||||
action=Membership.JOIN,
|
||||
)
|
||||
)
|
||||
|
||||
# The rate limiter bucket is full. A second join should be denied.
|
||||
self.get_failure(
|
||||
self.handler.update_membership(
|
||||
requester=create_requester(self.chris),
|
||||
target=UserID.from_string(self.chris),
|
||||
room_id=self.room_id,
|
||||
action=Membership.JOIN,
|
||||
),
|
||||
LimitExceededError,
|
||||
)
|
||||
|
||||
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
|
||||
def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None:
|
||||
# The rate limiter has accumulated one token from Alice's join after the create
|
||||
# event. Alice should still be able to change her displayname.
|
||||
self.get_success(
|
||||
self.handler.update_membership(
|
||||
requester=create_requester(self.alice),
|
||||
target=UserID.from_string(self.alice),
|
||||
room_id=self.room_id,
|
||||
action=Membership.JOIN,
|
||||
content={"displayname": "Alice Cooper"},
|
||||
)
|
||||
)
|
||||
|
||||
# Still room in the limiter bucket. Chris's join should be accepted.
|
||||
self.get_success(
|
||||
self.handler.update_membership(
|
||||
requester=create_requester(self.chris),
|
||||
target=UserID.from_string(self.chris),
|
||||
room_id=self.room_id,
|
||||
action=Membership.JOIN,
|
||||
)
|
||||
)
|
||||
|
||||
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 1}})
|
||||
def test_remote_joins_contribute_to_rate_limit(self) -> None:
|
||||
# Join once, to fill the rate limiter bucket.
|
||||
#
|
||||
# To do this we have to mock the responses from the remote homeserver.
|
||||
# We also patch out a bunch of event checks on our end. All we're really
|
||||
# trying to check here is that remote joins will bump the rate limter when
|
||||
# they are persisted.
|
||||
create_event_source = {
|
||||
"auth_events": [],
|
||||
"content": {
|
||||
"creator": f"@creator:{self.OTHER_SERVER_NAME}",
|
||||
"room_version": self.hs.config.server.default_room_version.identifier,
|
||||
},
|
||||
"depth": 0,
|
||||
"origin_server_ts": 0,
|
||||
"prev_events": [],
|
||||
"room_id": self.intially_unjoined_room_id,
|
||||
"sender": f"@creator:{self.OTHER_SERVER_NAME}",
|
||||
"state_key": "",
|
||||
"type": EventTypes.Create,
|
||||
}
|
||||
self.add_hashes_and_signatures_from_other_server(
|
||||
create_event_source,
|
||||
self.hs.config.server.default_room_version,
|
||||
)
|
||||
create_event = FrozenEventV3(
|
||||
create_event_source,
|
||||
self.hs.config.server.default_room_version,
|
||||
{},
|
||||
None,
|
||||
)
|
||||
|
||||
join_event_source = {
|
||||
"auth_events": [create_event.event_id],
|
||||
"content": {"membership": "join"},
|
||||
"depth": 1,
|
||||
"origin_server_ts": 100,
|
||||
"prev_events": [create_event.event_id],
|
||||
"sender": self.bob,
|
||||
"state_key": self.bob,
|
||||
"room_id": self.intially_unjoined_room_id,
|
||||
"type": EventTypes.Member,
|
||||
}
|
||||
add_hashes_and_signatures(
|
||||
self.hs.config.server.default_room_version,
|
||||
join_event_source,
|
||||
self.hs.hostname,
|
||||
self.hs.signing_key,
|
||||
)
|
||||
join_event = FrozenEventV3(
|
||||
join_event_source,
|
||||
self.hs.config.server.default_room_version,
|
||||
{},
|
||||
None,
|
||||
)
|
||||
|
||||
mock_make_membership_event = Mock(
|
||||
return_value=make_awaitable(
|
||||
(
|
||||
self.OTHER_SERVER_NAME,
|
||||
join_event,
|
||||
self.hs.config.server.default_room_version,
|
||||
)
|
||||
)
|
||||
)
|
||||
mock_send_join = Mock(
|
||||
return_value=make_awaitable(
|
||||
SendJoinResult(
|
||||
join_event,
|
||||
self.OTHER_SERVER_NAME,
|
||||
state=[create_event],
|
||||
auth_chain=[create_event],
|
||||
partial_state=False,
|
||||
servers_in_room=[],
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
self.handler.federation_handler.federation_client,
|
||||
"make_membership_event",
|
||||
mock_make_membership_event,
|
||||
), patch.object(
|
||||
self.handler.federation_handler.federation_client,
|
||||
"send_join",
|
||||
mock_send_join,
|
||||
), patch(
|
||||
"synapse.event_auth._is_membership_change_allowed",
|
||||
return_value=None,
|
||||
), patch(
|
||||
"synapse.handlers.federation_event.check_state_dependent_auth_rules",
|
||||
return_value=None,
|
||||
):
|
||||
self.get_success(
|
||||
self.handler.update_membership(
|
||||
requester=create_requester(self.bob),
|
||||
target=UserID.from_string(self.bob),
|
||||
room_id=self.intially_unjoined_room_id,
|
||||
action=Membership.JOIN,
|
||||
remote_room_hosts=[self.OTHER_SERVER_NAME],
|
||||
)
|
||||
)
|
||||
|
||||
# Try to join as Chris. Should get denied.
|
||||
self.get_failure(
|
||||
self.handler.update_membership(
|
||||
requester=create_requester(self.chris),
|
||||
target=UserID.from_string(self.chris),
|
||||
room_id=self.intially_unjoined_room_id,
|
||||
action=Membership.JOIN,
|
||||
remote_room_hosts=[self.OTHER_SERVER_NAME],
|
||||
),
|
||||
LimitExceededError,
|
||||
)
|
||||
|
||||
# TODO: test that remote joins to a room are rate limited.
|
||||
# Could do this by setting the burst count to 1, then:
|
||||
# - remote-joining a room
|
||||
# - immediately leaving
|
||||
# - trying to remote-join again.
|
||||
|
||||
|
||||
class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
synapse.rest.client.login.register_servlets,
|
||||
synapse.rest.client.room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.handler = hs.get_room_member_handler()
|
||||
|
||||
# Create three users.
|
||||
self.alice = self.register_user("alice", "pass")
|
||||
self.alice_token = self.login("alice", "pass")
|
||||
self.bob = self.register_user("bob", "pass")
|
||||
self.bob_token = self.login("bob", "pass")
|
||||
self.chris = self.register_user("chris", "pass")
|
||||
self.chris_token = self.login("chris", "pass")
|
||||
|
||||
# Create a room on this homeserver.
|
||||
# Note that this counts as a
|
||||
self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
|
||||
self.intially_unjoined_room_id = "!example:otherhs"
|
||||
|
||||
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
|
||||
def test_local_users_joining_on_another_worker_contribute_to_rate_limit(
|
||||
self,
|
||||
) -> None:
|
||||
# The rate limiter has accumulated one token from Alice's join after the create
|
||||
# event.
|
||||
self.replicate()
|
||||
|
||||
# Spawn another worker and have bob join via it.
|
||||
worker_app = self.make_worker_hs(
|
||||
"synapse.app.generic_worker", extra_config={"worker_name": "other worker"}
|
||||
)
|
||||
worker_site = self._hs_to_site[worker_app]
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
worker_site,
|
||||
"POST",
|
||||
f"/_matrix/client/v3/rooms/{self.room_id}/join",
|
||||
access_token=self.bob_token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
|
||||
|
||||
# wait for join to arrive over replication
|
||||
self.replicate()
|
||||
|
||||
# Try to join as Chris on the worker. Should get denied because Alice
|
||||
# and Bob have both joined the room.
|
||||
self.get_failure(
|
||||
worker_app.get_room_member_handler().update_membership(
|
||||
requester=create_requester(self.chris),
|
||||
target=UserID.from_string(self.chris),
|
||||
room_id=self.room_id,
|
||||
action=Membership.JOIN,
|
||||
),
|
||||
LimitExceededError,
|
||||
)
|
||||
|
||||
# Try to join as Chris on the original worker. Should get denied because Alice
|
||||
# and Bob have both joined the room.
|
||||
self.get_failure(
|
||||
self.handler.update_membership(
|
||||
requester=create_requester(self.chris),
|
||||
target=UserID.from_string(self.chris),
|
||||
room_id=self.room_id,
|
||||
action=Membership.JOIN,
|
||||
),
|
||||
LimitExceededError,
|
||||
)
|
||||
@@ -1379,7 +1379,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
content=body,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
@@ -1434,7 +1434,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
content=body,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual("Bob's name", channel.json_body["displayname"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
@@ -1512,7 +1512,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
content={"password": "abc123", "admin": False},
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertFalse(channel.json_body["admin"])
|
||||
|
||||
@@ -1550,7 +1550,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# Admin user is not blocked by mau anymore
|
||||
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertFalse(channel.json_body["admin"])
|
||||
|
||||
@@ -1585,7 +1585,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
content=body,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
@@ -1626,7 +1626,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
content=body,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
|
||||
@@ -1666,7 +1666,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
content=body,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual("msisdn", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual("1234567890", channel.json_body["threepids"][0]["address"])
|
||||
@@ -2407,7 +2407,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
content={"password": "abc123"},
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.CREATED, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(201, channel.code, msg=channel.json_body)
|
||||
self.assertEqual("@bob:test", channel.json_body["name"])
|
||||
self.assertEqual("bob", channel.json_body["displayname"])
|
||||
|
||||
|
||||
@@ -11,10 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from email.parser import Parser
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
@@ -95,8 +95,10 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
"""
|
||||
body = {"type": "m.login.password", "user": username, "password": password}
|
||||
|
||||
channel = self.make_request("POST", "/_matrix/client/r0/login", body)
|
||||
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
|
||||
channel = self.make_request(
|
||||
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
|
||||
)
|
||||
self.assertEqual(channel.code, 403, channel.result)
|
||||
|
||||
def test_basic_password_reset(self) -> None:
|
||||
"""Test basic password reset flow"""
|
||||
@@ -345,7 +347,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
shorthand=False,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
|
||||
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
|
||||
# password reset confirm button
|
||||
@@ -360,7 +362,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
shorthand=False,
|
||||
content_is_form=True,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
|
||||
def _get_link_from_email(self) -> str:
|
||||
assert self.email_attempts, "No emails have been sent"
|
||||
@@ -388,7 +390,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
new_password: str,
|
||||
session_id: str,
|
||||
client_secret: str,
|
||||
expected_code: int = HTTPStatus.OK,
|
||||
expected_code: int = 200,
|
||||
) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
@@ -477,14 +479,16 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(memberships[0].room_id, room_id, memberships)
|
||||
|
||||
def deactivate(self, user_id: str, tok: str) -> None:
|
||||
request_data = {
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
"user": user_id,
|
||||
"password": "test",
|
||||
},
|
||||
"erase": False,
|
||||
}
|
||||
request_data = json.dumps(
|
||||
{
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
"user": user_id,
|
||||
"password": "test",
|
||||
},
|
||||
"erase": False,
|
||||
}
|
||||
)
|
||||
channel = self.make_request(
|
||||
"POST", "account/deactivate", request_data, access_token=tok
|
||||
)
|
||||
@@ -711,9 +715,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
self.assertEqual(
|
||||
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
|
||||
)
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
# Get user
|
||||
@@ -723,7 +725,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertFalse(channel.json_body["threepids"])
|
||||
|
||||
def test_delete_email(self) -> None:
|
||||
@@ -745,7 +747,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
{"medium": "email", "address": self.email},
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
|
||||
# Get user
|
||||
channel = self.make_request(
|
||||
@@ -754,7 +756,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertFalse(channel.json_body["threepids"])
|
||||
|
||||
def test_delete_email_if_disabled(self) -> None:
|
||||
@@ -779,9 +781,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
|
||||
)
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
# Get user
|
||||
@@ -791,7 +791,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
|
||||
|
||||
@@ -817,9 +817,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
self.assertEqual(
|
||||
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
|
||||
)
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
|
||||
|
||||
# Get user
|
||||
@@ -829,7 +827,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertFalse(channel.json_body["threepids"])
|
||||
|
||||
def test_no_valid_token(self) -> None:
|
||||
@@ -854,9 +852,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
self.assertEqual(
|
||||
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
|
||||
)
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
|
||||
|
||||
# Get user
|
||||
@@ -866,7 +862,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertFalse(channel.json_body["threepids"])
|
||||
|
||||
@override_config({"next_link_domain_whitelist": None})
|
||||
@@ -876,7 +872,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
"something@example.com",
|
||||
"some_secret",
|
||||
next_link="https://example.com/a/good/site",
|
||||
expect_code=HTTPStatus.OK,
|
||||
expect_code=200,
|
||||
)
|
||||
|
||||
@override_config({"next_link_domain_whitelist": None})
|
||||
@@ -888,7 +884,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
"something@example.com",
|
||||
"some_secret",
|
||||
next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
|
||||
expect_code=HTTPStatus.OK,
|
||||
expect_code=200,
|
||||
)
|
||||
|
||||
@override_config({"next_link_domain_whitelist": None})
|
||||
@@ -899,7 +895,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
"something@example.com",
|
||||
"some_secret",
|
||||
next_link="file:///host/path",
|
||||
expect_code=HTTPStatus.BAD_REQUEST,
|
||||
expect_code=400,
|
||||
)
|
||||
|
||||
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
|
||||
@@ -911,28 +907,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
"something@example.com",
|
||||
"some_secret",
|
||||
next_link=None,
|
||||
expect_code=HTTPStatus.OK,
|
||||
expect_code=200,
|
||||
)
|
||||
|
||||
self._request_token(
|
||||
"something@example.com",
|
||||
"some_secret",
|
||||
next_link="https://example.com/some/good/page",
|
||||
expect_code=HTTPStatus.OK,
|
||||
expect_code=200,
|
||||
)
|
||||
|
||||
self._request_token(
|
||||
"something@example.com",
|
||||
"some_secret",
|
||||
next_link="https://example.org/some/also/good/page",
|
||||
expect_code=HTTPStatus.OK,
|
||||
expect_code=200,
|
||||
)
|
||||
|
||||
self._request_token(
|
||||
"something@example.com",
|
||||
"some_secret",
|
||||
next_link="https://bad.example.org/some/bad/page",
|
||||
expect_code=HTTPStatus.BAD_REQUEST,
|
||||
expect_code=400,
|
||||
)
|
||||
|
||||
@override_config({"next_link_domain_whitelist": []})
|
||||
@@ -944,7 +940,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
"something@example.com",
|
||||
"some_secret",
|
||||
next_link="https://example.com/a/page",
|
||||
expect_code=HTTPStatus.BAD_REQUEST,
|
||||
expect_code=400,
|
||||
)
|
||||
|
||||
def _request_token(
|
||||
@@ -952,7 +948,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
email: str,
|
||||
client_secret: str,
|
||||
next_link: Optional[str] = None,
|
||||
expect_code: int = HTTPStatus.OK,
|
||||
expect_code: int = 200,
|
||||
) -> Optional[str]:
|
||||
"""Request a validation token to add an email address to a user's account
|
||||
|
||||
@@ -997,9 +993,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
b"account/3pid/email/requestToken",
|
||||
{"client_secret": client_secret, "email": email, "send_attempt": 1},
|
||||
)
|
||||
self.assertEqual(
|
||||
HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
|
||||
)
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(expected_errcode, channel.json_body["errcode"])
|
||||
self.assertEqual(expected_error, channel.json_body["error"])
|
||||
|
||||
@@ -1008,7 +1002,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
path = link.replace("https://example.com", "")
|
||||
|
||||
channel = self.make_request("GET", path, shorthand=False)
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
|
||||
def _get_link_from_email(self) -> str:
|
||||
assert self.email_attempts, "No emails have been sent"
|
||||
@@ -1058,7 +1052,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
|
||||
# Get user
|
||||
channel = self.make_request(
|
||||
@@ -1067,7 +1061,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
|
||||
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
|
||||
@@ -1098,7 +1092,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
"""Tests that not providing any MXID raises an error."""
|
||||
self._test_status(
|
||||
users=None,
|
||||
expected_status_code=HTTPStatus.BAD_REQUEST,
|
||||
expected_status_code=400,
|
||||
expected_errcode=Codes.MISSING_PARAM,
|
||||
)
|
||||
|
||||
@@ -1106,7 +1100,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
"""Tests that providing an invalid MXID raises an error."""
|
||||
self._test_status(
|
||||
users=["bad:test"],
|
||||
expected_status_code=HTTPStatus.BAD_REQUEST,
|
||||
expected_status_code=400,
|
||||
expected_errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
@@ -1292,7 +1286,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
def _test_status(
|
||||
self,
|
||||
users: Optional[List[str]],
|
||||
expected_status_code: int = HTTPStatus.OK,
|
||||
expected_status_code: int = 200,
|
||||
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
|
||||
expected_failures: Optional[List[str]] = None,
|
||||
expected_errcode: Optional[str] = None,
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
@@ -96,7 +97,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# We use deliberately a localpart under the length threshold so
|
||||
# that we can make sure that the check is done on the whole alias.
|
||||
request_data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
|
||||
data = {"room_alias_name": random_string(256 - len(self.hs.hostname))}
|
||||
request_data = json.dumps(data)
|
||||
channel = self.make_request(
|
||||
"POST", url, request_data, access_token=self.user_tok
|
||||
)
|
||||
@@ -108,7 +110,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
# Check with an alias of allowed length. There should already be
|
||||
# a test that ensures it works in test_register.py, but let's be
|
||||
# as cautious as possible here.
|
||||
request_data = {"room_alias_name": random_string(5)}
|
||||
data = {"room_alias_name": random_string(5)}
|
||||
request_data = json.dumps(data)
|
||||
channel = self.make_request(
|
||||
"POST", url, request_data, access_token=self.user_tok
|
||||
)
|
||||
@@ -141,7 +144,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Add an alias for the room, as the appservice
|
||||
alias = RoomAlias(f"asns-{random_string(5)}", self.hs.hostname).to_string()
|
||||
request_data = {"room_id": self.room_id}
|
||||
data = {"room_id": self.room_id}
|
||||
request_data = json.dumps(data)
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
@@ -189,7 +193,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
self.hs.hostname,
|
||||
)
|
||||
|
||||
request_data = {"aliases": [self.random_alias(alias_length)]}
|
||||
data = {"aliases": [self.random_alias(alias_length)]}
|
||||
request_data = json.dumps(data)
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT", url, request_data, access_token=self.user_tok
|
||||
@@ -201,7 +206,8 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
|
||||
) -> str:
|
||||
alias = self.random_alias(alias_length)
|
||||
url = "/_matrix/client/r0/directory/room/%s" % alias
|
||||
request_data = {"room_id": self.room_id}
|
||||
data = {"room_id": self.room_id}
|
||||
request_data = json.dumps(data)
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT", url, request_data, access_token=self.user_tok
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
@@ -50,11 +51,12 @@ class IdentityTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
room_id = channel.json_body["room_id"]
|
||||
|
||||
request_data = {
|
||||
params = {
|
||||
"id_server": "testis",
|
||||
"medium": "email",
|
||||
"address": "test@example.com",
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
request_url = ("/rooms/%s/invite" % (room_id)).encode("ascii")
|
||||
channel = self.make_request(
|
||||
b"POST", request_url, request_data, access_token=tok
|
||||
|
||||
@@ -11,9 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import time
|
||||
import urllib.parse
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import Mock
|
||||
from urllib.parse import urlencode
|
||||
@@ -261,20 +261,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
access_token = channel.json_body["access_token"]
|
||||
device_id = channel.json_body["device_id"]
|
||||
|
||||
# we should now be able to make requests with the access token
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# time passes
|
||||
self.reactor.advance(24 * 3600)
|
||||
|
||||
# ... and we should be soft-logouted
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||
self.assertEqual(channel.json_body["soft_logout"], True)
|
||||
|
||||
@@ -288,7 +288,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
# more requests with the expired token should still return a soft-logout
|
||||
self.reactor.advance(3600)
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||
self.assertEqual(channel.json_body["soft_logout"], True)
|
||||
|
||||
@@ -296,7 +296,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
self._delete_device(access_token_2, "kermit", "monkey", device_id)
|
||||
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||
self.assertEqual(channel.json_body["soft_logout"], False)
|
||||
|
||||
@@ -307,7 +307,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
b"DELETE", "devices/" + device_id, access_token=access_token
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
# check it's a UI-Auth fail
|
||||
self.assertEqual(
|
||||
set(channel.json_body.keys()),
|
||||
@@ -330,7 +330,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
access_token=access_token,
|
||||
content={"auth": auth},
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
@override_config({"session_lifetime": "24h"})
|
||||
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
|
||||
@@ -341,14 +341,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# we should now be able to make requests with the access token
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# time passes
|
||||
self.reactor.advance(24 * 3600)
|
||||
|
||||
# ... and we should be soft-logouted
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||
self.assertEqual(channel.json_body["soft_logout"], True)
|
||||
|
||||
@@ -367,14 +367,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# we should now be able to make requests with the access token
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# time passes
|
||||
self.reactor.advance(24 * 3600)
|
||||
|
||||
# ... and we should be soft-logouted
|
||||
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
||||
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
||||
self.assertEqual(channel.json_body["soft_logout"], True)
|
||||
|
||||
@@ -399,7 +399,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/v3/login",
|
||||
body,
|
||||
json.dumps(body).encode("utf8"),
|
||||
custom_headers=None,
|
||||
)
|
||||
|
||||
@@ -466,7 +466,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
def test_get_login_flows(self) -> None:
|
||||
"""GET /login should return password and SSO flows"""
|
||||
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
expected_flow_types = [
|
||||
"m.login.cas",
|
||||
@@ -494,14 +494,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
"""/login/sso/redirect should redirect to an identity picker"""
|
||||
# first hit the redirect url, which should redirect to our idp picker
|
||||
channel = self._make_sso_redirect_request(None)
|
||||
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
uri = location_headers[0]
|
||||
|
||||
# hitting that picker should give us some HTML
|
||||
channel = self.make_request("GET", uri)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# parse the form to check it has fields assumed elsewhere in this class
|
||||
html = channel.result["body"].decode("utf-8")
|
||||
@@ -530,7 +530,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
+ "&idp=cas",
|
||||
shorthand=False,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
cas_uri = location_headers[0]
|
||||
@@ -555,7 +555,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
||||
+ "&idp=saml",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
saml_uri = location_headers[0]
|
||||
@@ -579,7 +579,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
||||
+ "&idp=oidc",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
oidc_uri = location_headers[0]
|
||||
@@ -606,7 +606,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
|
||||
|
||||
# that should serve a confirmation page
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
content_type_headers = channel.headers.getRawHeaders("Content-Type")
|
||||
assert content_type_headers
|
||||
self.assertTrue(content_type_headers[-1].startswith("text/html"))
|
||||
@@ -634,7 +634,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
"/login",
|
||||
content={"type": "m.login.token", "token": login_token},
|
||||
)
|
||||
self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
|
||||
self.assertEqual(chan.code, 200, chan.result)
|
||||
self.assertEqual(chan.json_body["user_id"], "@user1:test")
|
||||
|
||||
def test_multi_sso_redirect_to_unknown(self) -> None:
|
||||
@@ -643,18 +643,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
"GET",
|
||||
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
self.assertEqual(channel.code, 400, channel.result)
|
||||
|
||||
def test_client_idp_redirect_to_unknown(self) -> None:
|
||||
"""If the client tries to pick an unknown IdP, return a 404"""
|
||||
channel = self._make_sso_redirect_request("xxx")
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
|
||||
self.assertEqual(channel.code, 404, channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||
|
||||
def test_client_idp_redirect_to_oidc(self) -> None:
|
||||
"""If the client pick a known IdP, redirect to it"""
|
||||
channel = self._make_sso_redirect_request("oidc")
|
||||
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
oidc_uri = location_headers[0]
|
||||
@@ -765,7 +765,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request("GET", cas_ticket_url)
|
||||
|
||||
# Test that the response is HTML.
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
content_type_header_value = ""
|
||||
for header in channel.result.get("headers", []):
|
||||
if header[0] == b"Content-Type":
|
||||
@@ -1246,7 +1246,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
# that should redirect to the username picker
|
||||
self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
picker_url = location_headers[0]
|
||||
@@ -1290,7 +1290,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
("Content-Length", str(len(content))),
|
||||
],
|
||||
)
|
||||
self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
|
||||
self.assertEqual(chan.code, 302, chan.result)
|
||||
location_headers = chan.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
|
||||
@@ -1300,7 +1300,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
path=location_headers[0],
|
||||
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
|
||||
)
|
||||
self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
|
||||
self.assertEqual(chan.code, 302, chan.result)
|
||||
location_headers = chan.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
|
||||
@@ -1325,5 +1325,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
"/login",
|
||||
content={"type": "m.login.token", "token": login_token},
|
||||
)
|
||||
self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
|
||||
self.assertEqual(chan.code, 200, chan.result)
|
||||
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
@@ -88,7 +89,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
def test_password_too_short(self) -> None:
|
||||
request_data = {"username": "kermit", "password": "shorty"}
|
||||
request_data = json.dumps({"username": "kermit", "password": "shorty"})
|
||||
channel = self.make_request("POST", self.register_url, request_data)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
@@ -99,7 +100,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
def test_password_no_digit(self) -> None:
|
||||
request_data = {"username": "kermit", "password": "longerpassword"}
|
||||
request_data = json.dumps({"username": "kermit", "password": "longerpassword"})
|
||||
channel = self.make_request("POST", self.register_url, request_data)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
@@ -110,7 +111,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
def test_password_no_symbol(self) -> None:
|
||||
request_data = {"username": "kermit", "password": "l0ngerpassword"}
|
||||
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword"})
|
||||
channel = self.make_request("POST", self.register_url, request_data)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
@@ -121,7 +122,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
def test_password_no_uppercase(self) -> None:
|
||||
request_data = {"username": "kermit", "password": "l0ngerpassword!"}
|
||||
request_data = json.dumps({"username": "kermit", "password": "l0ngerpassword!"})
|
||||
channel = self.make_request("POST", self.register_url, request_data)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
@@ -132,7 +133,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
def test_password_no_lowercase(self) -> None:
|
||||
request_data = {"username": "kermit", "password": "L0NGERPASSWORD!"}
|
||||
request_data = json.dumps({"username": "kermit", "password": "L0NGERPASSWORD!"})
|
||||
channel = self.make_request("POST", self.register_url, request_data)
|
||||
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
||||
@@ -143,7 +144,7 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
def test_password_compliant(self) -> None:
|
||||
request_data = {"username": "kermit", "password": "L0ngerpassword!"}
|
||||
request_data = json.dumps({"username": "kermit", "password": "L0ngerpassword!"})
|
||||
channel = self.make_request("POST", self.register_url, request_data)
|
||||
|
||||
# Getting a 401 here means the password has passed validation and the server has
|
||||
@@ -160,14 +161,16 @@ class PasswordPolicyTestCase(unittest.HomeserverTestCase):
|
||||
user_id = self.register_user("kermit", compliant_password)
|
||||
tok = self.login("kermit", compliant_password)
|
||||
|
||||
request_data = {
|
||||
"new_password": not_compliant_password,
|
||||
"auth": {
|
||||
"password": compliant_password,
|
||||
"type": LoginType.PASSWORD,
|
||||
"user": user_id,
|
||||
},
|
||||
}
|
||||
request_data = json.dumps(
|
||||
{
|
||||
"new_password": not_compliant_password,
|
||||
"auth": {
|
||||
"password": compliant_password,
|
||||
"type": LoginType.PASSWORD,
|
||||
"user": user_id,
|
||||
},
|
||||
}
|
||||
)
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/account/password",
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
@@ -61,10 +62,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.hs.get_datastores().main.services_cache.append(appservice)
|
||||
request_data = {
|
||||
"username": "as_user_kermit",
|
||||
"type": APP_SERVICE_REGISTRATION_TYPE,
|
||||
}
|
||||
request_data = json.dumps(
|
||||
{"username": "as_user_kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||
@@ -85,7 +85,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.hs.get_datastores().main.services_cache.append(appservice)
|
||||
request_data = {"username": "as_user_kermit"}
|
||||
request_data = json.dumps({"username": "as_user_kermit"})
|
||||
|
||||
channel = self.make_request(
|
||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||
@@ -95,7 +95,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_POST_appservice_registration_invalid(self) -> None:
|
||||
self.appservice = None # no application service exists
|
||||
request_data = {"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
|
||||
request_data = json.dumps(
|
||||
{"username": "kermit", "type": APP_SERVICE_REGISTRATION_TYPE}
|
||||
)
|
||||
channel = self.make_request(
|
||||
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
|
||||
)
|
||||
@@ -103,14 +105,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
|
||||
def test_POST_bad_password(self) -> None:
|
||||
request_data = {"username": "kermit", "password": 666}
|
||||
request_data = json.dumps({"username": "kermit", "password": 666})
|
||||
channel = self.make_request(b"POST", self.url, request_data)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"400", channel.result)
|
||||
self.assertEqual(channel.json_body["error"], "Invalid password")
|
||||
|
||||
def test_POST_bad_username(self) -> None:
|
||||
request_data = {"username": 777, "password": "monkey"}
|
||||
request_data = json.dumps({"username": 777, "password": "monkey"})
|
||||
channel = self.make_request(b"POST", self.url, request_data)
|
||||
|
||||
self.assertEqual(channel.result["code"], b"400", channel.result)
|
||||
@@ -119,12 +121,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
def test_POST_user_valid(self) -> None:
|
||||
user_id = "@kermit:test"
|
||||
device_id = "frogfone"
|
||||
request_data = {
|
||||
params = {
|
||||
"username": "kermit",
|
||||
"password": "monkey",
|
||||
"device_id": device_id,
|
||||
"auth": {"type": LoginType.DUMMY},
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
channel = self.make_request(b"POST", self.url, request_data)
|
||||
|
||||
det_data = {
|
||||
@@ -137,7 +140,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
@override_config({"enable_registration": False})
|
||||
def test_POST_disabled_registration(self) -> None:
|
||||
request_data = {"username": "kermit", "password": "monkey"}
|
||||
request_data = json.dumps({"username": "kermit", "password": "monkey"})
|
||||
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
|
||||
|
||||
channel = self.make_request(b"POST", self.url, request_data)
|
||||
@@ -185,12 +188,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
@override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
|
||||
def test_POST_ratelimiting(self) -> None:
|
||||
for i in range(0, 6):
|
||||
request_data = {
|
||||
params = {
|
||||
"username": "kermit" + str(i),
|
||||
"password": "monkey",
|
||||
"device_id": "frogfone",
|
||||
"auth": {"type": LoginType.DUMMY},
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
channel = self.make_request(b"POST", self.url, request_data)
|
||||
|
||||
if i == 5:
|
||||
@@ -230,7 +234,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
|
||||
# Request without auth to get flows and session
|
||||
channel = self.make_request(b"POST", self.url, params)
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
flows = channel.json_body["flows"]
|
||||
# Synapse adds a dummy stage to differentiate flows where otherwise one
|
||||
@@ -247,7 +251,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"token": token,
|
||||
"session": session,
|
||||
}
|
||||
channel = self.make_request(b"POST", self.url, params)
|
||||
request_data = json.dumps(params)
|
||||
channel = self.make_request(b"POST", self.url, request_data)
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
completed = channel.json_body["completed"]
|
||||
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
|
||||
@@ -257,7 +262,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"type": LoginType.DUMMY,
|
||||
"session": session,
|
||||
}
|
||||
channel = self.make_request(b"POST", self.url, params)
|
||||
request_data = json.dumps(params)
|
||||
channel = self.make_request(b"POST", self.url, request_data)
|
||||
det_data = {
|
||||
"user_id": f"@{username}:{self.hs.hostname}",
|
||||
"home_server": self.hs.hostname,
|
||||
@@ -284,7 +290,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"password": "monkey",
|
||||
}
|
||||
# Request without auth to get session
|
||||
channel = self.make_request(b"POST", self.url, params)
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Test with token param missing (invalid)
|
||||
@@ -292,21 +298,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"session": session,
|
||||
}
|
||||
channel = self.make_request(b"POST", self.url, params)
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
|
||||
self.assertEqual(channel.json_body["completed"], [])
|
||||
|
||||
# Test with non-string (invalid)
|
||||
params["auth"]["token"] = 1234
|
||||
channel = self.make_request(b"POST", self.url, params)
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
|
||||
self.assertEqual(channel.json_body["completed"], [])
|
||||
|
||||
# Test with unknown token (invalid)
|
||||
params["auth"]["token"] = "1234"
|
||||
channel = self.make_request(b"POST", self.url, params)
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||
self.assertEqual(channel.json_body["completed"], [])
|
||||
@@ -331,9 +337,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
params1: JsonDict = {"username": "bert", "password": "monkey"}
|
||||
params2: JsonDict = {"username": "ernie", "password": "monkey"}
|
||||
# Do 2 requests without auth to get two session IDs
|
||||
channel1 = self.make_request(b"POST", self.url, params1)
|
||||
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
session1 = channel1.json_body["session"]
|
||||
channel2 = self.make_request(b"POST", self.url, params2)
|
||||
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
session2 = channel2.json_body["session"]
|
||||
|
||||
# Use token with session1 and check `pending` is 1
|
||||
@@ -342,9 +348,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"token": token,
|
||||
"session": session1,
|
||||
}
|
||||
self.make_request(b"POST", self.url, params1)
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
# Repeat request to make sure pending isn't increased again
|
||||
self.make_request(b"POST", self.url, params1)
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
pending = self.get_success(
|
||||
store.db_pool.simple_select_one_onecol(
|
||||
"registration_tokens",
|
||||
@@ -360,14 +366,14 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"token": token,
|
||||
"session": session2,
|
||||
}
|
||||
channel = self.make_request(b"POST", self.url, params2)
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||
self.assertEqual(channel.json_body["completed"], [])
|
||||
|
||||
# Complete registration with session1
|
||||
params1["auth"]["type"] = LoginType.DUMMY
|
||||
self.make_request(b"POST", self.url, params1)
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
# Check pending=0 and completed=1
|
||||
res = self.get_success(
|
||||
store.db_pool.simple_select_one(
|
||||
@@ -380,7 +386,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(res["completed"], 1)
|
||||
|
||||
# Check auth still fails when using token with session2
|
||||
channel = self.make_request(b"POST", self.url, params2)
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||
self.assertEqual(channel.json_body["completed"], [])
|
||||
@@ -405,7 +411,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
params: JsonDict = {"username": "kermit", "password": "monkey"}
|
||||
# Request without auth to get session
|
||||
channel = self.make_request(b"POST", self.url, params)
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Check authentication fails with expired token
|
||||
@@ -414,7 +420,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"token": token,
|
||||
"session": session,
|
||||
}
|
||||
channel = self.make_request(b"POST", self.url, params)
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
|
||||
self.assertEqual(channel.json_body["completed"], [])
|
||||
@@ -429,7 +435,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# Check authentication succeeds
|
||||
channel = self.make_request(b"POST", self.url, params)
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
completed = channel.json_body["completed"]
|
||||
self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
|
||||
|
||||
@@ -454,9 +460,9 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
# Do 2 requests without auth to get two session IDs
|
||||
params1: JsonDict = {"username": "bert", "password": "monkey"}
|
||||
params2: JsonDict = {"username": "ernie", "password": "monkey"}
|
||||
channel1 = self.make_request(b"POST", self.url, params1)
|
||||
channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
session1 = channel1.json_body["session"]
|
||||
channel2 = self.make_request(b"POST", self.url, params2)
|
||||
channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
session2 = channel2.json_body["session"]
|
||||
|
||||
# Use token with both sessions
|
||||
@@ -465,18 +471,18 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"token": token,
|
||||
"session": session1,
|
||||
}
|
||||
self.make_request(b"POST", self.url, params1)
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
|
||||
params2["auth"] = {
|
||||
"type": LoginType.REGISTRATION_TOKEN,
|
||||
"token": token,
|
||||
"session": session2,
|
||||
}
|
||||
self.make_request(b"POST", self.url, params2)
|
||||
self.make_request(b"POST", self.url, json.dumps(params2))
|
||||
|
||||
# Complete registration with session1
|
||||
params1["auth"]["type"] = LoginType.DUMMY
|
||||
self.make_request(b"POST", self.url, params1)
|
||||
self.make_request(b"POST", self.url, json.dumps(params1))
|
||||
|
||||
# Check `result` of registration token stage for session1 is `True`
|
||||
result1 = self.get_success(
|
||||
@@ -544,7 +550,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Do request without auth to get a session ID
|
||||
params: JsonDict = {"username": "kermit", "password": "monkey"}
|
||||
channel = self.make_request(b"POST", self.url, params)
|
||||
channel = self.make_request(b"POST", self.url, json.dumps(params))
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Use token
|
||||
@@ -553,7 +559,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"token": token,
|
||||
"session": session,
|
||||
}
|
||||
self.make_request(b"POST", self.url, params)
|
||||
self.make_request(b"POST", self.url, json.dumps(params))
|
||||
|
||||
# Delete token
|
||||
self.get_success(
|
||||
@@ -821,7 +827,8 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||
admin_tok = self.login("admin", "adminpassword")
|
||||
|
||||
url = "/_synapse/admin/v1/account_validity/validity"
|
||||
request_data = {"user_id": user_id}
|
||||
params = {"user_id": user_id}
|
||||
request_data = json.dumps(params)
|
||||
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
|
||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||
|
||||
@@ -838,11 +845,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||
admin_tok = self.login("admin", "adminpassword")
|
||||
|
||||
url = "/_synapse/admin/v1/account_validity/validity"
|
||||
request_data = {
|
||||
params = {
|
||||
"user_id": user_id,
|
||||
"expiration_ts": 0,
|
||||
"enable_renewal_emails": False,
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
|
||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||
|
||||
@@ -862,11 +870,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
|
||||
admin_tok = self.login("admin", "adminpassword")
|
||||
|
||||
url = "/_synapse/admin/v1/account_validity/validity"
|
||||
request_data = {
|
||||
params = {
|
||||
"user_id": user_id,
|
||||
"expiration_ts": 0,
|
||||
"enable_renewal_emails": False,
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
|
||||
self.assertEqual(channel.result["code"], b"200", channel.result)
|
||||
|
||||
@@ -1032,14 +1041,16 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
(user_id, tok) = self.create_user()
|
||||
|
||||
request_data = {
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
"user": user_id,
|
||||
"password": "monkey",
|
||||
},
|
||||
"erase": False,
|
||||
}
|
||||
request_data = json.dumps(
|
||||
{
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
"user": user_id,
|
||||
"password": "monkey",
|
||||
},
|
||||
"erase": False,
|
||||
}
|
||||
)
|
||||
channel = self.make_request(
|
||||
"POST", "account/deactivate", request_data, access_token=tok
|
||||
)
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
@@ -75,7 +77,10 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def _assert_status(self, response_status: int, data: JsonDict) -> None:
|
||||
channel = self.make_request(
|
||||
"POST", self.report_path, data, access_token=self.other_user_tok
|
||||
"POST",
|
||||
self.report_path,
|
||||
json.dumps(data),
|
||||
access_token=self.other_user_tok,
|
||||
)
|
||||
self.assertEqual(
|
||||
response_status, int(channel.result["code"]), msg=channel.result["body"]
|
||||
|
||||
+189
-230
File diff suppressed because it is too large
Load Diff
@@ -606,10 +606,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
|
||||
self._check_unread_count(1)
|
||||
|
||||
# Send a read receipt to tell the server we've read the latest event.
|
||||
body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8")
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/rooms/{self.room_id}/read_markers",
|
||||
{ReceiptTypes.READ: res["event_id"]},
|
||||
body,
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
@@ -136,7 +136,7 @@ class RestHelper:
|
||||
self.site,
|
||||
"POST",
|
||||
path,
|
||||
content,
|
||||
json.dumps(content).encode("utf8"),
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
|
||||
@@ -210,7 +210,7 @@ class RestHelper:
|
||||
self.site,
|
||||
"POST",
|
||||
path,
|
||||
data,
|
||||
json.dumps(data).encode("utf8"),
|
||||
)
|
||||
|
||||
assert (
|
||||
@@ -309,7 +309,7 @@ class RestHelper:
|
||||
self.site,
|
||||
"PUT",
|
||||
path,
|
||||
data,
|
||||
json.dumps(data).encode("utf8"),
|
||||
)
|
||||
|
||||
assert (
|
||||
@@ -392,7 +392,7 @@ class RestHelper:
|
||||
self.site,
|
||||
"PUT",
|
||||
path,
|
||||
content or {},
|
||||
json.dumps(content or {}).encode("utf8"),
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
|
||||
|
||||
@@ -126,9 +126,7 @@ class _TestImage:
|
||||
expected_scaled: The expected bytes from scaled thumbnailing, or None if
|
||||
test should just check for a valid image returned.
|
||||
expected_found: True if the file should exist on the server, or False if
|
||||
a 404/400 is expected.
|
||||
unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
|
||||
False if the thumbnailing should succeed or a normal 404 is expected.
|
||||
a 404 is expected.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
@@ -137,7 +135,6 @@ class _TestImage:
|
||||
expected_cropped: Optional[bytes] = None
|
||||
expected_scaled: Optional[bytes] = None
|
||||
expected_found: bool = True
|
||||
unable_to_thumbnail: bool = False
|
||||
|
||||
|
||||
@parameterized_class(
|
||||
@@ -195,7 +192,6 @@ class _TestImage:
|
||||
b"image/gif",
|
||||
b".gif",
|
||||
expected_found=False,
|
||||
unable_to_thumbnail=True,
|
||||
),
|
||||
),
|
||||
],
|
||||
@@ -370,29 +366,18 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
def test_thumbnail_crop(self) -> None:
|
||||
"""Test that a cropped remote thumbnail is available."""
|
||||
self._test_thumbnail(
|
||||
"crop",
|
||||
self.test_image.expected_cropped,
|
||||
expected_found=self.test_image.expected_found,
|
||||
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
|
||||
"crop", self.test_image.expected_cropped, self.test_image.expected_found
|
||||
)
|
||||
|
||||
def test_thumbnail_scale(self) -> None:
|
||||
"""Test that a scaled remote thumbnail is available."""
|
||||
self._test_thumbnail(
|
||||
"scale",
|
||||
self.test_image.expected_scaled,
|
||||
expected_found=self.test_image.expected_found,
|
||||
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
|
||||
"scale", self.test_image.expected_scaled, self.test_image.expected_found
|
||||
)
|
||||
|
||||
def test_invalid_type(self) -> None:
|
||||
"""An invalid thumbnail type is never available."""
|
||||
self._test_thumbnail(
|
||||
"invalid",
|
||||
None,
|
||||
expected_found=False,
|
||||
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
|
||||
)
|
||||
self._test_thumbnail("invalid", None, False)
|
||||
|
||||
@unittest.override_config(
|
||||
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
|
||||
@@ -401,12 +386,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
"""
|
||||
Override the config to generate only scaled thumbnails, but request a cropped one.
|
||||
"""
|
||||
self._test_thumbnail(
|
||||
"crop",
|
||||
None,
|
||||
expected_found=False,
|
||||
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
|
||||
)
|
||||
self._test_thumbnail("crop", None, False)
|
||||
|
||||
@unittest.override_config(
|
||||
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
|
||||
@@ -415,22 +395,14 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
"""
|
||||
Override the config to generate only cropped thumbnails, but request a scaled one.
|
||||
"""
|
||||
self._test_thumbnail(
|
||||
"scale",
|
||||
None,
|
||||
expected_found=False,
|
||||
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
|
||||
)
|
||||
self._test_thumbnail("scale", None, False)
|
||||
|
||||
def test_thumbnail_repeated_thumbnail(self) -> None:
|
||||
"""Test that fetching the same thumbnail works, and deleting the on disk
|
||||
thumbnail regenerates it.
|
||||
"""
|
||||
self._test_thumbnail(
|
||||
"scale",
|
||||
self.test_image.expected_scaled,
|
||||
expected_found=self.test_image.expected_found,
|
||||
unable_to_thumbnail=self.test_image.unable_to_thumbnail,
|
||||
"scale", self.test_image.expected_scaled, self.test_image.expected_found
|
||||
)
|
||||
|
||||
if not self.test_image.expected_found:
|
||||
@@ -487,24 +459,8 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
def _test_thumbnail(
|
||||
self,
|
||||
method: str,
|
||||
expected_body: Optional[bytes],
|
||||
expected_found: bool,
|
||||
unable_to_thumbnail: bool = False,
|
||||
self, method: str, expected_body: Optional[bytes], expected_found: bool
|
||||
) -> None:
|
||||
"""Test the given thumbnailing method works as expected.
|
||||
|
||||
Args:
|
||||
method: The thumbnailing method to use (crop, scale).
|
||||
expected_body: The expected bytes from thumbnailing, or None if
|
||||
test should just check for a valid image.
|
||||
expected_found: True if the file should exist on the server, or False if
|
||||
a 404/400 is expected.
|
||||
unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or
|
||||
False if the thumbnailing should succeed or a normal 404 is expected.
|
||||
"""
|
||||
|
||||
params = "?width=32&height=32&method=" + method
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
@@ -540,16 +496,6 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
else:
|
||||
# ensure that the result is at least some valid image
|
||||
Image.open(BytesIO(channel.result["body"]))
|
||||
elif unable_to_thumbnail:
|
||||
# A 400 with a JSON body.
|
||||
self.assertEqual(channel.code, 400)
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{
|
||||
"errcode": "M_UNKNOWN",
|
||||
"error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)",
|
||||
},
|
||||
)
|
||||
else:
|
||||
# A 404 with a JSON body.
|
||||
self.assertEqual(channel.code, 404)
|
||||
|
||||
@@ -369,7 +369,7 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
state_dict_ids = cache_entry.value
|
||||
|
||||
self.assertEqual(cache_entry.full, False)
|
||||
self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
|
||||
self.assertEqual(cache_entry.known_absent, set())
|
||||
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
|
||||
|
||||
############################################
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user