Compare commits
29 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3bd26733d7 | |||
| 90fa2026ba | |||
| 55ac419b63 | |||
| 047db4da1c | |||
| 88c4e7369d | |||
| a77f35144f | |||
| 1671f8772d | |||
| b4aad3604a | |||
| 51be04b918 | |||
| 4b6711803d | |||
| 87c230c27c | |||
| d56202b038 | |||
| 8533c8b03d | |||
| fb0ffa9676 | |||
| 9297d040a7 | |||
| 7e91107be1 | |||
| 1d11b452b7 | |||
| a511a890d7 | |||
| 61fd2a8f59 | |||
| 31b125ccec | |||
| 11282ade1d | |||
| 1fbe0316a9 | |||
| 106959b3cf | |||
| 2ffaf30803 | |||
| b4461e7d8a | |||
| 594a07ede4 | |||
| 1103c5fe8a | |||
| f3f0ab10fe | |||
| 6adb89ff00 |
@@ -21,7 +21,7 @@ python -m synapse.app.homeserver --generate-keys -c .ci/sqlite-config.yaml
|
||||
echo "--- Prepare test database"
|
||||
|
||||
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
|
||||
scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
|
||||
update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
|
||||
|
||||
# Run the export-data command on the sqlite test database
|
||||
python -m synapse.app.admin_cmd -c .ci/sqlite-config.yaml export-data @anon-20191002_181700-832:localhost:8800 \
|
||||
@@ -41,7 +41,7 @@ fi
|
||||
|
||||
# Port the SQLite databse to postgres so we can check command works against postgres
|
||||
echo "+++ Port SQLite3 databse to postgres"
|
||||
scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
|
||||
synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
|
||||
|
||||
# Run the export-data command on postgres database
|
||||
python -m synapse.app.admin_cmd -c .ci/postgres-config.yaml export-data @anon-20191002_181700-832:localhost:8800 \
|
||||
|
||||
@@ -25,17 +25,19 @@ python -m synapse.app.homeserver --generate-keys -c .ci/sqlite-config.yaml
|
||||
echo "--- Prepare test database"
|
||||
|
||||
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
|
||||
scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
|
||||
update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
|
||||
|
||||
# Create the PostgreSQL database.
|
||||
.ci/scripts/postgres_exec.py "CREATE DATABASE synapse"
|
||||
|
||||
echo "+++ Run synapse_port_db against test database"
|
||||
coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
|
||||
# TODO: this invocation of synapse_port_db (and others below) used to be prepended with `coverage run`,
|
||||
# but coverage seems unable to find the entrypoints installed by `pip install -e .`.
|
||||
synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
|
||||
|
||||
# We should be able to run twice against the same database.
|
||||
echo "+++ Run synapse_port_db a second time"
|
||||
coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
|
||||
synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
|
||||
|
||||
#####
|
||||
|
||||
@@ -46,7 +48,7 @@ echo "--- Prepare empty SQLite database"
|
||||
# we do this by deleting the sqlite db, and then doing the same again.
|
||||
rm .ci/test_db.db
|
||||
|
||||
scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
|
||||
update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
|
||||
|
||||
# re-create the PostgreSQL database.
|
||||
.ci/scripts/postgres_exec.py \
|
||||
@@ -54,4 +56,4 @@ scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-b
|
||||
"CREATE DATABASE synapse"
|
||||
|
||||
echo "+++ Run synapse_port_db against empty database"
|
||||
coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
|
||||
synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
# things to include
|
||||
!docker
|
||||
!scripts
|
||||
!synapse
|
||||
!MANIFEST.in
|
||||
!README.rst
|
||||
|
||||
@@ -31,7 +31,7 @@ jobs:
|
||||
# if we're running from a tag, get the full list of distros; otherwise just use debian:sid
|
||||
dists='["debian:sid"]'
|
||||
if [[ $GITHUB_REF == refs/tags/* ]]; then
|
||||
dists=$(scripts-dev/build_debian_packages --show-dists-json)
|
||||
dists=$(scripts-dev/build_debian_packages.py --show-dists-json)
|
||||
fi
|
||||
echo "::set-output name=distros::$dists"
|
||||
# map the step outputs to job outputs
|
||||
@@ -74,7 +74,7 @@ jobs:
|
||||
# see https://github.com/docker/build-push-action/issues/252
|
||||
# for the cache magic here
|
||||
run: |
|
||||
./src/scripts-dev/build_debian_packages \
|
||||
./src/scripts-dev/build_debian_packages.py \
|
||||
--docker-build-arg=--cache-from=type=local,src=/tmp/.buildx-cache \
|
||||
--docker-build-arg=--cache-to=type=local,mode=max,dest=/tmp/.buildx-cache-new \
|
||||
--docker-build-arg=--progress=plain \
|
||||
|
||||
@@ -16,7 +16,7 @@ jobs:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/setup-python@v2
|
||||
- run: pip install -e .
|
||||
- run: scripts-dev/generate_sample_config --check
|
||||
- run: scripts-dev/generate_sample_config.sh --check
|
||||
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
@@ -51,7 +51,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
- uses: actions/setup-python@v2
|
||||
- run: "pip install 'towncrier>=18.6.0rc1'"
|
||||
- run: scripts-dev/check-newsfragment
|
||||
- run: scripts-dev/check-newsfragment.sh
|
||||
env:
|
||||
PULL_REQUEST_NUMBER: ${{ github.event.number }}
|
||||
|
||||
@@ -376,7 +376,7 @@ jobs:
|
||||
# Run Complement
|
||||
- run: |
|
||||
set -o pipefail
|
||||
go test -v -json -p 1 -tags synapse_blacklist,msc2403 ./tests/... 2>&1 | gotestfmt
|
||||
go test -v -json -p 1 -tags synapse_blacklist,msc2403,msc2716,msc3030 ./tests/... 2>&1 | gotestfmt
|
||||
shell: bash
|
||||
name: Run Complement Tests
|
||||
env:
|
||||
|
||||
@@ -17,7 +17,6 @@ recursive-include synapse/storage *.txt
|
||||
recursive-include synapse/storage *.md
|
||||
|
||||
recursive-include docs *
|
||||
recursive-include scripts *
|
||||
recursive-include scripts-dev *
|
||||
recursive-include synapse *.pyi
|
||||
recursive-include tests *.py
|
||||
@@ -53,5 +52,4 @@ prune contrib
|
||||
prune debian
|
||||
prune demo/etc
|
||||
prune docker
|
||||
prune snap
|
||||
prune stubs
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
Simplify the `ApplicationService` class' set of public methods related to interest checking.
|
||||
@@ -0,0 +1 @@
|
||||
Fix complexity checking config example in [Resource Constrained Devices](https://matrix-org.github.io/synapse/v1.54/other/running_synapse_on_single_board_computers.html) docs page.
|
||||
@@ -0,0 +1 @@
|
||||
Use the proper serialization format for bundled thread aggregations. The bug has existed since Synapse v1.48.0.
|
||||
@@ -0,0 +1 @@
|
||||
Limit the size of `aggregation_key` on annotations.
|
||||
@@ -0,0 +1 @@
|
||||
Add type hints to `tests/rest/client`.
|
||||
@@ -0,0 +1 @@
|
||||
Refactor the tests for event relations.
|
||||
@@ -0,0 +1 @@
|
||||
Move scripts to Synapse package and expose as setuptools entry points.
|
||||
@@ -0,0 +1 @@
|
||||
Fix data validation to compare to lists, not sequences.
|
||||
@@ -0,0 +1 @@
|
||||
Remove unused mocks from `test_typing`.
|
||||
@@ -0,0 +1 @@
|
||||
Give `scripts-dev` scripts suffixes for neater CI config.
|
||||
@@ -0,0 +1 @@
|
||||
Move the snapcraft configuration file to `contrib`.
|
||||
@@ -0,0 +1 @@
|
||||
Enable [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) Complement tests in CI.
|
||||
@@ -0,0 +1 @@
|
||||
Enable [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) Complement tests in CI.
|
||||
@@ -0,0 +1 @@
|
||||
Add type hints to `tests/rest`.
|
||||
@@ -0,0 +1 @@
|
||||
Prune unused jobs from `tox` config.
|
||||
@@ -0,0 +1 @@
|
||||
Avoid generating state groups for local out-of-band leaves.
|
||||
@@ -0,0 +1 @@
|
||||
Avoid trying to calculate the state at outlier events.
|
||||
@@ -0,0 +1 @@
|
||||
Fix some type annotations.
|
||||
@@ -20,7 +20,7 @@ apps:
|
||||
generate-config:
|
||||
command: generate_config
|
||||
generate-signing-key:
|
||||
command: generate_signing_key.py
|
||||
command: generate_signing_key
|
||||
register-new-matrix-user:
|
||||
command: register_new_matrix_user
|
||||
plugs: [network]
|
||||
@@ -46,7 +46,6 @@ RUN \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy just what we need to pip install
|
||||
COPY scripts /synapse/scripts/
|
||||
COPY MANIFEST.in README.rst setup.py synctl /synapse/
|
||||
COPY synapse/__init__.py /synapse/synapse/__init__.py
|
||||
COPY synapse/python_dependencies.py /synapse/synapse/python_dependencies.py
|
||||
|
||||
+1
-1
@@ -172,6 +172,6 @@ frobber:
|
||||
```
|
||||
|
||||
Note that the sample configuration is generated from the synapse code
|
||||
and is maintained by a script, `scripts-dev/generate_sample_config`.
|
||||
and is maintained by a script, `scripts-dev/generate_sample_config.sh`.
|
||||
Making sure that the output from this script matches the desired format
|
||||
is left as an exercise for the reader!
|
||||
|
||||
@@ -158,9 +158,9 @@ same as integers.
|
||||
There are three separate aspects to this:
|
||||
|
||||
* Any new boolean column must be added to the `BOOLEAN_COLUMNS` list in
|
||||
`scripts/synapse_port_db`. This tells the port script to cast the integer
|
||||
value from SQLite to a boolean before writing the value to the postgres
|
||||
database.
|
||||
`synapse/_scripts/synapse_port_db.py`. This tells the port script to cast
|
||||
the integer value from SQLite to a boolean before writing the value to the
|
||||
postgres database.
|
||||
|
||||
* Before SQLite 3.23, `TRUE` and `FALSE` were not recognised as constants by
|
||||
SQLite, and the `IS [NOT] TRUE`/`IS [NOT] FALSE` operators were not
|
||||
|
||||
@@ -31,28 +31,29 @@ Anything that requires modifying the device list [#7721](https://github.com/matr
|
||||
Put the below in a new file at /etc/matrix-synapse/conf.d/sbc.yaml to override the defaults in homeserver.yaml.
|
||||
|
||||
```
|
||||
# Set to false to disable presence tracking on this homeserver.
|
||||
# Disable presence tracking, which is currently fairly resource intensive
|
||||
# More info: https://github.com/matrix-org/synapse/issues/9478
|
||||
use_presence: false
|
||||
|
||||
# When this is enabled, the room "complexity" will be checked before a user
|
||||
# joins a new remote room. If it is above the complexity limit, the server will
|
||||
# disallow joining, or will instantly leave.
|
||||
# Set a small complexity limit, preventing users from joining large rooms
|
||||
# which may be resource-intensive to remain a part of.
|
||||
#
|
||||
# Note that this will not prevent users from joining smaller rooms that
|
||||
# eventually become complex.
|
||||
limit_remote_rooms:
|
||||
# Uncomment to enable room complexity checking.
|
||||
#enabled: true
|
||||
enabled: true
|
||||
complexity: 3.0
|
||||
|
||||
# Database configuration
|
||||
database:
|
||||
# Use postgres for the best performance
|
||||
name: psycopg2
|
||||
args:
|
||||
user: matrix-synapse
|
||||
# Generate a long, secure one with a password manager
|
||||
# Generate a long, secure password using a password manager
|
||||
password: hunter2
|
||||
database: matrix-synapse
|
||||
host: localhost
|
||||
cp_min: 5
|
||||
cp_max: 10
|
||||
```
|
||||
|
||||
Currently the complexity is measured by [current_state_events / 500](https://github.com/matrix-org/synapse/blob/v1.20.1/synapse/storage/databases/main/events_worker.py#L986). You can find join times and your most complex rooms like this:
|
||||
|
||||
@@ -12,7 +12,7 @@ UPDATE users SET admin = 1 WHERE name = '@foo:bar.com';
|
||||
```
|
||||
|
||||
A new server admin user can also be created using the `register_new_matrix_user`
|
||||
command. This is a script that is located in the `scripts/` directory, or possibly
|
||||
command. This is a script that is distributed as part of synapse. It is possibly
|
||||
already on your `$PATH` depending on how Synapse was installed.
|
||||
|
||||
Finding your user's `access_token` is client-dependent, but will usually be shown in the client's settings.
|
||||
|
||||
@@ -11,7 +11,7 @@ local_partial_types = True
|
||||
no_implicit_optional = True
|
||||
|
||||
files =
|
||||
scripts-dev/sign_json,
|
||||
scripts-dev/,
|
||||
setup.py,
|
||||
synapse/,
|
||||
tests/
|
||||
@@ -23,6 +23,20 @@ files =
|
||||
# https://docs.python.org/3/library/re.html#re.X
|
||||
exclude = (?x)
|
||||
^(
|
||||
|scripts-dev/build_debian_packages.py
|
||||
|scripts-dev/check_signature.py
|
||||
|scripts-dev/definitions.py
|
||||
|scripts-dev/federation_client.py
|
||||
|scripts-dev/hash_history.py
|
||||
|scripts-dev/list_url_patterns.py
|
||||
|scripts-dev/release.py
|
||||
|scripts-dev/tail-synapse.py
|
||||
|
||||
|synapse/_scripts/export_signing_key.py
|
||||
|synapse/_scripts/move_remote_media_to_new_store.py
|
||||
|synapse/_scripts/synapse_port_db.py
|
||||
|synapse/_scripts/update_synapse_database.py
|
||||
|
||||
|synapse/storage/databases/__init__.py
|
||||
|synapse/storage/databases/main/__init__.py
|
||||
|synapse/storage/databases/main/cache.py
|
||||
@@ -74,15 +88,7 @@ exclude = (?x)
|
||||
|tests/push/test_http.py
|
||||
|tests/push/test_presentable_names.py
|
||||
|tests/push/test_push_rule_evaluator.py
|
||||
|tests/rest/client/test_account.py
|
||||
|tests/rest/client/test_filter.py
|
||||
|tests/rest/client/test_report_event.py
|
||||
|tests/rest/client/test_rooms.py
|
||||
|tests/rest/client/test_third_party_rules.py
|
||||
|tests/rest/client/test_transactions.py
|
||||
|tests/rest/client/test_typing.py
|
||||
|tests/rest/key/v2/test_remote_key_resource.py
|
||||
|tests/rest/media/v1/test_base.py
|
||||
|tests/rest/media/v1/test_media_storage.py
|
||||
|tests/rest/media/v1/test_url_preview.py
|
||||
|tests/scripts/test_new_matrix_user.py
|
||||
@@ -246,10 +252,7 @@ disallow_untyped_defs = True
|
||||
[mypy-tests.storage.test_user_directory]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.rest.admin.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.rest.client.*]
|
||||
[mypy-tests.rest.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.federation.transport.test_client]
|
||||
|
||||
@@ -71,4 +71,4 @@ fi
|
||||
|
||||
# Run the tests!
|
||||
echo "Images built; running complement"
|
||||
go test -v -tags synapse_blacklist,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/...
|
||||
go test -v -tags synapse_blacklist,msc2403,msc2716,msc3030 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/...
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# Update/check the docs/sample_config.yaml
|
||||
|
||||
set -e
|
||||
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
SAMPLE_CONFIG="docs/sample_config.yaml"
|
||||
SAMPLE_LOG_CONFIG="docs/sample_log_config.yaml"
|
||||
|
||||
check() {
|
||||
diff -u "$SAMPLE_LOG_CONFIG" <(./scripts/generate_log_config) >/dev/null || return 1
|
||||
}
|
||||
|
||||
if [ "$1" == "--check" ]; then
|
||||
diff -u "$SAMPLE_CONFIG" <(./scripts/generate_config --header-file docs/.sample_config_header.yaml) >/dev/null || {
|
||||
echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2
|
||||
exit 1
|
||||
}
|
||||
diff -u "$SAMPLE_LOG_CONFIG" <(./scripts/generate_log_config) >/dev/null || {
|
||||
echo -e "\e[1m\e[31m$SAMPLE_LOG_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2
|
||||
exit 1
|
||||
}
|
||||
else
|
||||
./scripts/generate_config --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG"
|
||||
./scripts/generate_log_config -o "$SAMPLE_LOG_CONFIG"
|
||||
fi
|
||||
Executable
+28
@@ -0,0 +1,28 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# Update/check the docs/sample_config.yaml
|
||||
|
||||
set -e
|
||||
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
SAMPLE_CONFIG="docs/sample_config.yaml"
|
||||
SAMPLE_LOG_CONFIG="docs/sample_log_config.yaml"
|
||||
|
||||
check() {
|
||||
diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || return 1
|
||||
}
|
||||
|
||||
if [ "$1" == "--check" ]; then
|
||||
diff -u "$SAMPLE_CONFIG" <(synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml) >/dev/null || {
|
||||
echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config.sh\`.\e[0m" >&2
|
||||
exit 1
|
||||
}
|
||||
diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || {
|
||||
echo -e "\e[1m\e[31m$SAMPLE_LOG_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config.sh\`.\e[0m" >&2
|
||||
exit 1
|
||||
}
|
||||
else
|
||||
synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG"
|
||||
synapse/_scripts/generate_log_config.py -o "$SAMPLE_LOG_CONFIG"
|
||||
fi
|
||||
@@ -84,16 +84,7 @@ else
|
||||
files=(
|
||||
"synapse" "docker" "tests"
|
||||
# annoyingly, black doesn't find these so we have to list them
|
||||
"scripts/export_signing_key"
|
||||
"scripts/generate_config"
|
||||
"scripts/generate_log_config"
|
||||
"scripts/hash_password"
|
||||
"scripts/register_new_matrix_user"
|
||||
"scripts/synapse_port_db"
|
||||
"scripts/update_synapse_database"
|
||||
"scripts-dev"
|
||||
"scripts-dev/build_debian_packages"
|
||||
"scripts-dev/sign_json"
|
||||
"contrib" "synctl" "setup.py" "synmark" "stubs" ".ci"
|
||||
)
|
||||
fi
|
||||
|
||||
@@ -147,7 +147,7 @@ python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG"
|
||||
|
||||
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
|
||||
echo "Running db background jobs..."
|
||||
scripts/update_synapse_database --database-config --run-background-updates "$SQLITE_CONFIG"
|
||||
synapse/_scripts/update_synapse_database.py --database-config --run-background-updates "$SQLITE_CONFIG"
|
||||
|
||||
# Create the PostgreSQL database.
|
||||
echo "Creating postgres database..."
|
||||
@@ -156,10 +156,10 @@ createdb --lc-collate=C --lc-ctype=C --template=template0 "$POSTGRES_DB_NAME"
|
||||
echo "Copying data from SQLite3 to Postgres with synapse_port_db..."
|
||||
if [ -z "$COVERAGE" ]; then
|
||||
# No coverage needed
|
||||
scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
|
||||
synapse/_scripts/synapse_port_db.py --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
|
||||
else
|
||||
# Coverage desired
|
||||
coverage run scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
|
||||
coverage run synapse/_scripts/synapse_port_db.py --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
|
||||
fi
|
||||
|
||||
# Delete schema_version, applied_schema_deltas and applied_module_schemas tables
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse._scripts.register_new_matrix_user import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,19 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse._scripts.review_recent_signups import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,45 +0,0 @@
|
||||
#!/usr/bin/env perl
|
||||
|
||||
use strict;
|
||||
use warnings;
|
||||
|
||||
use JSON::XS;
|
||||
use LWP::UserAgent;
|
||||
use URI::Escape;
|
||||
|
||||
if (@ARGV < 4) {
|
||||
die "usage: $0 <homeserver url> <access_token> <room_id|room_alias> <group_id>\n";
|
||||
}
|
||||
|
||||
my ($hs, $access_token, $room_id, $group_id) = @ARGV;
|
||||
my $ua = LWP::UserAgent->new();
|
||||
$ua->timeout(10);
|
||||
|
||||
if ($room_id =~ /^#/) {
|
||||
$room_id = uri_escape($room_id);
|
||||
$room_id = decode_json($ua->get("${hs}/_matrix/client/r0/directory/room/${room_id}?access_token=${access_token}")->decoded_content)->{room_id};
|
||||
}
|
||||
|
||||
my $room_users = [ keys %{decode_json($ua->get("${hs}/_matrix/client/r0/rooms/${room_id}/joined_members?access_token=${access_token}")->decoded_content)->{joined}} ];
|
||||
my $group_users = [
|
||||
(map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/users?access_token=${access_token}" )->decoded_content)->{chunk}}),
|
||||
(map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/invited_users?access_token=${access_token}" )->decoded_content)->{chunk}}),
|
||||
];
|
||||
|
||||
die "refusing to sync from empty room" unless (@$room_users);
|
||||
die "refusing to sync to empty group" unless (@$group_users);
|
||||
|
||||
my $diff = {};
|
||||
foreach my $user (@$room_users) { $diff->{$user}++ }
|
||||
foreach my $user (@$group_users) { $diff->{$user}-- }
|
||||
|
||||
foreach my $user (keys %$diff) {
|
||||
if ($diff->{$user} == 1) {
|
||||
warn "inviting $user";
|
||||
print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/invite/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
|
||||
}
|
||||
elsif ($diff->{$user} == -1) {
|
||||
warn "removing $user";
|
||||
print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/remove/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import glob
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
@@ -153,8 +152,19 @@ setup(
|
||||
python_requires="~=3.7",
|
||||
entry_points={
|
||||
"console_scripts": [
|
||||
# Application
|
||||
"synapse_homeserver = synapse.app.homeserver:main",
|
||||
"synapse_worker = synapse.app.generic_worker:main",
|
||||
# Scripts
|
||||
"export_signing_key = synapse._scripts.export_signing_key:main",
|
||||
"generate_config = synapse._scripts.generate_config:main",
|
||||
"generate_log_config = synapse._scripts.generate_log_config:main",
|
||||
"generate_signing_key = synapse._scripts.generate_signing_key:main",
|
||||
"hash_password = synapse._scripts.hash_password:main",
|
||||
"register_new_matrix_user = synapse._scripts.register_new_matrix_user:main",
|
||||
"synapse_port_db = synapse._scripts.synapse_port_db:main",
|
||||
"synapse_review_recent_signups = synapse._scripts.review_recent_signups:main",
|
||||
"update_synapse_database = synapse._scripts.update_synapse_database:main",
|
||||
]
|
||||
},
|
||||
classifiers=[
|
||||
@@ -167,6 +177,6 @@ setup(
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
],
|
||||
scripts=["synctl"] + glob.glob("scripts/*"),
|
||||
scripts=["synctl"],
|
||||
cmdclass={"test": TestCommand},
|
||||
)
|
||||
|
||||
@@ -50,7 +50,7 @@ def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
@@ -85,7 +85,6 @@ if __name__ == "__main__":
|
||||
else format_plain
|
||||
)
|
||||
|
||||
keys = []
|
||||
for file in args.key_file:
|
||||
try:
|
||||
res = read_signing_keys(file)
|
||||
@@ -98,3 +97,7 @@ if __name__ == "__main__":
|
||||
res = []
|
||||
for key in res:
|
||||
formatter(get_verify_key(key))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -6,7 +6,8 @@ import sys
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config-dir",
|
||||
@@ -76,3 +77,7 @@ if __name__ == "__main__":
|
||||
shutil.copyfileobj(args.header_file, args.output_file)
|
||||
|
||||
args.output_file.write(conf)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -19,7 +19,8 @@ import sys
|
||||
|
||||
from synapse.config.logger import DEFAULT_LOG_CONFIG
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
@@ -42,3 +43,7 @@ if __name__ == "__main__":
|
||||
out = args.output_file
|
||||
out.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file))
|
||||
out.flush()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -19,7 +19,8 @@ from signedjson.key import generate_signing_key, write_signing_keys
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
@@ -34,3 +35,7 @@ if __name__ == "__main__":
|
||||
key_id = "a_" + random_string(4)
|
||||
key = (generate_signing_key(key_id),)
|
||||
write_signing_keys(args.output_file, key)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -8,9 +8,6 @@ import unicodedata
|
||||
import bcrypt
|
||||
import yaml
|
||||
|
||||
bcrypt_rounds = 12
|
||||
password_pepper = ""
|
||||
|
||||
|
||||
def prompt_for_pass():
|
||||
password = getpass.getpass("Password: ")
|
||||
@@ -26,7 +23,10 @@ def prompt_for_pass():
|
||||
return password
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
bcrypt_rounds = 12
|
||||
password_pepper = ""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description=(
|
||||
"Calculate the hash of a new password, so that passwords can be reset"
|
||||
@@ -77,3 +77,7 @@ if __name__ == "__main__":
|
||||
).decode("ascii")
|
||||
|
||||
print(hashed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
+1
-1
@@ -28,7 +28,7 @@ This can be extracted from postgres with::
|
||||
|
||||
To use, pipe the above into::
|
||||
|
||||
PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
|
||||
PYTHON_PATH=. synapse/_scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -1146,7 +1146,7 @@ class TerminalProgress(Progress):
|
||||
##############################################
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A script to port an existing synapse SQLite database to"
|
||||
" a new PostgreSQL database."
|
||||
@@ -1251,3 +1251,7 @@ if __name__ == "__main__":
|
||||
sys.stderr.write(end_error)
|
||||
|
||||
sys.exit(5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -22,7 +23,7 @@ from netaddr import IPSet
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
|
||||
from synapse.types import DeviceLists, GroupID, JsonDict, UserID, get_domain_from_id
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -175,27 +176,14 @@ class ApplicationService:
|
||||
return namespace.exclusive
|
||||
return False
|
||||
|
||||
async def _matches_user(self, event: EventBase, store: "DataStore") -> bool:
|
||||
if self.is_interested_in_user(event.sender):
|
||||
return True
|
||||
|
||||
# also check m.room.member state key
|
||||
if event.type == EventTypes.Member and self.is_interested_in_user(
|
||||
event.state_key
|
||||
):
|
||||
return True
|
||||
|
||||
does_match = await self.matches_user_in_member_list(event.room_id, store)
|
||||
return does_match
|
||||
|
||||
@cached(num_args=1, cache_context=True)
|
||||
async def matches_user_in_member_list(
|
||||
async def _matches_user_in_member_list(
|
||||
self,
|
||||
room_id: str,
|
||||
store: "DataStore",
|
||||
cache_context: _CacheContext,
|
||||
) -> bool:
|
||||
"""Check if this service is interested a room based upon it's membership
|
||||
"""Check if this service is interested a room based upon its membership
|
||||
|
||||
Args:
|
||||
room_id: The room to check.
|
||||
@@ -214,47 +202,110 @@ class ApplicationService:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _matches_room_id(self, event: EventBase) -> bool:
|
||||
if hasattr(event, "room_id"):
|
||||
return self.is_interested_in_room(event.room_id)
|
||||
return False
|
||||
def is_interested_in_user(
|
||||
self,
|
||||
user_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether the application is interested in a given user ID.
|
||||
|
||||
async def _matches_aliases(self, event: EventBase, store: "DataStore") -> bool:
|
||||
alias_list = await store.get_aliases_for_room(event.room_id)
|
||||
The appservice is considered to be interested in a user if either: the
|
||||
user ID is in the appservice's user namespace, or if the user is the
|
||||
appservice's configured sender_localpart.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to check.
|
||||
|
||||
Returns:
|
||||
True if the application service is interested in the user, False if not.
|
||||
"""
|
||||
return (
|
||||
# User is the appservice's sender_localpart user
|
||||
user_id == self.sender
|
||||
# User is in the appservice's user namespace
|
||||
or self.is_user_in_namespace(user_id)
|
||||
)
|
||||
|
||||
@cached(num_args=1, cache_context=True)
|
||||
async def is_interested_in_room(
|
||||
self,
|
||||
room_id: str,
|
||||
store: "DataStore",
|
||||
cache_context: _CacheContext,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether the application service is interested in a given room ID.
|
||||
|
||||
The appservice is considered to be interested in the room if either: the ID or one
|
||||
of the aliases of the room is in the appservice's room ID or alias namespace
|
||||
respectively, or if one of the members of the room fall into the appservice's user
|
||||
namespace.
|
||||
|
||||
Args:
|
||||
room_id: The ID of the room to check.
|
||||
store: The homeserver's datastore class.
|
||||
|
||||
Returns:
|
||||
True if the application service is interested in the room, False if not.
|
||||
"""
|
||||
# Check if we have interest in this room ID
|
||||
if self.is_room_id_in_namespace(room_id):
|
||||
return True
|
||||
|
||||
# likewise with the room's aliases (if it has any)
|
||||
alias_list = await store.get_aliases_for_room(room_id)
|
||||
for alias in alias_list:
|
||||
if self.is_interested_in_alias(alias):
|
||||
if self.is_room_alias_in_namespace(alias):
|
||||
return True
|
||||
|
||||
return False
|
||||
# And finally, perform an expensive check on whether any of the
|
||||
# users in the room match the appservice's user namespace
|
||||
return await self._matches_user_in_member_list(
|
||||
room_id, store, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
async def is_interested(self, event: EventBase, store: "DataStore") -> bool:
|
||||
@cached(num_args=1, cache_context=True)
|
||||
async def is_interested_in_event(
|
||||
self,
|
||||
event_id: str,
|
||||
event: EventBase,
|
||||
store: "DataStore",
|
||||
cache_context: _CacheContext,
|
||||
) -> bool:
|
||||
"""Check if this service is interested in this event.
|
||||
|
||||
Args:
|
||||
event_id: The ID of the event to check. This is purely used for simplifying the
|
||||
caching of calls to this method.
|
||||
event: The event to check.
|
||||
store: The datastore to query.
|
||||
|
||||
Returns:
|
||||
True if this service would like to know about this event.
|
||||
True if this service would like to know about this event, otherwise False.
|
||||
"""
|
||||
# Do cheap checks first
|
||||
if self._matches_room_id(event):
|
||||
# Check if we're interested in this event's sender by namespace (or if they're the
|
||||
# sender_localpart user)
|
||||
if self.is_interested_in_user(event.sender):
|
||||
return True
|
||||
|
||||
# This will check the namespaces first before
|
||||
# checking the store, so should be run before _matches_aliases
|
||||
if await self._matches_user(event, store):
|
||||
# additionally, if this is a membership event, perform the same checks on
|
||||
# the user it references
|
||||
if event.type == EventTypes.Member and self.is_interested_in_user(
|
||||
event.state_key
|
||||
):
|
||||
return True
|
||||
|
||||
# This will check the store, so should be run last
|
||||
if await self._matches_aliases(event, store):
|
||||
# This will check the datastore, so should be run last
|
||||
if await self.is_interested_in_room(
|
||||
event.room_id, store, on_invalidate=cache_context.invalidate
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@cached(num_args=1)
|
||||
@cached(num_args=1, cache_context=True)
|
||||
async def is_interested_in_presence(
|
||||
self, user_id: UserID, store: "DataStore"
|
||||
self, user_id: UserID, store: "DataStore", cache_context: _CacheContext
|
||||
) -> bool:
|
||||
"""Check if this service is interested a user's presence
|
||||
|
||||
@@ -272,20 +323,19 @@ class ApplicationService:
|
||||
|
||||
# Then find out if the appservice is interested in any of those rooms
|
||||
for room_id in room_ids:
|
||||
if await self.matches_user_in_member_list(room_id, store):
|
||||
if await self.is_interested_in_room(
|
||||
room_id, store, on_invalidate=cache_context.invalidate
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_interested_in_user(self, user_id: str) -> bool:
|
||||
return (
|
||||
bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
|
||||
or user_id == self.sender
|
||||
)
|
||||
def is_user_in_namespace(self, user_id: str) -> bool:
|
||||
return bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
|
||||
|
||||
def is_interested_in_alias(self, alias: str) -> bool:
|
||||
def is_room_alias_in_namespace(self, alias: str) -> bool:
|
||||
return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias))
|
||||
|
||||
def is_interested_in_room(self, room_id: str) -> bool:
|
||||
def is_room_id_in_namespace(self, room_id: str) -> bool:
|
||||
return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id))
|
||||
|
||||
def is_exclusive_user(self, user_id: str) -> bool:
|
||||
@@ -351,6 +401,7 @@ class AppServiceTransaction:
|
||||
to_device_messages: List[JsonDict],
|
||||
one_time_key_counts: TransactionOneTimeKeyCounts,
|
||||
unused_fallback_keys: TransactionUnusedFallbackKeys,
|
||||
device_list_summary: DeviceLists,
|
||||
):
|
||||
self.service = service
|
||||
self.id = id
|
||||
@@ -359,6 +410,7 @@ class AppServiceTransaction:
|
||||
self.to_device_messages = to_device_messages
|
||||
self.one_time_key_counts = one_time_key_counts
|
||||
self.unused_fallback_keys = unused_fallback_keys
|
||||
self.device_list_summary = device_list_summary
|
||||
|
||||
async def send(self, as_api: "ApplicationServiceApi") -> bool:
|
||||
"""Sends this transaction using the provided AS API interface.
|
||||
@@ -375,6 +427,7 @@ class AppServiceTransaction:
|
||||
to_device_messages=self.to_device_messages,
|
||||
one_time_key_counts=self.one_time_key_counts,
|
||||
unused_fallback_keys=self.unused_fallback_keys,
|
||||
device_list_summary=self.device_list_summary,
|
||||
txn_id=self.id,
|
||||
)
|
||||
|
||||
|
||||
+22
-12
@@ -1,4 +1,5 @@
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -25,9 +26,9 @@ from synapse.appservice import (
|
||||
TransactionUnusedFallbackKeys,
|
||||
)
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.events.utils import SerializeEventConfig, serialize_event
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.types import JsonDict, ThirdPartyInstanceID
|
||||
from synapse.types import DeviceLists, JsonDict, ThirdPartyInstanceID
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -225,6 +226,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
to_device_messages: List[JsonDict],
|
||||
one_time_key_counts: TransactionOneTimeKeyCounts,
|
||||
unused_fallback_keys: TransactionUnusedFallbackKeys,
|
||||
device_list_summary: DeviceLists,
|
||||
txn_id: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
@@ -268,6 +270,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
}
|
||||
)
|
||||
|
||||
# TODO: Update to stable prefixes once MSC3202 completes FCP merge
|
||||
if service.msc3202_transaction_extensions:
|
||||
if one_time_key_counts:
|
||||
body[
|
||||
@@ -277,6 +280,11 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
body[
|
||||
"org.matrix.msc3202.device_unused_fallback_keys"
|
||||
] = unused_fallback_keys
|
||||
if device_list_summary:
|
||||
body["org.matrix.msc3202.device_lists"] = {
|
||||
"changed": list(device_list_summary.changed),
|
||||
"left": list(device_list_summary.left),
|
||||
}
|
||||
|
||||
try:
|
||||
await self.put_json(
|
||||
@@ -321,16 +329,18 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
serialize_event(
|
||||
e,
|
||||
time_now,
|
||||
as_client_event=True,
|
||||
# If this is an invite or a knock membership event, and we're interested
|
||||
# in this user, then include any stripped state alongside the event.
|
||||
include_stripped_room_state=(
|
||||
e.type == EventTypes.Member
|
||||
and (
|
||||
e.membership == Membership.INVITE
|
||||
or e.membership == Membership.KNOCK
|
||||
)
|
||||
and service.is_interested_in_user(e.state_key)
|
||||
config=SerializeEventConfig(
|
||||
as_client_event=True,
|
||||
# If this is an invite or a knock membership event, and we're interested
|
||||
# in this user, then include any stripped state alongside the event.
|
||||
include_stripped_room_state=(
|
||||
e.type == EventTypes.Member
|
||||
and (
|
||||
e.membership == Membership.INVITE
|
||||
or e.membership == Membership.KNOCK
|
||||
)
|
||||
and service.is_interested_in_user(e.state_key)
|
||||
),
|
||||
),
|
||||
)
|
||||
for e in events
|
||||
|
||||
@@ -72,7 +72,7 @@ from synapse.events import EventBase
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import DeviceLists, JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -122,6 +122,7 @@ class ApplicationServiceScheduler:
|
||||
events: Optional[Collection[EventBase]] = None,
|
||||
ephemeral: Optional[Collection[JsonDict]] = None,
|
||||
to_device_messages: Optional[Collection[JsonDict]] = None,
|
||||
device_list_summary: Optional[DeviceLists] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Enqueue some data to be sent off to an application service.
|
||||
@@ -133,10 +134,18 @@ class ApplicationServiceScheduler:
|
||||
to_device_messages: The to-device messages to send. These differ from normal
|
||||
to-device messages sent to clients, as they have 'to_device_id' and
|
||||
'to_user_id' fields.
|
||||
device_list_summary: A summary of users that the application service either needs
|
||||
to refresh the device lists of, or those that the application service need no
|
||||
longer track the device lists of.
|
||||
"""
|
||||
# We purposefully allow this method to run with empty events/ephemeral
|
||||
# collections, so that callers do not need to check iterable size themselves.
|
||||
if not events and not ephemeral and not to_device_messages:
|
||||
if (
|
||||
not events
|
||||
and not ephemeral
|
||||
and not to_device_messages
|
||||
and not device_list_summary
|
||||
):
|
||||
return
|
||||
|
||||
if events:
|
||||
@@ -147,6 +156,10 @@ class ApplicationServiceScheduler:
|
||||
self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend(
|
||||
to_device_messages
|
||||
)
|
||||
if device_list_summary:
|
||||
self.queuer.queued_device_list_summaries.setdefault(
|
||||
appservice.id, []
|
||||
).append(device_list_summary)
|
||||
|
||||
# Kick off a new application service transaction
|
||||
self.queuer.start_background_request(appservice)
|
||||
@@ -169,6 +182,8 @@ class _ServiceQueuer:
|
||||
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
|
||||
# dict of {service_id: [to_device_message_json]}
|
||||
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
|
||||
# dict of {service_id: [device_list_summary]}
|
||||
self.queued_device_list_summaries: Dict[str, List[DeviceLists]] = {}
|
||||
|
||||
# the appservices which currently have a transaction in flight
|
||||
self.requests_in_flight: Set[str] = set()
|
||||
@@ -212,7 +227,40 @@ class _ServiceQueuer:
|
||||
]
|
||||
del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION]
|
||||
|
||||
if not events and not ephemeral and not to_device_messages_to_send:
|
||||
# Consolidate any pending device list summaries into a single, up-to-date
|
||||
# summary.
|
||||
# Note: this code assumes that in a single DeviceLists, a user will
|
||||
# never be in both "changed" and "left" sets.
|
||||
device_list_summary = DeviceLists()
|
||||
while self.queued_device_list_summaries.get(service.id, []):
|
||||
# Pop a summary off the front of the queue
|
||||
summary = self.queued_device_list_summaries[service.id].pop(0)
|
||||
|
||||
# For every user in the incoming "changed" set:
|
||||
# * Remove them from the existing "left" set if necessary
|
||||
# (as we need to start tracking them again)
|
||||
# * Add them to the existing "changed" set if necessary.
|
||||
for user_id in summary.changed:
|
||||
if user_id in device_list_summary.left:
|
||||
device_list_summary.left.remove(user_id)
|
||||
device_list_summary.changed.add(user_id)
|
||||
|
||||
# For every user in the incoming "left" set:
|
||||
# * Remove them from the existing "changed" set if necessary
|
||||
# (we no longer need to track them)
|
||||
# * Add them to the existing "left" set if necessary.
|
||||
for user_id in summary.left:
|
||||
if user_id in device_list_summary.changed:
|
||||
device_list_summary.changed.remove(user_id)
|
||||
device_list_summary.left.add(user_id)
|
||||
|
||||
if (
|
||||
not events
|
||||
and not ephemeral
|
||||
and not to_device_messages_to_send
|
||||
# Note that DeviceLists implements __bool__
|
||||
and not device_list_summary
|
||||
):
|
||||
return
|
||||
|
||||
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None
|
||||
@@ -240,6 +288,7 @@ class _ServiceQueuer:
|
||||
to_device_messages_to_send,
|
||||
one_time_key_counts,
|
||||
unused_fallback_keys,
|
||||
device_list_summary,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("AS request failed")
|
||||
@@ -322,6 +371,7 @@ class _TransactionController:
|
||||
to_device_messages: Optional[List[JsonDict]] = None,
|
||||
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
|
||||
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
|
||||
device_list_summary: Optional[DeviceLists] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Create a transaction with the given data and send to the provided
|
||||
@@ -336,6 +386,7 @@ class _TransactionController:
|
||||
appservice devices in the transaction.
|
||||
unused_fallback_keys: Lists of unused fallback keys for relevant
|
||||
appservice devices in the transaction.
|
||||
device_list_summary: The device list summary to include in the transaction.
|
||||
"""
|
||||
try:
|
||||
txn = await self.store.create_appservice_txn(
|
||||
@@ -345,6 +396,7 @@ class _TransactionController:
|
||||
to_device_messages=to_device_messages or [],
|
||||
one_time_key_counts=one_time_key_counts or {},
|
||||
unused_fallback_keys=unused_fallback_keys or {},
|
||||
device_list_summary=device_list_summary or DeviceLists(),
|
||||
)
|
||||
service_is_up = await self._is_service_up(service)
|
||||
if service_is_up:
|
||||
|
||||
@@ -383,7 +383,7 @@ class RootConfig:
|
||||
Build a default configuration file
|
||||
|
||||
This is used when the user explicitly asks us to generate a config file
|
||||
(eg with --generate_config).
|
||||
(eg with --generate-config).
|
||||
|
||||
Args:
|
||||
config_dir_path: The path where the config files are kept. Used to
|
||||
|
||||
@@ -170,6 +170,7 @@ def _load_appservice(
|
||||
# When enabled, appservice transactions contain the following information:
|
||||
# - device One-Time Key counts
|
||||
# - device unused fallback key usage states
|
||||
# - device list changes
|
||||
msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False)
|
||||
if not isinstance(msc3202_transaction_extensions, bool):
|
||||
raise ValueError(
|
||||
|
||||
@@ -59,8 +59,9 @@ class ExperimentalConfig(Config):
|
||||
"msc3202_device_masquerading", False
|
||||
)
|
||||
|
||||
# Portion of MSC3202 related to transaction extensions:
|
||||
# sending one-time key counts and fallback key usage to application services.
|
||||
# The portion of MSC3202 related to transaction extensions:
|
||||
# sending device list changes, one-time key counts and fallback key
|
||||
# usage to application services.
|
||||
self.msc3202_transaction_extensions: bool = experimental.get(
|
||||
"msc3202_transaction_extensions", False
|
||||
)
|
||||
|
||||
+54
-27
@@ -26,6 +26,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||
@@ -303,29 +304,37 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
|
||||
return d
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class SerializeEventConfig:
|
||||
as_client_event: bool = True
|
||||
# Function to convert from federation format to client format
|
||||
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1
|
||||
# ID of the user's auth token - used for namespacing of transaction IDs
|
||||
token_id: Optional[int] = None
|
||||
# List of event fields to include. If empty, all fields will be returned.
|
||||
only_event_fields: Optional[List[str]] = None
|
||||
# Some events can have stripped room state stored in the `unsigned` field.
|
||||
# This is required for invite and knock functionality. If this option is
|
||||
# False, that state will be removed from the event before it is returned.
|
||||
# Otherwise, it will be kept.
|
||||
include_stripped_room_state: bool = False
|
||||
|
||||
|
||||
_DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig()
|
||||
|
||||
|
||||
def serialize_event(
|
||||
e: Union[JsonDict, EventBase],
|
||||
time_now_ms: int,
|
||||
*,
|
||||
as_client_event: bool = True,
|
||||
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
|
||||
token_id: Optional[str] = None,
|
||||
only_event_fields: Optional[List[str]] = None,
|
||||
include_stripped_room_state: bool = False,
|
||||
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
|
||||
) -> JsonDict:
|
||||
"""Serialize event for clients
|
||||
|
||||
Args:
|
||||
e
|
||||
time_now_ms
|
||||
as_client_event
|
||||
event_format
|
||||
token_id
|
||||
only_event_fields
|
||||
include_stripped_room_state: Some events can have stripped room state
|
||||
stored in the `unsigned` field. This is required for invite and knock
|
||||
functionality. If this option is False, that state will be removed from the
|
||||
event before it is returned. Otherwise, it will be kept.
|
||||
config: Event serialization config
|
||||
|
||||
Returns:
|
||||
The serialized event dictionary.
|
||||
@@ -348,11 +357,11 @@ def serialize_event(
|
||||
|
||||
if "redacted_because" in e.unsigned:
|
||||
d["unsigned"]["redacted_because"] = serialize_event(
|
||||
e.unsigned["redacted_because"], time_now_ms, event_format=event_format
|
||||
e.unsigned["redacted_because"], time_now_ms, config=config
|
||||
)
|
||||
|
||||
if token_id is not None:
|
||||
if token_id == getattr(e.internal_metadata, "token_id", None):
|
||||
if config.token_id is not None:
|
||||
if config.token_id == getattr(e.internal_metadata, "token_id", None):
|
||||
txn_id = getattr(e.internal_metadata, "txn_id", None)
|
||||
if txn_id is not None:
|
||||
d["unsigned"]["transaction_id"] = txn_id
|
||||
@@ -361,13 +370,14 @@ def serialize_event(
|
||||
# that are meant to provide metadata about a room to an invitee/knocker. They are
|
||||
# intended to only be included in specific circumstances, such as down sync, and
|
||||
# should not be included in any other case.
|
||||
if not include_stripped_room_state:
|
||||
if not config.include_stripped_room_state:
|
||||
d["unsigned"].pop("invite_room_state", None)
|
||||
d["unsigned"].pop("knock_room_state", None)
|
||||
|
||||
if as_client_event:
|
||||
d = event_format(d)
|
||||
if config.as_client_event:
|
||||
d = config.event_format(d)
|
||||
|
||||
only_event_fields = config.only_event_fields
|
||||
if only_event_fields:
|
||||
if not isinstance(only_event_fields, list) or not all(
|
||||
isinstance(f, str) for f in only_event_fields
|
||||
@@ -390,18 +400,18 @@ class EventClientSerializer:
|
||||
event: Union[JsonDict, EventBase],
|
||||
time_now: int,
|
||||
*,
|
||||
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
|
||||
bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
|
||||
**kwargs: Any,
|
||||
) -> JsonDict:
|
||||
"""Serializes a single event.
|
||||
|
||||
Args:
|
||||
event: The event being serialized.
|
||||
time_now: The current time in milliseconds
|
||||
config: Event serialization config
|
||||
bundle_aggregations: Whether to include the bundled aggregations for this
|
||||
event. Only applies to non-state events. (State events never include
|
||||
bundled aggregations.)
|
||||
**kwargs: Arguments to pass to `serialize_event`
|
||||
|
||||
Returns:
|
||||
The serialized event
|
||||
@@ -410,7 +420,7 @@ class EventClientSerializer:
|
||||
if not isinstance(event, EventBase):
|
||||
return event
|
||||
|
||||
serialized_event = serialize_event(event, time_now, **kwargs)
|
||||
serialized_event = serialize_event(event, time_now, config=config)
|
||||
|
||||
# Check if there are any bundled aggregations to include with the event.
|
||||
if bundle_aggregations:
|
||||
@@ -419,6 +429,7 @@ class EventClientSerializer:
|
||||
self._inject_bundled_aggregations(
|
||||
event,
|
||||
time_now,
|
||||
config,
|
||||
bundle_aggregations[event.event_id],
|
||||
serialized_event,
|
||||
)
|
||||
@@ -456,6 +467,7 @@ class EventClientSerializer:
|
||||
self,
|
||||
event: EventBase,
|
||||
time_now: int,
|
||||
config: SerializeEventConfig,
|
||||
aggregations: "BundledAggregations",
|
||||
serialized_event: JsonDict,
|
||||
) -> None:
|
||||
@@ -466,6 +478,7 @@ class EventClientSerializer:
|
||||
time_now: The current time in milliseconds
|
||||
aggregations: The bundled aggregation to serialize.
|
||||
serialized_event: The serialized event which may be modified.
|
||||
config: Event serialization config
|
||||
|
||||
"""
|
||||
serialized_aggregations = {}
|
||||
@@ -493,8 +506,8 @@ class EventClientSerializer:
|
||||
thread = aggregations.thread
|
||||
|
||||
# Don't bundle aggregations as this could recurse forever.
|
||||
serialized_latest_event = self.serialize_event(
|
||||
thread.latest_event, time_now, bundle_aggregations=None
|
||||
serialized_latest_event = serialize_event(
|
||||
thread.latest_event, time_now, config=config
|
||||
)
|
||||
# Manually apply an edit, if one exists.
|
||||
if thread.latest_edit:
|
||||
@@ -515,20 +528,34 @@ class EventClientSerializer:
|
||||
)
|
||||
|
||||
def serialize_events(
|
||||
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
|
||||
self,
|
||||
events: Iterable[Union[JsonDict, EventBase]],
|
||||
time_now: int,
|
||||
*,
|
||||
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
|
||||
bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
|
||||
) -> List[JsonDict]:
|
||||
"""Serializes multiple events.
|
||||
|
||||
Args:
|
||||
event
|
||||
time_now: The current time in milliseconds
|
||||
**kwargs: Arguments to pass to `serialize_event`
|
||||
config: Event serialization config
|
||||
bundle_aggregations: Whether to include the bundled aggregations for this
|
||||
event. Only applies to non-state events. (State events never include
|
||||
bundled aggregations.)
|
||||
|
||||
Returns:
|
||||
The list of serialized events
|
||||
"""
|
||||
return [
|
||||
self.serialize_event(event, time_now=time_now, **kwargs) for event in events
|
||||
self.serialize_event(
|
||||
event,
|
||||
time_now,
|
||||
config=config,
|
||||
bundle_aggregations=bundle_aggregations,
|
||||
)
|
||||
for event in events
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1428,7 +1428,7 @@ class FederationClient(FederationBase):
|
||||
|
||||
# Validate children_state of the room.
|
||||
children_state = room.pop("children_state", [])
|
||||
if not isinstance(children_state, Sequence):
|
||||
if not isinstance(children_state, list):
|
||||
raise InvalidResponseError("'room.children_state' must be a list")
|
||||
if any(not isinstance(e, dict) for e in children_state):
|
||||
raise InvalidResponseError("Invalid event in 'children_state' list")
|
||||
@@ -1440,14 +1440,14 @@ class FederationClient(FederationBase):
|
||||
|
||||
# Validate the children rooms.
|
||||
children = res.get("children", [])
|
||||
if not isinstance(children, Sequence):
|
||||
if not isinstance(children, list):
|
||||
raise InvalidResponseError("'children' must be a list")
|
||||
if any(not isinstance(r, dict) for r in children):
|
||||
raise InvalidResponseError("Invalid room in 'children' list")
|
||||
|
||||
# Validate the inaccessible children.
|
||||
inaccessible_children = res.get("inaccessible_children", [])
|
||||
if not isinstance(inaccessible_children, Sequence):
|
||||
if not isinstance(inaccessible_children, list):
|
||||
raise InvalidResponseError("'inaccessible_children' must be a list")
|
||||
if any(not isinstance(r, str) for r in inaccessible_children):
|
||||
raise InvalidResponseError(
|
||||
@@ -1630,7 +1630,7 @@ def _validate_hierarchy_event(d: JsonDict) -> None:
|
||||
raise ValueError("Invalid event: 'content' must be a dict")
|
||||
|
||||
via = content.get("via")
|
||||
if not isinstance(via, Sequence):
|
||||
if not isinstance(via, list):
|
||||
raise ValueError("Invalid event: 'via' must be a list")
|
||||
if any(not isinstance(v, str) for v in via):
|
||||
raise ValueError("Invalid event: 'via' must be a list of strings")
|
||||
|
||||
@@ -33,7 +33,7 @@ from synapse.metrics.background_process_metrics import (
|
||||
wrap_as_background_process,
|
||||
)
|
||||
from synapse.storage.databases.main.directory import RoomAliasMapping
|
||||
from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID
|
||||
from synapse.types import DeviceLists, JsonDict, RoomAlias, RoomStreamToken, UserID
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
@@ -58,6 +58,9 @@ class ApplicationServicesHandler:
|
||||
self._msc2409_to_device_messages_enabled = (
|
||||
hs.config.experimental.msc2409_to_device_messages_enabled
|
||||
)
|
||||
self._msc3202_transaction_extensions_enabled = (
|
||||
hs.config.experimental.msc3202_transaction_extensions
|
||||
)
|
||||
|
||||
self.current_max = 0
|
||||
self.is_processing = False
|
||||
@@ -204,9 +207,9 @@ class ApplicationServicesHandler:
|
||||
Args:
|
||||
stream_key: The stream the event came from.
|
||||
|
||||
`stream_key` can be "typing_key", "receipt_key", "presence_key" or
|
||||
"to_device_key". Any other value for `stream_key` will cause this function
|
||||
to return early.
|
||||
`stream_key` can be "typing_key", "receipt_key", "presence_key",
|
||||
"to_device_key" or "device_list_key". Any other value for `stream_key`
|
||||
will cause this function to return early.
|
||||
|
||||
Ephemeral events will only be pushed to appservices that have opted into
|
||||
receiving them by setting `push_ephemeral` to true in their registration
|
||||
@@ -230,6 +233,7 @@ class ApplicationServicesHandler:
|
||||
"receipt_key",
|
||||
"presence_key",
|
||||
"to_device_key",
|
||||
"device_list_key",
|
||||
):
|
||||
return
|
||||
|
||||
@@ -253,15 +257,37 @@ class ApplicationServicesHandler:
|
||||
):
|
||||
return
|
||||
|
||||
# Ignore device lists if the feature flag is not enabled
|
||||
if (
|
||||
stream_key == "device_list_key"
|
||||
and not self._msc3202_transaction_extensions_enabled
|
||||
):
|
||||
return
|
||||
|
||||
# Check whether there are any appservices which have registered to receive
|
||||
# ephemeral events.
|
||||
#
|
||||
# Note that whether these events are actually relevant to these appservices
|
||||
# is decided later on.
|
||||
services = self.store.get_app_services()
|
||||
services = [
|
||||
service
|
||||
for service in self.store.get_app_services()
|
||||
if service.supports_ephemeral
|
||||
for service in services
|
||||
# Different stream keys require different support booleans
|
||||
if (
|
||||
stream_key
|
||||
in (
|
||||
"typing_key",
|
||||
"receipt_key",
|
||||
"presence_key",
|
||||
"to_device_key",
|
||||
)
|
||||
and service.supports_ephemeral
|
||||
)
|
||||
or (
|
||||
stream_key == "device_list_key"
|
||||
and service.msc3202_transaction_extensions
|
||||
)
|
||||
]
|
||||
if not services:
|
||||
# Bail out early if none of the target appservices have explicitly registered
|
||||
@@ -336,6 +362,20 @@ class ApplicationServicesHandler:
|
||||
service, "to_device", new_token
|
||||
)
|
||||
|
||||
elif stream_key == "device_list_key":
|
||||
device_list_summary = await self._get_device_list_summary(
|
||||
service, new_token
|
||||
)
|
||||
if device_list_summary:
|
||||
self.scheduler.enqueue_for_appservice(
|
||||
service, device_list_summary=device_list_summary
|
||||
)
|
||||
|
||||
# Persist the latest handled stream token for this appservice
|
||||
await self.store.set_appservice_stream_type_pos(
|
||||
service, "device_list", new_token
|
||||
)
|
||||
|
||||
async def _handle_typing(
|
||||
self, service: ApplicationService, new_token: int
|
||||
) -> List[JsonDict]:
|
||||
@@ -542,6 +582,98 @@ class ApplicationServicesHandler:
|
||||
|
||||
return message_payload
|
||||
|
||||
async def _get_device_list_summary(
|
||||
self,
|
||||
appservice: ApplicationService,
|
||||
new_key: int,
|
||||
) -> DeviceLists:
|
||||
"""
|
||||
Retrieve a list of users who have changed their device lists.
|
||||
|
||||
Args:
|
||||
appservice: The application service to retrieve device list changes for.
|
||||
new_key: The stream key of the device list change that triggered this method call.
|
||||
|
||||
Returns:
|
||||
A set of device list updates, comprised of users that the appservices needs to:
|
||||
* resync the device list of, and
|
||||
* stop tracking the device list of.
|
||||
"""
|
||||
# Fetch the last successfully processed device list update stream ID
|
||||
# for this appservice.
|
||||
from_key = await self.store.get_type_stream_id_for_appservice(
|
||||
appservice, "device_list"
|
||||
)
|
||||
|
||||
# Fetch the users who have modified their device list since then.
|
||||
users_with_changed_device_lists = (
|
||||
await self.store.get_users_whose_devices_changed(
|
||||
from_key, user_ids=None, to_key=new_key
|
||||
)
|
||||
)
|
||||
|
||||
# Filter out any users the application service is not interested in
|
||||
#
|
||||
# For each user who changed their device list, we want to check whether this
|
||||
# appservice would be interested in the change.
|
||||
filtered_users_with_changed_device_lists = {
|
||||
user_id
|
||||
for user_id in users_with_changed_device_lists
|
||||
if await self._is_appservice_interested_in_device_lists_of_user(
|
||||
appservice, user_id
|
||||
)
|
||||
}
|
||||
|
||||
# Create a summary of "changed" and "left" users.
|
||||
# TODO: Calculate "left" users.
|
||||
device_list_summary = DeviceLists(
|
||||
changed=filtered_users_with_changed_device_lists
|
||||
)
|
||||
|
||||
return device_list_summary
|
||||
|
||||
async def _is_appservice_interested_in_device_lists_of_user(
|
||||
self,
|
||||
appservice: ApplicationService,
|
||||
user_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether a given application service is interested in the device list
|
||||
updates of a given user.
|
||||
|
||||
The application service is interested in the user's device list updates if any
|
||||
of the following are true:
|
||||
* The user is the appservice's sender localpart user.
|
||||
* The user is in the appservice's user namespace.
|
||||
* At least one member of one room that the user is a part of is in the
|
||||
appservice's user namespace.
|
||||
* The appservice is explicitly (via room ID or alias) interested in at
|
||||
least one room that the user is in.
|
||||
|
||||
Args:
|
||||
appservice: The application service to gauge interest of.
|
||||
user_id: The ID of the user whose device list interest is in question.
|
||||
|
||||
Returns:
|
||||
True if the application service is interested in the user's device lists, False
|
||||
otherwise.
|
||||
"""
|
||||
# This method checks against both the sender localpart user as well as if the
|
||||
# user is in the appservice's user namespace.
|
||||
if appservice.is_interested_in_user(user_id):
|
||||
return True
|
||||
|
||||
# FIXME: This is quite an expensive check. This method is called per device
|
||||
# list change.
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
for room_id in room_ids:
|
||||
# This method covers checking room members for appservice interest as well as
|
||||
# room ID and alias checks.
|
||||
if await appservice.is_interested_in_room(room_id, self.store):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def query_user_exists(self, user_id: str) -> bool:
|
||||
"""Check if any application service knows this user_id exists.
|
||||
|
||||
@@ -571,7 +703,7 @@ class ApplicationServicesHandler:
|
||||
room_alias_str = room_alias.to_string()
|
||||
services = self.store.get_app_services()
|
||||
alias_query_services = [
|
||||
s for s in services if (s.is_interested_in_alias(room_alias_str))
|
||||
s for s in services if (s.is_room_alias_in_namespace(room_alias_str))
|
||||
]
|
||||
for alias_service in alias_query_services:
|
||||
is_known_alias = await self.appservice_api.query_alias(
|
||||
@@ -660,7 +792,7 @@ class ApplicationServicesHandler:
|
||||
# inside of a list comprehension anymore.
|
||||
interested_list = []
|
||||
for s in services:
|
||||
if await s.is_interested(event, self.store):
|
||||
if await s.is_interested_in_event(event.event_id, event, self.store):
|
||||
interested_list.append(s)
|
||||
|
||||
return interested_list
|
||||
|
||||
@@ -119,7 +119,7 @@ class DirectoryHandler:
|
||||
|
||||
service = requester.app_service
|
||||
if service:
|
||||
if not service.is_interested_in_alias(room_alias_str):
|
||||
if not service.is_room_alias_in_namespace(room_alias_str):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"This application service has not reserved this kind of alias.",
|
||||
@@ -221,7 +221,7 @@ class DirectoryHandler:
|
||||
async def delete_appservice_association(
|
||||
self, service: ApplicationService, room_alias: RoomAlias
|
||||
) -> None:
|
||||
if not service.is_interested_in_alias(room_alias.to_string()):
|
||||
if not service.is_room_alias_in_namespace(room_alias.to_string()):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"This application service has not reserved this kind of alias",
|
||||
@@ -376,7 +376,7 @@ class DirectoryHandler:
|
||||
# non-exclusive locks on the alias (or there are no interested services)
|
||||
services = self.store.get_app_services()
|
||||
interested_services = [
|
||||
s for s in services if s.is_interested_in_alias(alias.to_string())
|
||||
s for s in services if s.is_room_alias_in_namespace(alias.to_string())
|
||||
]
|
||||
|
||||
for service in interested_services:
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional
|
||||
from synapse.api.constants import EduTypes, EventTypes, Membership
|
||||
from synapse.api.errors import AuthError, SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import SerializeEventConfig
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import JsonDict, UserID
|
||||
@@ -120,7 +121,7 @@ class EventStreamHandler:
|
||||
chunks = self._event_serializer.serialize_events(
|
||||
events,
|
||||
time_now,
|
||||
as_client_event=as_client_event,
|
||||
config=SerializeEventConfig(as_client_event=as_client_event),
|
||||
)
|
||||
|
||||
chunk = {
|
||||
|
||||
@@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
||||
from synapse.api.constants import EduTypes, EventTypes, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import SerializeEventConfig
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.handlers.receipts import ReceiptEventSource
|
||||
@@ -156,6 +157,8 @@ class InitialSyncHandler:
|
||||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
serializer_options = SerializeEventConfig(as_client_event=as_client_event)
|
||||
|
||||
async def handle_room(event: RoomsForUser) -> None:
|
||||
d: JsonDict = {
|
||||
"room_id": event.room_id,
|
||||
@@ -173,7 +176,7 @@ class InitialSyncHandler:
|
||||
d["invite"] = self._event_serializer.serialize_event(
|
||||
invite_event,
|
||||
time_now,
|
||||
as_client_event=as_client_event,
|
||||
config=serializer_options,
|
||||
)
|
||||
|
||||
rooms_ret.append(d)
|
||||
@@ -225,7 +228,7 @@ class InitialSyncHandler:
|
||||
self._event_serializer.serialize_events(
|
||||
messages,
|
||||
time_now=time_now,
|
||||
as_client_event=as_client_event,
|
||||
config=serializer_options,
|
||||
)
|
||||
),
|
||||
"start": await start_token.to_string(self.store),
|
||||
@@ -235,7 +238,7 @@ class InitialSyncHandler:
|
||||
d["state"] = self._event_serializer.serialize_events(
|
||||
current_state.values(),
|
||||
time_now=time_now,
|
||||
as_client_event=as_client_event,
|
||||
config=serializer_options,
|
||||
)
|
||||
|
||||
account_data_events = []
|
||||
|
||||
@@ -1069,6 +1069,9 @@ class EventCreationHandler:
|
||||
if relation_type == RelationTypes.ANNOTATION:
|
||||
aggregation_key = relation["key"]
|
||||
|
||||
if len(aggregation_key) > 500:
|
||||
raise SynapseError(400, "Aggregation key is too long")
|
||||
|
||||
already_exists = await self.store.has_user_annotated_event(
|
||||
relates_to, event.type, aggregation_key, event.sender
|
||||
)
|
||||
|
||||
@@ -22,6 +22,7 @@ from twisted.python.failure import Failure
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.events.utils import SerializeEventConfig
|
||||
from synapse.handlers.room import ShutdownRoomResponse
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.state import StateFilter
|
||||
@@ -541,13 +542,15 @@ class PaginationHandler:
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
serialize_options = SerializeEventConfig(as_client_event=as_client_event)
|
||||
|
||||
chunk = {
|
||||
"chunk": (
|
||||
self._event_serializer.serialize_events(
|
||||
events,
|
||||
time_now,
|
||||
config=serialize_options,
|
||||
bundle_aggregations=aggregations,
|
||||
as_client_event=as_client_event,
|
||||
)
|
||||
),
|
||||
"start": await from_token.to_string(self.store),
|
||||
@@ -556,7 +559,7 @@ class PaginationHandler:
|
||||
|
||||
if state:
|
||||
chunk["state"] = self._event_serializer.serialize_events(
|
||||
state, time_now, as_client_event=as_client_event
|
||||
state, time_now, config=serialize_options
|
||||
)
|
||||
|
||||
return chunk
|
||||
|
||||
@@ -269,7 +269,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
|
||||
# Then filter down to rooms that the AS can read
|
||||
events = []
|
||||
for room_id, event in rooms_to_events.items():
|
||||
if not await service.matches_user_in_member_list(room_id, self.store):
|
||||
if not await service.is_interested_in_room(room_id, self.store):
|
||||
continue
|
||||
|
||||
events.append(event)
|
||||
|
||||
@@ -1736,8 +1736,8 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
||||
txn_id=txn_id,
|
||||
prev_event_ids=prev_event_ids,
|
||||
auth_event_ids=auth_event_ids,
|
||||
outlier=True,
|
||||
)
|
||||
event.internal_metadata.outlier = True
|
||||
event.internal_metadata.out_of_band_membership = True
|
||||
|
||||
result_event = await self.event_creation_handler.handle_new_client_event(
|
||||
|
||||
@@ -857,7 +857,7 @@ class _RoomEntry:
|
||||
|
||||
def _has_valid_via(e: EventBase) -> bool:
|
||||
via = e.content.get("via")
|
||||
if not via or not isinstance(via, Sequence):
|
||||
if not via or not isinstance(via, list):
|
||||
return False
|
||||
for v in via:
|
||||
if not isinstance(v, str):
|
||||
|
||||
@@ -13,17 +13,7 @@
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
@@ -41,6 +31,7 @@ from synapse.storage.databases.main.relations import BundledAggregations
|
||||
from synapse.storage.roommember import MemberSummary
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
DeviceLists,
|
||||
JsonDict,
|
||||
MutableStateMap,
|
||||
Requester,
|
||||
@@ -184,21 +175,6 @@ class GroupsSyncResult:
|
||||
return bool(self.join or self.invite or self.leave)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class DeviceLists:
|
||||
"""
|
||||
Attributes:
|
||||
changed: List of user_ids whose devices may have changed
|
||||
left: List of user_ids whose devices we no longer track
|
||||
"""
|
||||
|
||||
changed: Collection[str]
|
||||
left: Collection[str]
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.changed or self.left)
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class _RoomChanges:
|
||||
"""The set of room entries to include in the sync, plus the set of joined
|
||||
@@ -1380,7 +1356,7 @@ class SyncHandler:
|
||||
|
||||
return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
|
||||
else:
|
||||
return DeviceLists(changed=[], left=[])
|
||||
return DeviceLists()
|
||||
|
||||
async def _generate_sync_entry_for_to_device(
|
||||
self, sync_result_builder: "SyncResultBuilder"
|
||||
|
||||
@@ -486,9 +486,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
|
||||
if handler._room_serials[room_id] <= from_key:
|
||||
continue
|
||||
|
||||
if not await service.matches_user_in_member_list(
|
||||
room_id, self._main_store
|
||||
):
|
||||
if not await service.is_interested_in_room(room_id, self._main_store):
|
||||
continue
|
||||
|
||||
events.append(self._make_event_for(room_id))
|
||||
|
||||
@@ -16,7 +16,10 @@ import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import ReceiptTypes
|
||||
from synapse.events.utils import format_event_for_client_v2_without_room_id
|
||||
from synapse.events.utils import (
|
||||
SerializeEventConfig,
|
||||
format_event_for_client_v2_without_room_id,
|
||||
)
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
@@ -75,7 +78,9 @@ class NotificationsServlet(RestServlet):
|
||||
self._event_serializer.serialize_event(
|
||||
notif_events[pa.event_id],
|
||||
self.clock.time_msec(),
|
||||
event_format=format_event_for_client_v2_without_room_id,
|
||||
config=SerializeEventConfig(
|
||||
event_format=format_event_for_client_v2_without_room_id
|
||||
),
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
+38
-94
@@ -14,24 +14,14 @@
|
||||
import itertools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from synapse.api.constants import Membership, PresenceState
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.api.filtering import FilterCollection
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import (
|
||||
SerializeEventConfig,
|
||||
format_event_for_client_v2_without_room_id,
|
||||
format_event_raw,
|
||||
)
|
||||
@@ -48,7 +38,6 @@ from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import trace
|
||||
from synapse.storage.databases.main.relations import BundledAggregations
|
||||
from synapse.types import JsonDict, StreamToken
|
||||
from synapse.util import json_decoder
|
||||
|
||||
@@ -239,28 +228,31 @@ class SyncRestServlet(RestServlet):
|
||||
else:
|
||||
raise Exception("Unknown event format %s" % (filter.event_format,))
|
||||
|
||||
serialize_options = SerializeEventConfig(
|
||||
event_format=event_formatter,
|
||||
token_id=access_token_id,
|
||||
only_event_fields=filter.event_fields,
|
||||
)
|
||||
stripped_serialize_options = SerializeEventConfig(
|
||||
event_format=event_formatter,
|
||||
token_id=access_token_id,
|
||||
include_stripped_room_state=True,
|
||||
)
|
||||
|
||||
joined = await self.encode_joined(
|
||||
sync_result.joined,
|
||||
time_now,
|
||||
access_token_id,
|
||||
filter.event_fields,
|
||||
event_formatter,
|
||||
sync_result.joined, time_now, serialize_options
|
||||
)
|
||||
|
||||
invited = await self.encode_invited(
|
||||
sync_result.invited, time_now, access_token_id, event_formatter
|
||||
sync_result.invited, time_now, stripped_serialize_options
|
||||
)
|
||||
|
||||
knocked = await self.encode_knocked(
|
||||
sync_result.knocked, time_now, access_token_id, event_formatter
|
||||
sync_result.knocked, time_now, stripped_serialize_options
|
||||
)
|
||||
|
||||
archived = await self.encode_archived(
|
||||
sync_result.archived,
|
||||
time_now,
|
||||
access_token_id,
|
||||
filter.event_fields,
|
||||
event_formatter,
|
||||
sync_result.archived, time_now, serialize_options
|
||||
)
|
||||
|
||||
logger.debug("building sync response dict")
|
||||
@@ -339,9 +331,7 @@ class SyncRestServlet(RestServlet):
|
||||
self,
|
||||
rooms: List[JoinedSyncResult],
|
||||
time_now: int,
|
||||
token_id: Optional[int],
|
||||
event_fields: List[str],
|
||||
event_formatter: Callable[[JsonDict], JsonDict],
|
||||
serialize_options: SerializeEventConfig,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Encode the joined rooms in a sync result
|
||||
@@ -349,24 +339,14 @@ class SyncRestServlet(RestServlet):
|
||||
Args:
|
||||
rooms: list of sync results for rooms this user is joined to
|
||||
time_now: current time - used as a baseline for age calculations
|
||||
token_id: ID of the user's auth token - used for namespacing
|
||||
of transaction IDs
|
||||
event_fields: List of event fields to include. If empty,
|
||||
all fields will be returned.
|
||||
event_formatter: function to convert from federation format
|
||||
to client format
|
||||
serialize_options: Event serializer options
|
||||
Returns:
|
||||
The joined rooms list, in our response format
|
||||
"""
|
||||
joined = {}
|
||||
for room in rooms:
|
||||
joined[room.room_id] = await self.encode_room(
|
||||
room,
|
||||
time_now,
|
||||
token_id,
|
||||
joined=True,
|
||||
only_fields=event_fields,
|
||||
event_formatter=event_formatter,
|
||||
room, time_now, joined=True, serialize_options=serialize_options
|
||||
)
|
||||
|
||||
return joined
|
||||
@@ -376,8 +356,7 @@ class SyncRestServlet(RestServlet):
|
||||
self,
|
||||
rooms: List[InvitedSyncResult],
|
||||
time_now: int,
|
||||
token_id: Optional[int],
|
||||
event_formatter: Callable[[JsonDict], JsonDict],
|
||||
serialize_options: SerializeEventConfig,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Encode the invited rooms in a sync result
|
||||
@@ -385,10 +364,7 @@ class SyncRestServlet(RestServlet):
|
||||
Args:
|
||||
rooms: list of sync results for rooms this user is invited to
|
||||
time_now: current time - used as a baseline for age calculations
|
||||
token_id: ID of the user's auth token - used for namespacing
|
||||
of transaction IDs
|
||||
event_formatter: function to convert from federation format
|
||||
to client format
|
||||
serialize_options: Event serializer options
|
||||
|
||||
Returns:
|
||||
The invited rooms list, in our response format
|
||||
@@ -396,11 +372,7 @@ class SyncRestServlet(RestServlet):
|
||||
invited = {}
|
||||
for room in rooms:
|
||||
invite = self._event_serializer.serialize_event(
|
||||
room.invite,
|
||||
time_now,
|
||||
token_id=token_id,
|
||||
event_format=event_formatter,
|
||||
include_stripped_room_state=True,
|
||||
room.invite, time_now, config=serialize_options
|
||||
)
|
||||
unsigned = dict(invite.get("unsigned", {}))
|
||||
invite["unsigned"] = unsigned
|
||||
@@ -415,8 +387,7 @@ class SyncRestServlet(RestServlet):
|
||||
self,
|
||||
rooms: List[KnockedSyncResult],
|
||||
time_now: int,
|
||||
token_id: Optional[int],
|
||||
event_formatter: Callable[[Dict], Dict],
|
||||
serialize_options: SerializeEventConfig,
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Encode the rooms we've knocked on in a sync result.
|
||||
@@ -424,8 +395,7 @@ class SyncRestServlet(RestServlet):
|
||||
Args:
|
||||
rooms: list of sync results for rooms this user is knocking on
|
||||
time_now: current time - used as a baseline for age calculations
|
||||
token_id: ID of the user's auth token - used for namespacing of transaction IDs
|
||||
event_formatter: function to convert from federation format to client format
|
||||
serialize_options: Event serializer options
|
||||
|
||||
Returns:
|
||||
The list of rooms the user has knocked on, in our response format.
|
||||
@@ -433,11 +403,7 @@ class SyncRestServlet(RestServlet):
|
||||
knocked = {}
|
||||
for room in rooms:
|
||||
knock = self._event_serializer.serialize_event(
|
||||
room.knock,
|
||||
time_now,
|
||||
token_id=token_id,
|
||||
event_format=event_formatter,
|
||||
include_stripped_room_state=True,
|
||||
room.knock, time_now, config=serialize_options
|
||||
)
|
||||
|
||||
# Extract the `unsigned` key from the knock event.
|
||||
@@ -470,9 +436,7 @@ class SyncRestServlet(RestServlet):
|
||||
self,
|
||||
rooms: List[ArchivedSyncResult],
|
||||
time_now: int,
|
||||
token_id: Optional[int],
|
||||
event_fields: List[str],
|
||||
event_formatter: Callable[[JsonDict], JsonDict],
|
||||
serialize_options: SerializeEventConfig,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Encode the archived rooms in a sync result
|
||||
@@ -480,23 +444,14 @@ class SyncRestServlet(RestServlet):
|
||||
Args:
|
||||
rooms: list of sync results for rooms this user is joined to
|
||||
time_now: current time - used as a baseline for age calculations
|
||||
token_id: ID of the user's auth token - used for namespacing
|
||||
of transaction IDs
|
||||
event_fields: List of event fields to include. If empty,
|
||||
all fields will be returned.
|
||||
event_formatter: function to convert from federation format to client format
|
||||
serialize_options: Event serializer options
|
||||
Returns:
|
||||
The archived rooms list, in our response format
|
||||
"""
|
||||
joined = {}
|
||||
for room in rooms:
|
||||
joined[room.room_id] = await self.encode_room(
|
||||
room,
|
||||
time_now,
|
||||
token_id,
|
||||
joined=False,
|
||||
only_fields=event_fields,
|
||||
event_formatter=event_formatter,
|
||||
room, time_now, joined=False, serialize_options=serialize_options
|
||||
)
|
||||
|
||||
return joined
|
||||
@@ -505,10 +460,8 @@ class SyncRestServlet(RestServlet):
|
||||
self,
|
||||
room: Union[JoinedSyncResult, ArchivedSyncResult],
|
||||
time_now: int,
|
||||
token_id: Optional[int],
|
||||
joined: bool,
|
||||
only_fields: Optional[List[str]],
|
||||
event_formatter: Callable[[JsonDict], JsonDict],
|
||||
serialize_options: SerializeEventConfig,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
@@ -524,20 +477,6 @@ class SyncRestServlet(RestServlet):
|
||||
Returns:
|
||||
The room, encoded in our response format
|
||||
"""
|
||||
|
||||
def serialize(
|
||||
events: Iterable[EventBase],
|
||||
aggregations: Optional[Dict[str, BundledAggregations]] = None,
|
||||
) -> List[JsonDict]:
|
||||
return self._event_serializer.serialize_events(
|
||||
events,
|
||||
time_now=time_now,
|
||||
bundle_aggregations=aggregations,
|
||||
token_id=token_id,
|
||||
event_format=event_formatter,
|
||||
only_event_fields=only_fields,
|
||||
)
|
||||
|
||||
state_dict = room.state
|
||||
timeline_events = room.timeline.events
|
||||
|
||||
@@ -554,9 +493,14 @@ class SyncRestServlet(RestServlet):
|
||||
event.room_id,
|
||||
)
|
||||
|
||||
serialized_state = serialize(state_events)
|
||||
serialized_timeline = serialize(
|
||||
timeline_events, room.timeline.bundled_aggregations
|
||||
serialized_state = self._event_serializer.serialize_events(
|
||||
state_events, time_now, config=serialize_options
|
||||
)
|
||||
serialized_timeline = self._event_serializer.serialize_events(
|
||||
timeline_events,
|
||||
time_now,
|
||||
config=serialize_options,
|
||||
bundle_aggregations=room.timeline.bundled_aggregations,
|
||||
)
|
||||
|
||||
account_data = room.account_data
|
||||
|
||||
@@ -194,7 +194,7 @@ class StateHandler:
|
||||
}
|
||||
|
||||
async def get_current_state_ids(
|
||||
self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
|
||||
self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
|
||||
) -> StateMap[str]:
|
||||
"""Get the current state, or the state at a set of events, for a room
|
||||
|
||||
@@ -243,7 +243,7 @@ class StateHandler:
|
||||
return await self.get_hosts_in_room_at_events(room_id, event_ids)
|
||||
|
||||
async def get_hosts_in_room_at_events(
|
||||
self, room_id: str, event_ids: Iterable[str]
|
||||
self, room_id: str, event_ids: Collection[str]
|
||||
) -> Set[str]:
|
||||
"""Get the hosts that were in a room at the given event ids
|
||||
|
||||
@@ -404,7 +404,7 @@ class StateHandler:
|
||||
|
||||
@measure_func()
|
||||
async def resolve_state_groups_for_events(
|
||||
self, room_id: str, event_ids: Iterable[str]
|
||||
self, room_id: str, event_ids: Collection[str]
|
||||
) -> _StateCacheEntry:
|
||||
"""Given a list of event_ids this method fetches the state at each
|
||||
event, resolves conflicts between them and returns them.
|
||||
|
||||
@@ -29,7 +29,7 @@ from synapse.storage._base import db_to_json
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import DeviceLists, JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||
|
||||
@@ -217,6 +217,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
to_device_messages: List[JsonDict],
|
||||
one_time_key_counts: TransactionOneTimeKeyCounts,
|
||||
unused_fallback_keys: TransactionUnusedFallbackKeys,
|
||||
device_list_summary: DeviceLists,
|
||||
) -> AppServiceTransaction:
|
||||
"""Atomically creates a new transaction for this application service
|
||||
with the given list of events. Ephemeral events are NOT persisted to the
|
||||
@@ -231,6 +232,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
appservice devices in the transaction.
|
||||
unused_fallback_keys: Lists of unused fallback keys for relevant
|
||||
appservice devices in the transaction.
|
||||
device_list_summary: The device list summary to include in the transaction.
|
||||
|
||||
Returns:
|
||||
A new transaction.
|
||||
@@ -268,6 +270,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
to_device_messages=to_device_messages,
|
||||
one_time_key_counts=one_time_key_counts,
|
||||
unused_fallback_keys=unused_fallback_keys,
|
||||
device_list_summary=device_list_summary,
|
||||
)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
@@ -359,8 +362,8 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
|
||||
events = await self.get_events_as_list(event_ids)
|
||||
|
||||
# TODO: to-device messages, one-time key counts and unused fallback keys
|
||||
# are not yet populated for catch-up transactions.
|
||||
# TODO: to-device messages, one-time key counts, device list summaries and unused
|
||||
# fallback keys are not yet populated for catch-up transactions.
|
||||
# We likely want to populate those for reliability.
|
||||
return AppServiceTransaction(
|
||||
service=service,
|
||||
@@ -370,6 +373,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
to_device_messages=[],
|
||||
one_time_key_counts={},
|
||||
unused_fallback_keys={},
|
||||
device_list_summary=DeviceLists(),
|
||||
)
|
||||
|
||||
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
|
||||
@@ -430,7 +434,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
async def get_type_stream_id_for_appservice(
|
||||
self, service: ApplicationService, type: str
|
||||
) -> int:
|
||||
if type not in ("read_receipt", "presence", "to_device"):
|
||||
if type not in ("read_receipt", "presence", "to_device", "device_list"):
|
||||
raise ValueError(
|
||||
"Expected type to be a valid application stream id type, got %s"
|
||||
% (type,)
|
||||
@@ -446,7 +450,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
)
|
||||
last_stream_id = txn.fetchone()
|
||||
if last_stream_id is None or last_stream_id[0] is None: # no row exists
|
||||
return 0
|
||||
return 1
|
||||
else:
|
||||
return int(last_stream_id[0])
|
||||
|
||||
@@ -457,7 +461,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
async def set_appservice_stream_type_pos(
|
||||
self, service: ApplicationService, stream_type: str, pos: Optional[int]
|
||||
) -> None:
|
||||
if stream_type not in ("read_receipt", "presence", "to_device"):
|
||||
if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
|
||||
raise ValueError(
|
||||
"Expected type to be a valid application stream id type, got %s"
|
||||
% (stream_type,)
|
||||
|
||||
@@ -681,42 +681,64 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
return self._device_list_stream_cache.get_all_entities_changed(from_key)
|
||||
|
||||
async def get_users_whose_devices_changed(
|
||||
self, from_key: int, user_ids: Iterable[str]
|
||||
self,
|
||||
from_key: int,
|
||||
user_ids: Optional[Iterable[str]] = None,
|
||||
to_key: Optional[int] = None,
|
||||
) -> Set[str]:
|
||||
"""Get set of users whose devices have changed since `from_key` that
|
||||
are in the given list of user_ids.
|
||||
|
||||
Args:
|
||||
from_key: The device lists stream token
|
||||
user_ids: The user IDs to query for devices.
|
||||
|
||||
from_key: The minimum device lists stream token to query device list changes for,
|
||||
exclusive.
|
||||
user_ids: If provided, only check if these users have changed their device lists.
|
||||
Otherwise changes from all users are returned.
|
||||
to_key: The maximum device lists stream token to query device list changes for,
|
||||
inclusive.
|
||||
Returns:
|
||||
The set of user_ids whose devices have changed since `from_key`
|
||||
The set of user_ids whose devices have changed since `from_key` (exclusive)
|
||||
until `to_key` (inclusive).
|
||||
"""
|
||||
|
||||
# Get set of users who *may* have changed. Users not in the returned
|
||||
# list have definitely not changed.
|
||||
to_check = self._device_list_stream_cache.get_entities_changed(
|
||||
user_ids, from_key
|
||||
)
|
||||
if user_ids is None:
|
||||
# Get set of all users that have had device list changes since 'from_key'
|
||||
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
|
||||
from_key
|
||||
)
|
||||
else:
|
||||
# The same as above, but filter results to only those users in 'user_ids'
|
||||
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
|
||||
user_ids, from_key
|
||||
)
|
||||
|
||||
if not to_check:
|
||||
if not user_ids_to_check:
|
||||
return set()
|
||||
|
||||
def _get_users_whose_devices_changed_txn(txn):
|
||||
changes = set()
|
||||
|
||||
sql = """
|
||||
stream_id_where_clause = "stream_id > ?"
|
||||
sql_args = [from_key]
|
||||
|
||||
if to_key:
|
||||
stream_id_where_clause += " AND stream_id <= ?"
|
||||
sql_args += [to_key]
|
||||
|
||||
sql = f"""
|
||||
SELECT DISTINCT user_id FROM device_lists_stream
|
||||
WHERE stream_id > ?
|
||||
WHERE {stream_id_where_clause}
|
||||
AND
|
||||
"""
|
||||
|
||||
for chunk in batch_iter(to_check, 100):
|
||||
# Query device changes with a batch of users at a time
|
||||
for chunk in batch_iter(user_ids_to_check, 100):
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "user_id", chunk
|
||||
)
|
||||
txn.execute(sql + clause, (from_key,) + tuple(args))
|
||||
sql_args += args
|
||||
|
||||
txn.execute(sql + clause, sql_args)
|
||||
changes.update(user_id for user_id, in txn)
|
||||
|
||||
return changes
|
||||
|
||||
+18
@@ -0,0 +1,18 @@
|
||||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Add a column to track what device list changes stream id that this application
|
||||
-- service has been caught up to.
|
||||
ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT;
|
||||
@@ -561,7 +561,7 @@ class StateGroupStorage:
|
||||
return state_group_delta.prev_group, state_group_delta.delta_ids
|
||||
|
||||
async def get_state_groups_ids(
|
||||
self, _room_id: str, event_ids: Iterable[str]
|
||||
self, _room_id: str, event_ids: Collection[str]
|
||||
) -> Dict[int, MutableStateMap[str]]:
|
||||
"""Get the event IDs of all the state for the state groups for the given events
|
||||
|
||||
@@ -596,7 +596,7 @@ class StateGroupStorage:
|
||||
return group_to_state[state_group]
|
||||
|
||||
async def get_state_groups(
|
||||
self, room_id: str, event_ids: Iterable[str]
|
||||
self, room_id: str, event_ids: Collection[str]
|
||||
) -> Dict[int, List[EventBase]]:
|
||||
"""Get the state groups for the given list of event_ids
|
||||
|
||||
@@ -648,7 +648,7 @@ class StateGroupStorage:
|
||||
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
|
||||
|
||||
async def get_state_for_events(
|
||||
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
|
||||
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[str, StateMap[EventBase]]:
|
||||
"""Given a list of event_ids and type tuples, return a list of state
|
||||
dicts for each event.
|
||||
@@ -684,7 +684,7 @@ class StateGroupStorage:
|
||||
return {event: event_to_state[event] for event in event_ids}
|
||||
|
||||
async def get_state_ids_for_events(
|
||||
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
|
||||
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[str, StateMap[str]]:
|
||||
"""
|
||||
Get the state dicts corresponding to a list of events, containing the event_ids
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import (
|
||||
Match,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
@@ -743,6 +744,26 @@ class ReadReceipt:
|
||||
data: JsonDict
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class DeviceLists:
|
||||
"""
|
||||
Attributes:
|
||||
changed: user_ids whose devices may have changed
|
||||
left: user_ids whose devices we no longer track
|
||||
"""
|
||||
|
||||
# We need to use a factory here, otherwise `set` is not evaluated at
|
||||
# object instantiation, but instead at class definition instantiation.
|
||||
# The latter happening only once, thus always giving you the same sets
|
||||
# across multiple DeviceLists instances.
|
||||
# Also see: don't define mutable default arguments.
|
||||
changed: Set[str] = attr.ib(factory=set)
|
||||
left: Set[str] = attr.ib(factory=set)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.changed or self.left)
|
||||
|
||||
|
||||
def get_verify_key_from_cross_signing_key(key_info):
|
||||
"""Get the key ID and signedjson verify key from a cross-signing key dict
|
||||
|
||||
|
||||
+16
-1
@@ -81,8 +81,9 @@ async def filter_events_for_client(
|
||||
|
||||
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
|
||||
|
||||
# we exclude outliers at this point, and then handle them separately later
|
||||
event_id_to_state = await storage.state.get_state_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
frozenset(e.event_id for e in events if not e.internal_metadata.outlier),
|
||||
state_filter=StateFilter.from_types(types),
|
||||
)
|
||||
|
||||
@@ -154,6 +155,17 @@ async def filter_events_for_client(
|
||||
if event.event_id in always_include_ids:
|
||||
return event
|
||||
|
||||
# we need to handle outliers separately, since we don't have the room state.
|
||||
if event.internal_metadata.outlier:
|
||||
# Normally these can't be seen by clients, but we make an exception for
|
||||
# for out-of-band membership events (eg, incoming invites, or rejections of
|
||||
# said invite) for the user themselves.
|
||||
if event.type == EventTypes.Member and event.state_key == user_id:
|
||||
logger.debug("Returning out-of-band-membership event %s", event)
|
||||
return event
|
||||
|
||||
return None
|
||||
|
||||
state = event_id_to_state[event.event_id]
|
||||
|
||||
# get the room_visibility at the time of the event.
|
||||
@@ -198,6 +210,9 @@ async def filter_events_for_client(
|
||||
|
||||
# Always allow the user to see their own leave events, otherwise
|
||||
# they won't see the room disappear if they reject the invite
|
||||
#
|
||||
# (Note this doesn't work for out-of-band invite rejections, which don't
|
||||
# have prev_state populated. They are handled above in the outlier code.)
|
||||
if membership == "leave" and (
|
||||
prev_membership == "join" or prev_membership == "invite"
|
||||
):
|
||||
|
||||
@@ -36,7 +36,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||
hostname="matrix.org", # only used by get_groups_for_user
|
||||
)
|
||||
self.event = Mock(
|
||||
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
|
||||
event_id="$abc:xyz",
|
||||
type="m.something",
|
||||
room_id="!foo:bar",
|
||||
sender="@someone:somewhere",
|
||||
)
|
||||
|
||||
self.store = Mock()
|
||||
@@ -50,7 +53,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(self.event, self.store)
|
||||
self.service.is_interested_in_event(
|
||||
self.event.event_id, self.event, self.store
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -62,7 +67,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||
self.assertFalse(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(self.event, self.store)
|
||||
self.service.is_interested_in_event(
|
||||
self.event.event_id, self.event, self.store
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -76,7 +83,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(self.event, self.store)
|
||||
self.service.is_interested_in_event(
|
||||
self.event.event_id, self.event, self.store
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -90,7 +99,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(self.event, self.store)
|
||||
self.service.is_interested_in_event(
|
||||
self.event.event_id, self.event, self.store
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -104,7 +115,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||
self.assertFalse(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(self.event, self.store)
|
||||
self.service.is_interested_in_event(
|
||||
self.event.event_id, self.event, self.store
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -121,7 +134,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(self.event, self.store)
|
||||
self.service.is_interested_in_event(
|
||||
self.event.event_id, self.event, self.store
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -174,7 +189,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||
self.assertFalse(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(self.event, self.store)
|
||||
self.service.is_interested_in_event(
|
||||
self.event.event_id, self.event, self.store
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -191,7 +208,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(self.event, self.store)
|
||||
self.service.is_interested_in_event(
|
||||
self.event.event_id, self.event, self.store
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -207,7 +226,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(self.event, self.store)
|
||||
self.service.is_interested_in_event(
|
||||
self.event.event_id, self.event, self.store
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -225,7 +246,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.service.is_interested(event=self.event, store=self.store)
|
||||
self.service.is_interested_in_event(
|
||||
self.event.event_id, self.event, self.store
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ from synapse.appservice.scheduler import (
|
||||
)
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import DeviceLists
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
@@ -70,6 +71,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
||||
to_device_messages=[], # txn made and saved
|
||||
one_time_key_counts={},
|
||||
unused_fallback_keys={},
|
||||
device_list_summary=DeviceLists(),
|
||||
)
|
||||
self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made
|
||||
txn.complete.assert_called_once_with(self.store) # txn completed
|
||||
@@ -96,6 +98,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
||||
to_device_messages=[], # txn made and saved
|
||||
one_time_key_counts={},
|
||||
unused_fallback_keys={},
|
||||
device_list_summary=DeviceLists(),
|
||||
)
|
||||
self.assertEqual(0, txn.send.call_count) # txn not sent though
|
||||
self.assertEqual(0, txn.complete.call_count) # or completed
|
||||
@@ -124,6 +127,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
|
||||
to_device_messages=[],
|
||||
one_time_key_counts={},
|
||||
unused_fallback_keys={},
|
||||
device_list_summary=DeviceLists(),
|
||||
)
|
||||
self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made
|
||||
self.assertEqual(1, self.recoverer.recover.call_count) # and invoked
|
||||
@@ -225,7 +229,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
|
||||
service = Mock(id=4)
|
||||
event = Mock()
|
||||
self.scheduler.enqueue_for_appservice(service, events=[event])
|
||||
self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None)
|
||||
self.txn_ctrl.send.assert_called_once_with(
|
||||
service, [event], [], [], None, None, DeviceLists()
|
||||
)
|
||||
|
||||
def test_send_single_event_with_queue(self):
|
||||
d = defer.Deferred()
|
||||
@@ -240,12 +246,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
|
||||
# (call enqueue_for_appservice multiple times deliberately)
|
||||
self.scheduler.enqueue_for_appservice(service, events=[event2])
|
||||
self.scheduler.enqueue_for_appservice(service, events=[event3])
|
||||
self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None)
|
||||
self.txn_ctrl.send.assert_called_with(
|
||||
service, [event], [], [], None, None, DeviceLists()
|
||||
)
|
||||
self.assertEqual(1, self.txn_ctrl.send.call_count)
|
||||
# Resolve the send event: expect the queued events to be sent
|
||||
d.callback(service)
|
||||
self.txn_ctrl.send.assert_called_with(
|
||||
service, [event2, event3], [], [], None, None
|
||||
service, [event2, event3], [], [], None, None, DeviceLists()
|
||||
)
|
||||
self.assertEqual(2, self.txn_ctrl.send.call_count)
|
||||
|
||||
@@ -272,15 +280,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
|
||||
# send events for different ASes and make sure they are sent
|
||||
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
|
||||
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
|
||||
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None)
|
||||
self.txn_ctrl.send.assert_called_with(
|
||||
srv1, [srv_1_event], [], [], None, None, DeviceLists()
|
||||
)
|
||||
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
|
||||
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
|
||||
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None)
|
||||
self.txn_ctrl.send.assert_called_with(
|
||||
srv2, [srv_2_event], [], [], None, None, DeviceLists()
|
||||
)
|
||||
|
||||
# make sure callbacks for a service only send queued events for THAT
|
||||
# service
|
||||
srv_2_defer.callback(srv2)
|
||||
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None)
|
||||
self.txn_ctrl.send.assert_called_with(
|
||||
srv2, [srv_2_event2], [], [], None, None, DeviceLists()
|
||||
)
|
||||
self.assertEqual(3, self.txn_ctrl.send.call_count)
|
||||
|
||||
def test_send_large_txns(self):
|
||||
@@ -300,17 +314,17 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Expect the first event to be sent immediately.
|
||||
self.txn_ctrl.send.assert_called_with(
|
||||
service, [event_list[0]], [], [], None, None
|
||||
service, [event_list[0]], [], [], None, None, DeviceLists()
|
||||
)
|
||||
srv_1_defer.callback(service)
|
||||
# Then send the next 100 events
|
||||
self.txn_ctrl.send.assert_called_with(
|
||||
service, event_list[1:101], [], [], None, None
|
||||
service, event_list[1:101], [], [], None, None, DeviceLists()
|
||||
)
|
||||
srv_2_defer.callback(service)
|
||||
# Then the final 99 events
|
||||
self.txn_ctrl.send.assert_called_with(
|
||||
service, event_list[101:], [], [], None, None
|
||||
service, event_list[101:], [], [], None, None, DeviceLists()
|
||||
)
|
||||
self.assertEqual(3, self.txn_ctrl.send.call_count)
|
||||
|
||||
@@ -320,7 +334,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
|
||||
event_list = [Mock(name="event")]
|
||||
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
|
||||
self.txn_ctrl.send.assert_called_once_with(
|
||||
service, [], event_list, [], None, None
|
||||
service, [], event_list, [], None, None, DeviceLists()
|
||||
)
|
||||
|
||||
def test_send_multiple_ephemeral_no_queue(self):
|
||||
@@ -329,7 +343,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
|
||||
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
|
||||
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
|
||||
self.txn_ctrl.send.assert_called_once_with(
|
||||
service, [], event_list, [], None, None
|
||||
service, [], event_list, [], None, None, DeviceLists()
|
||||
)
|
||||
|
||||
def test_send_single_ephemeral_with_queue(self):
|
||||
@@ -345,13 +359,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
|
||||
# Send more events: expect send() to NOT be called multiple times.
|
||||
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
|
||||
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
|
||||
self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None)
|
||||
self.txn_ctrl.send.assert_called_with(
|
||||
service, [], event_list_1, [], None, None, DeviceLists()
|
||||
)
|
||||
self.assertEqual(1, self.txn_ctrl.send.call_count)
|
||||
# Resolve txn_ctrl.send
|
||||
d.callback(service)
|
||||
# Expect the queued events to be sent
|
||||
self.txn_ctrl.send.assert_called_with(
|
||||
service, [], event_list_2 + event_list_3, [], None, None
|
||||
service, [], event_list_2 + event_list_3, [], None, None, DeviceLists()
|
||||
)
|
||||
self.assertEqual(2, self.txn_ctrl.send.call_count)
|
||||
|
||||
@@ -365,8 +381,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
|
||||
event_list = first_chunk + second_chunk
|
||||
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
|
||||
self.txn_ctrl.send.assert_called_once_with(
|
||||
service, [], first_chunk, [], None, None
|
||||
service, [], first_chunk, [], None, None, DeviceLists()
|
||||
)
|
||||
d.callback(service)
|
||||
self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None)
|
||||
self.txn_ctrl.send.assert_called_with(
|
||||
service, [], second_chunk, [], None, None, DeviceLists()
|
||||
)
|
||||
self.assertEqual(2, self.txn_ctrl.send.call_count)
|
||||
|
||||
@@ -16,6 +16,7 @@ from synapse.api.constants import EventContentFields
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events.utils import (
|
||||
SerializeEventConfig,
|
||||
copy_power_levels_contents,
|
||||
prune_event,
|
||||
serialize_event,
|
||||
@@ -392,7 +393,9 @@ class PruneEventTestCase(unittest.TestCase):
|
||||
|
||||
class SerializeEventTestCase(unittest.TestCase):
|
||||
def serialize(self, ev, fields):
|
||||
return serialize_event(ev, 1479807801915, only_event_fields=fields)
|
||||
return serialize_event(
|
||||
ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
|
||||
)
|
||||
|
||||
def test_event_fields_works_with_keys(self):
|
||||
self.assertEqual(
|
||||
|
||||
@@ -15,6 +15,8 @@
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
from unittest.mock import Mock
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
@@ -59,11 +61,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||
self.event_source = hs.get_event_sources()
|
||||
|
||||
def test_notify_interested_services(self):
|
||||
interested_service = self._mkservice(is_interested=True)
|
||||
interested_service = self._mkservice(is_interested_in_event=True)
|
||||
services = [
|
||||
self._mkservice(is_interested=False),
|
||||
self._mkservice(is_interested_in_event=False),
|
||||
interested_service,
|
||||
self._mkservice(is_interested=False),
|
||||
self._mkservice(is_interested_in_event=False),
|
||||
]
|
||||
|
||||
self.mock_as_api.query_user.return_value = make_awaitable(True)
|
||||
@@ -85,7 +87,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||
|
||||
def test_query_user_exists_unknown_user(self):
|
||||
user_id = "@someone:anywhere"
|
||||
services = [self._mkservice(is_interested=True)]
|
||||
services = [self._mkservice(is_interested_in_event=True)]
|
||||
services[0].is_interested_in_user.return_value = True
|
||||
self.mock_store.get_app_services.return_value = services
|
||||
self.mock_store.get_user_by_id.return_value = make_awaitable(None)
|
||||
@@ -102,7 +104,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||
|
||||
def test_query_user_exists_known_user(self):
|
||||
user_id = "@someone:anywhere"
|
||||
services = [self._mkservice(is_interested=True)]
|
||||
services = [self._mkservice(is_interested_in_event=True)]
|
||||
services[0].is_interested_in_user.return_value = True
|
||||
self.mock_store.get_app_services.return_value = services
|
||||
self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
|
||||
@@ -127,11 +129,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||
|
||||
room_id = "!alpha:bet"
|
||||
servers = ["aperture"]
|
||||
interested_service = self._mkservice_alias(is_interested_in_alias=True)
|
||||
interested_service = self._mkservice_alias(is_room_alias_in_namespace=True)
|
||||
services = [
|
||||
self._mkservice_alias(is_interested_in_alias=False),
|
||||
self._mkservice_alias(is_room_alias_in_namespace=False),
|
||||
interested_service,
|
||||
self._mkservice_alias(is_interested_in_alias=False),
|
||||
self._mkservice_alias(is_room_alias_in_namespace=False),
|
||||
]
|
||||
|
||||
self.mock_as_api.query_alias.return_value = make_awaitable(True)
|
||||
@@ -275,7 +277,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||
to be pushed out to interested appservices, and that the stream ID is
|
||||
updated accordingly.
|
||||
"""
|
||||
interested_service = self._mkservice(is_interested=True)
|
||||
interested_service = self._mkservice(is_interested_in_event=True)
|
||||
services = [interested_service]
|
||||
self.mock_store.get_app_services.return_value = services
|
||||
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
|
||||
@@ -304,7 +306,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||
Test sending out of order ephemeral events to the appservice handler
|
||||
are ignored.
|
||||
"""
|
||||
interested_service = self._mkservice(is_interested=True)
|
||||
interested_service = self._mkservice(is_interested_in_event=True)
|
||||
services = [interested_service]
|
||||
|
||||
self.mock_store.get_app_services.return_value = services
|
||||
@@ -325,17 +327,45 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||
interested_service, ephemeral=[]
|
||||
)
|
||||
|
||||
def _mkservice(self, is_interested, protocols=None):
|
||||
def _mkservice(
|
||||
self, is_interested_in_event: bool, protocols: Optional[Iterable] = None
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a new mock representing an ApplicationService.
|
||||
|
||||
Args:
|
||||
is_interested_in_event: Whether this application service will be considered
|
||||
interested in all events.
|
||||
protocols: The third-party protocols that this application service claims to
|
||||
support.
|
||||
|
||||
Returns:
|
||||
A mock representing the ApplicationService.
|
||||
"""
|
||||
service = Mock()
|
||||
service.is_interested.return_value = make_awaitable(is_interested)
|
||||
service.is_interested_in_event.return_value = make_awaitable(
|
||||
is_interested_in_event
|
||||
)
|
||||
service.token = "mock_service_token"
|
||||
service.url = "mock_service_url"
|
||||
service.protocols = protocols
|
||||
return service
|
||||
|
||||
def _mkservice_alias(self, is_interested_in_alias):
|
||||
def _mkservice_alias(self, is_room_alias_in_namespace: bool) -> Mock:
|
||||
"""
|
||||
Create a new mock representing an ApplicationService that is or is not interested
|
||||
any given room aliase.
|
||||
|
||||
Args:
|
||||
is_room_alias_in_namespace: If true, the application service will be interested
|
||||
in all room aliases that are queried against it. If false, the application
|
||||
service will not be interested in any room aliases.
|
||||
|
||||
Returns:
|
||||
A mock representing the ApplicationService.
|
||||
"""
|
||||
service = Mock()
|
||||
service.is_interested_in_alias.return_value = is_interested_in_alias
|
||||
service.is_room_alias_in_namespace.return_value = is_room_alias_in_namespace
|
||||
service.token = "mock_service_token"
|
||||
service.url = "mock_service_url"
|
||||
return service
|
||||
@@ -443,6 +473,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||
to_device_messages,
|
||||
_otks,
|
||||
_fbks,
|
||||
_device_list_summary,
|
||||
) = self.send_mock.call_args[0]
|
||||
|
||||
# Assert that this was the same to-device message that local_user sent
|
||||
@@ -555,7 +586,15 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||
service_id_to_message_count: Dict[str, int] = {}
|
||||
|
||||
for call in self.send_mock.call_args_list:
|
||||
service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0]
|
||||
(
|
||||
service,
|
||||
_events,
|
||||
_ephemeral,
|
||||
to_device_messages,
|
||||
_otks,
|
||||
_fbks,
|
||||
_device_list_summary,
|
||||
) = call[0]
|
||||
|
||||
# Check that this was made to an interested service
|
||||
self.assertIn(service, interested_appservices)
|
||||
@@ -599,6 +638,115 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||
return appservice
|
||||
|
||||
|
||||
class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase):
|
||||
"""
|
||||
Tests that the ApplicationServicesHandler sends device list updates to application
|
||||
services correctly.
|
||||
"""
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# Allow us to modify cached feature flags mid-test
|
||||
self.as_handler = hs.get_application_service_handler()
|
||||
|
||||
# Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
|
||||
# will be sent over the wire
|
||||
self.put_json = simple_async_mock()
|
||||
hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment]
|
||||
|
||||
# Mock out application services, and allow defining our own in tests
|
||||
self._services: List[ApplicationService] = []
|
||||
self.hs.get_datastores().main.get_app_services = Mock(
|
||||
return_value=self._services
|
||||
)
|
||||
|
||||
# Test across a variety of configuration values
|
||||
@parameterized.expand(
|
||||
[
|
||||
(True, True, True),
|
||||
(True, False, False),
|
||||
(False, True, False),
|
||||
(False, False, False),
|
||||
]
|
||||
)
|
||||
@unittest.override_config({"experimental_features": {"": False}})
|
||||
def test_application_service_receives_device_list_updates(
|
||||
self,
|
||||
experimental_feature_enabled: bool,
|
||||
as_supports_txn_extensions: bool,
|
||||
as_should_receive_device_list_updates: bool,
|
||||
):
|
||||
"""
|
||||
Tests that an application service receives notice of changed device
|
||||
lists for a user, when a user changes their device lists.
|
||||
|
||||
Arguments above are populated by parameterized.
|
||||
|
||||
Args:
|
||||
as_should_receive_device_list_updates: Whether we expect the AS to receive the
|
||||
device list changes.
|
||||
experimental_feature_enabled: Whether the "msc3202_transaction_extensions" experimental
|
||||
feature is enabled. This feature must be enabled for device lists to ASs to work.
|
||||
as_supports_txn_extensions: Whether the application service has explicitly registered
|
||||
to receive information defined by MSC3202 - which includes device list changes.
|
||||
"""
|
||||
# Change whether the experimental feature is enabled or disabled before making
|
||||
# device list changes
|
||||
self.as_handler._msc3202_transaction_extensions_enabled = (
|
||||
experimental_feature_enabled
|
||||
)
|
||||
|
||||
# Create an appservice that is interested in "local_user"
|
||||
appservice = ApplicationService(
|
||||
token=random_string(10),
|
||||
hostname="example.com",
|
||||
id=random_string(10),
|
||||
sender="@as:example.com",
|
||||
rate_limited=False,
|
||||
namespaces={
|
||||
ApplicationService.NS_USERS: [
|
||||
{
|
||||
"regex": "@local_user:.+",
|
||||
"exclusive": False,
|
||||
}
|
||||
],
|
||||
},
|
||||
supports_ephemeral=True,
|
||||
msc3202_transaction_extensions=as_supports_txn_extensions,
|
||||
# Must be set for Synapse to try pushing data to the AS
|
||||
hs_token="abcde",
|
||||
url="some_url",
|
||||
)
|
||||
|
||||
# Register the application service
|
||||
self._services.append(appservice)
|
||||
|
||||
# Register a user on the homeserver
|
||||
self.local_user = self.register_user("local_user", "password")
|
||||
self.local_user_token = self.login("local_user", "password")
|
||||
|
||||
if as_should_receive_device_list_updates:
|
||||
# Ensure that the resulting JSON uses the unstable prefix and contains the
|
||||
# expected users
|
||||
self.put_json.assert_called_once()
|
||||
json_body = self.put_json.call_args.kwargs["json_body"]
|
||||
|
||||
# Our application service should have received a device list update with
|
||||
# "local_user" in the "changed" list
|
||||
device_list_dict = json_body.get("org.matrix.msc3202.device_lists", {})
|
||||
self.assertEqual([], device_list_dict["left"])
|
||||
self.assertEqual([self.local_user], device_list_dict["changed"])
|
||||
|
||||
else:
|
||||
# No device list changes should have been sent out
|
||||
self.put_json.assert_not_called()
|
||||
|
||||
|
||||
class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
|
||||
# Argument indices for pulling out arguments from a `send_mock`.
|
||||
ARG_OTK_COUNTS = 4
|
||||
|
||||
+155
-131
@@ -15,11 +15,12 @@ import json
|
||||
import os
|
||||
import re
|
||||
from email.parser import Parser
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pkg_resources
|
||||
|
||||
from twisted.internet.interfaces import IReactorTCP
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
@@ -30,6 +31,7 @@ from synapse.rest import admin
|
||||
from synapse.rest.client import account, login, register, room
|
||||
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
@@ -46,7 +48,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
# Email config.
|
||||
@@ -67,20 +69,27 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
hs = self.setup_test_homeserver(config=config)
|
||||
|
||||
async def sendmail(
|
||||
reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
|
||||
):
|
||||
self.email_attempts.append(msg)
|
||||
reactor: IReactorTCP,
|
||||
smtphost: str,
|
||||
smtpport: int,
|
||||
from_addr: str,
|
||||
to_addr: str,
|
||||
msg_bytes: bytes,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.email_attempts.append(msg_bytes)
|
||||
|
||||
self.email_attempts = []
|
||||
self.email_attempts: List[bytes] = []
|
||||
hs.get_send_email_handler()._sendmail = sendmail
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
|
||||
|
||||
def test_basic_password_reset(self):
|
||||
def test_basic_password_reset(self) -> None:
|
||||
"""Test basic password reset flow"""
|
||||
old_password = "monkey"
|
||||
new_password = "kangeroo"
|
||||
@@ -118,7 +127,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
self.attempt_wrong_password_login("kermit", old_password)
|
||||
|
||||
@override_config({"rc_3pid_validation": {"burst_count": 3}})
|
||||
def test_ratelimit_by_email(self):
|
||||
def test_ratelimit_by_email(self) -> None:
|
||||
"""Test that we ratelimit /requestToken for the same email."""
|
||||
old_password = "monkey"
|
||||
new_password = "kangeroo"
|
||||
@@ -139,7 +148,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def reset(ip):
|
||||
def reset(ip: str) -> None:
|
||||
client_secret = "foobar"
|
||||
session_id = self._request_token(email, client_secret, ip)
|
||||
|
||||
@@ -166,7 +175,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(cm.exception.code, 429)
|
||||
|
||||
def test_basic_password_reset_canonicalise_email(self):
|
||||
def test_basic_password_reset_canonicalise_email(self) -> None:
|
||||
"""Test basic password reset flow
|
||||
Request password reset with different spelling
|
||||
"""
|
||||
@@ -206,7 +215,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
# Assert we can't log in with the old password
|
||||
self.attempt_wrong_password_login("kermit", old_password)
|
||||
|
||||
def test_cant_reset_password_without_clicking_link(self):
|
||||
def test_cant_reset_password_without_clicking_link(self) -> None:
|
||||
"""Test that we do actually need to click the link in the email"""
|
||||
old_password = "monkey"
|
||||
new_password = "kangeroo"
|
||||
@@ -241,7 +250,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
# Assert we can't log in with the new password
|
||||
self.attempt_wrong_password_login("kermit", new_password)
|
||||
|
||||
def test_no_valid_token(self):
|
||||
def test_no_valid_token(self) -> None:
|
||||
"""Test that we do actually need to request a token and can't just
|
||||
make a session up.
|
||||
"""
|
||||
@@ -277,7 +286,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
self.attempt_wrong_password_login("kermit", new_password)
|
||||
|
||||
@unittest.override_config({"request_token_inhibit_3pid_errors": True})
|
||||
def test_password_reset_bad_email_inhibit_error(self):
|
||||
def test_password_reset_bad_email_inhibit_error(self) -> None:
|
||||
"""Test that triggering a password reset with an email address that isn't bound
|
||||
to an account doesn't leak the lack of binding for that address if configured
|
||||
that way.
|
||||
@@ -292,7 +301,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertIsNotNone(session_id)
|
||||
|
||||
def _request_token(self, email, client_secret, ip="127.0.0.1"):
|
||||
def _request_token(
|
||||
self,
|
||||
email: str,
|
||||
client_secret: str,
|
||||
ip: str = "127.0.0.1",
|
||||
) -> str:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
b"account/password/email/requestToken",
|
||||
@@ -309,7 +323,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
return channel.json_body["sid"]
|
||||
|
||||
def _validate_token(self, link):
|
||||
def _validate_token(self, link: str) -> None:
|
||||
# Remove the host
|
||||
path = link.replace("https://example.com", "")
|
||||
|
||||
@@ -339,7 +353,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
|
||||
def _get_link_from_email(self):
|
||||
def _get_link_from_email(self) -> str:
|
||||
assert self.email_attempts, "No emails have been sent"
|
||||
|
||||
raw_msg = self.email_attempts[-1].decode("UTF-8")
|
||||
@@ -354,14 +368,19 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
if not text:
|
||||
self.fail("Could not find text portion of email to parse")
|
||||
|
||||
assert text is not None
|
||||
match = re.search(r"https://example.com\S+", text)
|
||||
assert match, "Could not find link in email"
|
||||
|
||||
return match.group(0)
|
||||
|
||||
def _reset_password(
|
||||
self, new_password, session_id, client_secret, expected_code=200
|
||||
):
|
||||
self,
|
||||
new_password: str,
|
||||
session_id: str,
|
||||
client_secret: str,
|
||||
expected_code: int = 200,
|
||||
) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
b"account/password",
|
||||
@@ -388,11 +407,11 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.hs = self.setup_test_homeserver()
|
||||
return self.hs
|
||||
|
||||
def test_deactivate_account(self):
|
||||
def test_deactivate_account(self) -> None:
|
||||
user_id = self.register_user("kermit", "test")
|
||||
tok = self.login("kermit", "test")
|
||||
|
||||
@@ -407,7 +426,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request("GET", "account/whoami", access_token=tok)
|
||||
self.assertEqual(channel.code, 401)
|
||||
|
||||
def test_pending_invites(self):
|
||||
def test_pending_invites(self) -> None:
|
||||
"""Tests that deactivating a user rejects every pending invite for them."""
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
@@ -448,7 +467,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(len(memberships), 1, memberships)
|
||||
self.assertEqual(memberships[0].room_id, room_id, memberships)
|
||||
|
||||
def deactivate(self, user_id, tok):
|
||||
def deactivate(self, user_id: str, tok: str) -> None:
|
||||
request_data = json.dumps(
|
||||
{
|
||||
"auth": {
|
||||
@@ -474,12 +493,12 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
|
||||
register.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self):
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
config = super().default_config()
|
||||
config["allow_guest_access"] = True
|
||||
return config
|
||||
|
||||
def test_GET_whoami(self):
|
||||
def test_GET_whoami(self) -> None:
|
||||
device_id = "wouldgohere"
|
||||
user_id = self.register_user("kermit", "test")
|
||||
tok = self.login("kermit", "test", device_id=device_id)
|
||||
@@ -496,7 +515,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_GET_whoami_guests(self):
|
||||
def test_GET_whoami_guests(self) -> None:
|
||||
channel = self.make_request(
|
||||
b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}"
|
||||
)
|
||||
@@ -516,7 +535,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_GET_whoami_appservices(self):
|
||||
def test_GET_whoami_appservices(self) -> None:
|
||||
user_id = "@as:test"
|
||||
as_token = "i_am_an_app_service"
|
||||
|
||||
@@ -541,7 +560,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertFalse(hasattr(whoami, "device_id"))
|
||||
|
||||
def _whoami(self, tok):
|
||||
def _whoami(self, tok: str) -> JsonDict:
|
||||
channel = self.make_request("GET", "account/whoami", {}, access_token=tok)
|
||||
self.assertEqual(channel.code, 200)
|
||||
return channel.json_body
|
||||
@@ -555,7 +574,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
# Email config.
|
||||
@@ -576,16 +595,23 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
self.hs = self.setup_test_homeserver(config=config)
|
||||
|
||||
async def sendmail(
|
||||
reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
|
||||
):
|
||||
self.email_attempts.append(msg)
|
||||
reactor: IReactorTCP,
|
||||
smtphost: str,
|
||||
smtpport: int,
|
||||
from_addr: str,
|
||||
to_addr: str,
|
||||
msg_bytes: bytes,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.email_attempts.append(msg_bytes)
|
||||
|
||||
self.email_attempts = []
|
||||
self.email_attempts: List[bytes] = []
|
||||
self.hs.get_send_email_handler()._sendmail = sendmail
|
||||
|
||||
return self.hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
|
||||
self.user_id = self.register_user("kermit", "test")
|
||||
@@ -593,83 +619,73 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
self.email = "test@example.com"
|
||||
self.url_3pid = b"account/3pid"
|
||||
|
||||
def test_add_valid_email(self):
|
||||
self.get_success(self._add_email(self.email, self.email))
|
||||
def test_add_valid_email(self) -> None:
|
||||
self._add_email(self.email, self.email)
|
||||
|
||||
def test_add_valid_email_second_time(self):
|
||||
self.get_success(self._add_email(self.email, self.email))
|
||||
self.get_success(
|
||||
self._request_token_invalid_email(
|
||||
self.email,
|
||||
expected_errcode=Codes.THREEPID_IN_USE,
|
||||
expected_error="Email is already in use",
|
||||
)
|
||||
def test_add_valid_email_second_time(self) -> None:
|
||||
self._add_email(self.email, self.email)
|
||||
self._request_token_invalid_email(
|
||||
self.email,
|
||||
expected_errcode=Codes.THREEPID_IN_USE,
|
||||
expected_error="Email is already in use",
|
||||
)
|
||||
|
||||
def test_add_valid_email_second_time_canonicalise(self):
|
||||
self.get_success(self._add_email(self.email, self.email))
|
||||
self.get_success(
|
||||
self._request_token_invalid_email(
|
||||
"TEST@EXAMPLE.COM",
|
||||
expected_errcode=Codes.THREEPID_IN_USE,
|
||||
expected_error="Email is already in use",
|
||||
)
|
||||
def test_add_valid_email_second_time_canonicalise(self) -> None:
|
||||
self._add_email(self.email, self.email)
|
||||
self._request_token_invalid_email(
|
||||
"TEST@EXAMPLE.COM",
|
||||
expected_errcode=Codes.THREEPID_IN_USE,
|
||||
expected_error="Email is already in use",
|
||||
)
|
||||
|
||||
def test_add_email_no_at(self):
|
||||
self.get_success(
|
||||
self._request_token_invalid_email(
|
||||
"address-without-at.bar",
|
||||
expected_errcode=Codes.UNKNOWN,
|
||||
expected_error="Unable to parse email address",
|
||||
)
|
||||
def test_add_email_no_at(self) -> None:
|
||||
self._request_token_invalid_email(
|
||||
"address-without-at.bar",
|
||||
expected_errcode=Codes.UNKNOWN,
|
||||
expected_error="Unable to parse email address",
|
||||
)
|
||||
|
||||
def test_add_email_two_at(self):
|
||||
self.get_success(
|
||||
self._request_token_invalid_email(
|
||||
"foo@foo@test.bar",
|
||||
expected_errcode=Codes.UNKNOWN,
|
||||
expected_error="Unable to parse email address",
|
||||
)
|
||||
def test_add_email_two_at(self) -> None:
|
||||
self._request_token_invalid_email(
|
||||
"foo@foo@test.bar",
|
||||
expected_errcode=Codes.UNKNOWN,
|
||||
expected_error="Unable to parse email address",
|
||||
)
|
||||
|
||||
def test_add_email_bad_format(self):
|
||||
self.get_success(
|
||||
self._request_token_invalid_email(
|
||||
"user@bad.example.net@good.example.com",
|
||||
expected_errcode=Codes.UNKNOWN,
|
||||
expected_error="Unable to parse email address",
|
||||
)
|
||||
def test_add_email_bad_format(self) -> None:
|
||||
self._request_token_invalid_email(
|
||||
"user@bad.example.net@good.example.com",
|
||||
expected_errcode=Codes.UNKNOWN,
|
||||
expected_error="Unable to parse email address",
|
||||
)
|
||||
|
||||
def test_add_email_domain_to_lower(self):
|
||||
self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar"))
|
||||
def test_add_email_domain_to_lower(self) -> None:
|
||||
self._add_email("foo@TEST.BAR", "foo@test.bar")
|
||||
|
||||
def test_add_email_domain_with_umlaut(self):
|
||||
self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com"))
|
||||
def test_add_email_domain_with_umlaut(self) -> None:
|
||||
self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")
|
||||
|
||||
def test_add_email_address_casefold(self):
|
||||
self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com"))
|
||||
def test_add_email_address_casefold(self) -> None:
|
||||
self._add_email("Strauß@Example.com", "strauss@example.com")
|
||||
|
||||
def test_address_trim(self):
|
||||
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
|
||||
def test_address_trim(self) -> None:
|
||||
self._add_email(" foo@test.bar ", "foo@test.bar")
|
||||
|
||||
@override_config({"rc_3pid_validation": {"burst_count": 3}})
|
||||
def test_ratelimit_by_ip(self):
|
||||
def test_ratelimit_by_ip(self) -> None:
|
||||
"""Tests that adding emails is ratelimited by IP"""
|
||||
|
||||
# We expect to be able to set three emails before getting ratelimited.
|
||||
self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
|
||||
self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
|
||||
self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
|
||||
self._add_email("foo1@test.bar", "foo1@test.bar")
|
||||
self._add_email("foo2@test.bar", "foo2@test.bar")
|
||||
self._add_email("foo3@test.bar", "foo3@test.bar")
|
||||
|
||||
with self.assertRaises(HttpResponseException) as cm:
|
||||
self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
|
||||
self._add_email("foo4@test.bar", "foo4@test.bar")
|
||||
|
||||
self.assertEqual(cm.exception.code, 429)
|
||||
|
||||
def test_add_email_if_disabled(self):
|
||||
def test_add_email_if_disabled(self) -> None:
|
||||
"""Test adding email to profile when doing so is disallowed"""
|
||||
self.hs.config.registration.enable_3pid_changes = False
|
||||
|
||||
@@ -695,7 +711,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
# Get user
|
||||
@@ -705,10 +721,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertFalse(channel.json_body["threepids"])
|
||||
|
||||
def test_delete_email(self):
|
||||
def test_delete_email(self) -> None:
|
||||
"""Test deleting an email from profile"""
|
||||
# Add a threepid
|
||||
self.get_success(
|
||||
@@ -727,7 +743,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
{"medium": "email", "address": self.email},
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
|
||||
# Get user
|
||||
channel = self.make_request(
|
||||
@@ -736,10 +752,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertFalse(channel.json_body["threepids"])
|
||||
|
||||
def test_delete_email_if_disabled(self):
|
||||
def test_delete_email_if_disabled(self) -> None:
|
||||
"""Test deleting an email from profile when disallowed"""
|
||||
self.hs.config.registration.enable_3pid_changes = False
|
||||
|
||||
@@ -761,7 +777,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
|
||||
|
||||
# Get user
|
||||
@@ -771,11 +787,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
|
||||
|
||||
def test_cant_add_email_without_clicking_link(self):
|
||||
def test_cant_add_email_without_clicking_link(self) -> None:
|
||||
"""Test that we do actually need to click the link in the email"""
|
||||
client_secret = "foobar"
|
||||
session_id = self._request_token(self.email, client_secret)
|
||||
@@ -797,7 +813,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
|
||||
|
||||
# Get user
|
||||
@@ -807,10 +823,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertFalse(channel.json_body["threepids"])
|
||||
|
||||
def test_no_valid_token(self):
|
||||
def test_no_valid_token(self) -> None:
|
||||
"""Test that we do actually need to request a token and can't just
|
||||
make a session up.
|
||||
"""
|
||||
@@ -832,7 +848,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
|
||||
|
||||
# Get user
|
||||
@@ -842,11 +858,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertFalse(channel.json_body["threepids"])
|
||||
|
||||
@override_config({"next_link_domain_whitelist": None})
|
||||
def test_next_link(self):
|
||||
def test_next_link(self) -> None:
|
||||
"""Tests a valid next_link parameter value with no whitelist (good case)"""
|
||||
self._request_token(
|
||||
"something@example.com",
|
||||
@@ -856,7 +872,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"next_link_domain_whitelist": None})
|
||||
def test_next_link_exotic_protocol(self):
|
||||
def test_next_link_exotic_protocol(self) -> None:
|
||||
"""Tests using a esoteric protocol as a next_link parameter value.
|
||||
Someone may be hosting a client on IPFS etc.
|
||||
"""
|
||||
@@ -868,7 +884,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"next_link_domain_whitelist": None})
|
||||
def test_next_link_file_uri(self):
|
||||
def test_next_link_file_uri(self) -> None:
|
||||
"""Tests next_link parameters cannot be file URI"""
|
||||
# Attempt to use a next_link value that points to the local disk
|
||||
self._request_token(
|
||||
@@ -879,7 +895,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
|
||||
def test_next_link_domain_whitelist(self):
|
||||
def test_next_link_domain_whitelist(self) -> None:
|
||||
"""Tests next_link parameters must fit the whitelist if provided"""
|
||||
|
||||
# Ensure not providing a next_link parameter still works
|
||||
@@ -912,7 +928,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
@override_config({"next_link_domain_whitelist": []})
|
||||
def test_empty_next_link_domain_whitelist(self):
|
||||
def test_empty_next_link_domain_whitelist(self) -> None:
|
||||
"""Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
|
||||
disallowed
|
||||
"""
|
||||
@@ -962,28 +978,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def _request_token_invalid_email(
|
||||
self,
|
||||
email,
|
||||
expected_errcode,
|
||||
expected_error,
|
||||
client_secret="foobar",
|
||||
):
|
||||
email: str,
|
||||
expected_errcode: str,
|
||||
expected_error: str,
|
||||
client_secret: str = "foobar",
|
||||
) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
b"account/3pid/email/requestToken",
|
||||
{"client_secret": client_secret, "email": email, "send_attempt": 1},
|
||||
)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(expected_errcode, channel.json_body["errcode"])
|
||||
self.assertEqual(expected_error, channel.json_body["error"])
|
||||
|
||||
def _validate_token(self, link):
|
||||
def _validate_token(self, link: str) -> None:
|
||||
# Remove the host
|
||||
path = link.replace("https://example.com", "")
|
||||
|
||||
channel = self.make_request("GET", path, shorthand=False)
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
|
||||
def _get_link_from_email(self):
|
||||
def _get_link_from_email(self) -> str:
|
||||
assert self.email_attempts, "No emails have been sent"
|
||||
|
||||
raw_msg = self.email_attempts[-1].decode("UTF-8")
|
||||
@@ -998,12 +1014,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
if not text:
|
||||
self.fail("Could not find text portion of email to parse")
|
||||
|
||||
assert text is not None
|
||||
match = re.search(r"https://example.com\S+", text)
|
||||
assert match, "Could not find link in email"
|
||||
|
||||
return match.group(0)
|
||||
|
||||
def _add_email(self, request_email, expected_email):
|
||||
def _add_email(self, request_email: str, expected_email: str) -> None:
|
||||
"""Test adding an email to profile"""
|
||||
previous_email_attempts = len(self.email_attempts)
|
||||
|
||||
@@ -1030,7 +1047,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
|
||||
# Get user
|
||||
channel = self.make_request(
|
||||
@@ -1039,7 +1056,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
|
||||
access_token=self.user_id_tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
|
||||
|
||||
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
|
||||
@@ -1055,18 +1072,18 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
url = "/_matrix/client/unstable/org.matrix.msc3720/account_status"
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
config["experimental_features"] = {"msc3720_enabled": True}
|
||||
|
||||
return self.setup_test_homeserver(config=config)
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.requester = self.register_user("requester", "password")
|
||||
self.requester_tok = self.login("requester", "password")
|
||||
self.server_name = homeserver.config.server.server_name
|
||||
self.server_name = hs.config.server.server_name
|
||||
|
||||
def test_missing_mxid(self):
|
||||
def test_missing_mxid(self) -> None:
|
||||
"""Tests that not providing any MXID raises an error."""
|
||||
self._test_status(
|
||||
users=None,
|
||||
@@ -1074,7 +1091,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
expected_errcode=Codes.MISSING_PARAM,
|
||||
)
|
||||
|
||||
def test_invalid_mxid(self):
|
||||
def test_invalid_mxid(self) -> None:
|
||||
"""Tests that providing an invalid MXID raises an error."""
|
||||
self._test_status(
|
||||
users=["bad:test"],
|
||||
@@ -1082,7 +1099,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
expected_errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
def test_local_user_not_exists(self):
|
||||
def test_local_user_not_exists(self) -> None:
|
||||
"""Tests that the account status endpoints correctly reports that a user doesn't
|
||||
exist.
|
||||
"""
|
||||
@@ -1098,7 +1115,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
expected_failures=[],
|
||||
)
|
||||
|
||||
def test_local_user_exists(self):
|
||||
def test_local_user_exists(self) -> None:
|
||||
"""Tests that the account status endpoint correctly reports that a user doesn't
|
||||
exist.
|
||||
"""
|
||||
@@ -1115,7 +1132,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
expected_failures=[],
|
||||
)
|
||||
|
||||
def test_local_user_deactivated(self):
|
||||
def test_local_user_deactivated(self) -> None:
|
||||
"""Tests that the account status endpoint correctly reports a deactivated user."""
|
||||
user = self.register_user("someuser", "password")
|
||||
self.get_success(
|
||||
@@ -1135,7 +1152,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
expected_failures=[],
|
||||
)
|
||||
|
||||
def test_mixed_local_and_remote_users(self):
|
||||
def test_mixed_local_and_remote_users(self) -> None:
|
||||
"""Tests that if some users are remote the account status endpoint correctly
|
||||
merges the remote responses with the local result.
|
||||
"""
|
||||
@@ -1150,7 +1167,13 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
"@bad:badremote",
|
||||
]
|
||||
|
||||
async def post_json(destination, path, data, *a, **kwa):
|
||||
async def post_json(
|
||||
destination: str,
|
||||
path: str,
|
||||
data: Optional[JsonDict] = None,
|
||||
*a: Any,
|
||||
**kwa: Any,
|
||||
) -> Union[JsonDict, list]:
|
||||
if destination == "remote":
|
||||
return {
|
||||
"account_statuses": {
|
||||
@@ -1160,9 +1183,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
}
|
||||
}
|
||||
if destination == "otherremote":
|
||||
return {}
|
||||
if destination == "badremote":
|
||||
elif destination == "badremote":
|
||||
# badremote tries to overwrite the status of a user that doesn't belong
|
||||
# to it (i.e. users[1]) with false data, which Synapse is expected to
|
||||
# ignore.
|
||||
@@ -1176,6 +1197,9 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
}
|
||||
}
|
||||
# if destination == "otherremote"
|
||||
else:
|
||||
return {}
|
||||
|
||||
# Register a mock that will return the expected result depending on the remote.
|
||||
self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
|
||||
@@ -1205,7 +1229,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
|
||||
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
|
||||
expected_failures: Optional[List[str]] = None,
|
||||
expected_errcode: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Send a request to the account status endpoint and check that the response
|
||||
matches with what's expected.
|
||||
|
||||
|
||||
@@ -12,10 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.rest.client import filter
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
@@ -30,11 +32,11 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
|
||||
servlets = [filter.register_servlets]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.filtering = hs.get_filtering()
|
||||
self.store = hs.get_datastores().main
|
||||
|
||||
def test_add_filter(self):
|
||||
def test_add_filter(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
|
||||
@@ -43,11 +45,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertEqual(channel.json_body, {"filter_id": "0"})
|
||||
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
||||
filter = self.get_success(
|
||||
self.store.get_user_filter(user_localpart="apple", filter_id=0)
|
||||
)
|
||||
self.pump()
|
||||
self.assertEqual(filter.result, self.EXAMPLE_FILTER)
|
||||
self.assertEqual(filter, self.EXAMPLE_FILTER)
|
||||
|
||||
def test_add_filter_for_other_user(self):
|
||||
def test_add_filter_for_other_user(self) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
|
||||
@@ -57,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.result["code"], b"403")
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||
|
||||
def test_add_filter_non_local_user(self):
|
||||
def test_add_filter_non_local_user(self) -> None:
|
||||
_is_mine = self.hs.is_mine
|
||||
self.hs.is_mine = lambda target_user: False
|
||||
channel = self.make_request(
|
||||
@@ -70,14 +74,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.result["code"], b"403")
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
|
||||
|
||||
def test_get_filter(self):
|
||||
filter_id = defer.ensureDeferred(
|
||||
def test_get_filter(self) -> None:
|
||||
filter_id = self.get_success(
|
||||
self.filtering.add_user_filter(
|
||||
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
|
||||
)
|
||||
)
|
||||
self.reactor.advance(1)
|
||||
filter_id = filter_id.result
|
||||
channel = self.make_request(
|
||||
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
|
||||
)
|
||||
@@ -85,7 +88,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.result["code"], b"200")
|
||||
self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
|
||||
|
||||
def test_get_filter_non_existant(self):
|
||||
def test_get_filter_non_existant(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
|
||||
)
|
||||
@@ -95,7 +98,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Currently invalid params do not have an appropriate errcode
|
||||
# in errors.py
|
||||
def test_get_filter_invalid_id(self):
|
||||
def test_get_filter_invalid_id(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
|
||||
)
|
||||
@@ -103,7 +106,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.result["code"], b"400")
|
||||
|
||||
# No ID also returns an invalid_id error
|
||||
def test_get_filter_no_id(self):
|
||||
def test_get_filter_no_id(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
|
||||
)
|
||||
|
||||
+180
-212
@@ -15,7 +15,7 @@
|
||||
|
||||
import itertools
|
||||
import urllib.parse
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
@@ -34,7 +34,7 @@ from tests.test_utils import make_awaitable
|
||||
from tests.test_utils.event_injection import inject_event
|
||||
|
||||
|
||||
class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
class BaseRelationsTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
relations.register_servlets,
|
||||
room.register_servlets,
|
||||
@@ -45,10 +45,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
]
|
||||
hijack_auth = False
|
||||
|
||||
def default_config(self) -> dict:
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
# We need to enable msc1849 support for aggregations
|
||||
config = super().default_config()
|
||||
config["experimental_msc1849_support_enabled"] = True
|
||||
|
||||
# We enable frozen dicts as relations/edits change event contents, so we
|
||||
# want to test that we don't modify the events in the caches.
|
||||
@@ -67,10 +66,62 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
res = self.helper.send(self.room, body="Hi!", tok=self.user_token)
|
||||
self.parent_id = res["event_id"]
|
||||
|
||||
def test_send_relation(self) -> None:
|
||||
"""Tests that sending a relation using the new /send_relation works
|
||||
creates the right shape of event.
|
||||
def _create_user(self, localpart: str) -> Tuple[str, str]:
|
||||
user_id = self.register_user(localpart, "abc123")
|
||||
access_token = self.login(localpart, "abc123")
|
||||
|
||||
return user_id, access_token
|
||||
|
||||
def _send_relation(
|
||||
self,
|
||||
relation_type: str,
|
||||
event_type: str,
|
||||
key: Optional[str] = None,
|
||||
content: Optional[dict] = None,
|
||||
access_token: Optional[str] = None,
|
||||
parent_id: Optional[str] = None,
|
||||
) -> FakeChannel:
|
||||
"""Helper function to send a relation pointing at `self.parent_id`
|
||||
|
||||
Args:
|
||||
relation_type: One of `RelationTypes`
|
||||
event_type: The type of the event to create
|
||||
key: The aggregation key used for m.annotation relation type.
|
||||
content: The content of the created event. Will be modified to configure
|
||||
the m.relates_to key based on the other provided parameters.
|
||||
access_token: The access token used to send the relation, defaults
|
||||
to `self.user_token`
|
||||
parent_id: The event_id this relation relates to. If None, then self.parent_id
|
||||
|
||||
Returns:
|
||||
FakeChannel
|
||||
"""
|
||||
if not access_token:
|
||||
access_token = self.user_token
|
||||
|
||||
original_id = parent_id if parent_id else self.parent_id
|
||||
|
||||
if content is None:
|
||||
content = {}
|
||||
content["m.relates_to"] = {
|
||||
"event_id": original_id,
|
||||
"rel_type": relation_type,
|
||||
}
|
||||
if key is not None:
|
||||
content["m.relates_to"]["key"] = key
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}",
|
||||
content,
|
||||
access_token=access_token,
|
||||
)
|
||||
return channel
|
||||
|
||||
|
||||
class RelationsTestCase(BaseRelationsTestCase):
|
||||
def test_send_relation(self) -> None:
|
||||
"""Tests that sending a relation works."""
|
||||
|
||||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
@@ -79,7 +130,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/rooms/%s/event/%s" % (self.room, event_id),
|
||||
f"/rooms/{self.room}/event/{event_id}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
@@ -317,9 +368,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Request /sync, limiting it such that only the latest event is returned
|
||||
# (and not the relation).
|
||||
filter = urllib.parse.quote_plus(
|
||||
'{"room": {"timeline": {"limit": 1}}}'.encode()
|
||||
)
|
||||
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
|
||||
channel = self.make_request(
|
||||
"GET", f"/sync?filter={filter}", access_token=self.user_token
|
||||
)
|
||||
@@ -404,8 +453,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
|
||||
% (self.room, self.parent_id, from_token),
|
||||
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
@@ -544,8 +592,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
|
||||
% (self.room, self.parent_id),
|
||||
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
@@ -560,47 +607,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_aggregation_redactions(self) -> None:
|
||||
"""Test that annotations get correctly aggregated after a redaction."""
|
||||
|
||||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
to_redact_event_id = channel.json_body["event_id"]
|
||||
|
||||
channel = self._send_relation(
|
||||
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
# Now lets redact one of the 'a' reactions
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
|
||||
access_token=self.user_token,
|
||||
content={},
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
|
||||
% (self.room, self.parent_id),
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
|
||||
)
|
||||
|
||||
def test_aggregation_must_be_annotation(self) -> None:
|
||||
"""Test that aggregations must be annotations."""
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
|
||||
% (self.room, self.parent_id, RelationTypes.REPLACE),
|
||||
f"/_matrix/client/unstable/rooms/{self.room}/aggregations"
|
||||
f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(400, channel.code, channel.json_body)
|
||||
@@ -691,10 +704,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
},
|
||||
"event_id": thread_2,
|
||||
"room_id": self.room,
|
||||
"sender": self.user_id,
|
||||
"type": "m.room.test",
|
||||
"user_id": self.user_id,
|
||||
},
|
||||
relations_dict[RelationTypes.THREAD].get("latest_event"),
|
||||
)
|
||||
@@ -986,9 +997,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Request sync, but limit the timeline so it becomes limited (and includes
|
||||
# bundled aggregations).
|
||||
filter = urllib.parse.quote_plus(
|
||||
'{"room": {"timeline": {"limit": 2}}}'.encode()
|
||||
)
|
||||
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 2}}}')
|
||||
channel = self.make_request(
|
||||
"GET", f"/sync?filter={filter}", access_token=self.user_token
|
||||
)
|
||||
@@ -1053,7 +1062,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/rooms/%s/event/%s" % (self.room, self.parent_id),
|
||||
f"/rooms/{self.room}/event/{self.parent_id}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
@@ -1096,7 +1105,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/rooms/%s/event/%s" % (self.room, reply),
|
||||
f"/rooms/{self.room}/event/{reply}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
@@ -1198,7 +1207,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
# Request the original event.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/rooms/%s/event/%s" % (self.room, self.parent_id),
|
||||
f"/rooms/{self.room}/event/{self.parent_id}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
@@ -1217,102 +1226,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
|
||||
)
|
||||
|
||||
def test_relations_redaction_redacts_edits(self) -> None:
|
||||
"""Test that edits of an event are redacted when the original event
|
||||
is redacted.
|
||||
"""
|
||||
# Send a new event
|
||||
res = self.helper.send(self.room, body="Heyo!", tok=self.user_token)
|
||||
original_event_id = res["event_id"]
|
||||
|
||||
# Add a relation
|
||||
channel = self._send_relation(
|
||||
RelationTypes.REPLACE,
|
||||
"m.room.message",
|
||||
parent_id=original_event_id,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Wibble",
|
||||
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
|
||||
},
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
# Check the relation is returned
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
|
||||
% (self.room, original_event_id),
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
self.assertIn("chunk", channel.json_body)
|
||||
self.assertEqual(len(channel.json_body["chunk"]), 1)
|
||||
|
||||
# Redact the original event
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
"/rooms/%s/redact/%s/%s"
|
||||
% (self.room, original_event_id, "test_relations_redaction_redacts_edits"),
|
||||
access_token=self.user_token,
|
||||
content="{}",
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
# Try to check for remaining m.replace relations
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
|
||||
% (self.room, original_event_id),
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
# Check that no relations are returned
|
||||
self.assertIn("chunk", channel.json_body)
|
||||
self.assertEqual(channel.json_body["chunk"], [])
|
||||
|
||||
def test_aggregations_redaction_prevents_access_to_aggregations(self) -> None:
|
||||
"""Test that annotations of an event are redacted when the original event
|
||||
is redacted.
|
||||
"""
|
||||
# Send a new event
|
||||
res = self.helper.send(self.room, body="Hello!", tok=self.user_token)
|
||||
original_event_id = res["event_id"]
|
||||
|
||||
# Add a relation
|
||||
channel = self._send_relation(
|
||||
RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
# Redact the original
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
"/rooms/%s/redact/%s/%s"
|
||||
% (
|
||||
self.room,
|
||||
original_event_id,
|
||||
"test_aggregations_redaction_prevents_access_to_aggregations",
|
||||
),
|
||||
access_token=self.user_token,
|
||||
content="{}",
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
# Check that aggregations returns zero
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction"
|
||||
% (self.room, original_event_id),
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
self.assertIn("chunk", channel.json_body)
|
||||
self.assertEqual(channel.json_body["chunk"], [])
|
||||
|
||||
def test_unknown_relations(self) -> None:
|
||||
"""Unknown relations should be accepted."""
|
||||
channel = self._send_relation("m.relation.test", "m.room.test")
|
||||
@@ -1321,8 +1234,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
|
||||
% (self.room, self.parent_id),
|
||||
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
@@ -1343,7 +1255,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
# When bundling the unknown relation is not included.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/rooms/%s/event/%s" % (self.room, self.parent_id),
|
||||
f"/rooms/{self.room}/event/{self.parent_id}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
@@ -1352,8 +1264,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
# But unknown relations can be directly queried.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
|
||||
% (self.room, self.parent_id),
|
||||
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
@@ -1369,58 +1280,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
raise AssertionError(f"Event {self.parent_id} not found in chunk")
|
||||
|
||||
def _send_relation(
|
||||
self,
|
||||
relation_type: str,
|
||||
event_type: str,
|
||||
key: Optional[str] = None,
|
||||
content: Optional[dict] = None,
|
||||
access_token: Optional[str] = None,
|
||||
parent_id: Optional[str] = None,
|
||||
) -> FakeChannel:
|
||||
"""Helper function to send a relation pointing at `self.parent_id`
|
||||
|
||||
Args:
|
||||
relation_type: One of `RelationTypes`
|
||||
event_type: The type of the event to create
|
||||
key: The aggregation key used for m.annotation relation type.
|
||||
content: The content of the created event. Will be modified to configure
|
||||
the m.relates_to key based on the other provided parameters.
|
||||
access_token: The access token used to send the relation, defaults
|
||||
to `self.user_token`
|
||||
parent_id: The event_id this relation relates to. If None, then self.parent_id
|
||||
|
||||
Returns:
|
||||
FakeChannel
|
||||
"""
|
||||
if not access_token:
|
||||
access_token = self.user_token
|
||||
|
||||
original_id = parent_id if parent_id else self.parent_id
|
||||
|
||||
if content is None:
|
||||
content = {}
|
||||
content["m.relates_to"] = {
|
||||
"event_id": original_id,
|
||||
"rel_type": relation_type,
|
||||
}
|
||||
if key is not None:
|
||||
content["m.relates_to"]["key"] = key
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}",
|
||||
content,
|
||||
access_token=access_token,
|
||||
)
|
||||
return channel
|
||||
|
||||
def _create_user(self, localpart: str) -> Tuple[str, str]:
|
||||
user_id = self.register_user(localpart, "abc123")
|
||||
access_token = self.login(localpart, "abc123")
|
||||
|
||||
return user_id, access_token
|
||||
|
||||
def test_background_update(self) -> None:
|
||||
"""Test the event_arbitrary_relations background update."""
|
||||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
|
||||
@@ -1482,3 +1341,112 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
[ev["event_id"] for ev in channel.json_body["chunk"]],
|
||||
[annotation_event_id_good, thread_event_id],
|
||||
)
|
||||
|
||||
|
||||
class RelationRedactionTestCase(BaseRelationsTestCase):
|
||||
"""Test the behaviour of relations when the parent or child event is redacted."""
|
||||
|
||||
def _redact(self, event_id: str) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
f"/_matrix/client/r0/rooms/{self.room}/redact/{event_id}",
|
||||
access_token=self.user_token,
|
||||
content={},
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
def test_redact_relation_annotation(self) -> None:
|
||||
"""Test that annotations of an event are properly handled after the
|
||||
annotation is redacted.
|
||||
"""
|
||||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
to_redact_event_id = channel.json_body["event_id"]
|
||||
|
||||
channel = self._send_relation(
|
||||
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
# Redact one of the reactions.
|
||||
self._redact(to_redact_event_id)
|
||||
|
||||
# Ensure that the aggregations are correct.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
|
||||
)
|
||||
|
||||
def test_redact_relation_edit(self) -> None:
|
||||
"""Test that edits of an event are redacted when the original event
|
||||
is redacted.
|
||||
"""
|
||||
# Add a relation
|
||||
channel = self._send_relation(
|
||||
RelationTypes.REPLACE,
|
||||
"m.room.message",
|
||||
parent_id=self.parent_id,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "Wibble",
|
||||
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
|
||||
},
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
# Check the relation is returned
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/rooms/{self.room}/relations"
|
||||
f"/{self.parent_id}/m.replace/m.room.message",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
self.assertIn("chunk", channel.json_body)
|
||||
self.assertEqual(len(channel.json_body["chunk"]), 1)
|
||||
|
||||
# Redact the original event
|
||||
self._redact(self.parent_id)
|
||||
|
||||
# Try to check for remaining m.replace relations
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/rooms/{self.room}/relations"
|
||||
f"/{self.parent_id}/m.replace/m.room.message",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
# Check that no relations are returned
|
||||
self.assertIn("chunk", channel.json_body)
|
||||
self.assertEqual(channel.json_body["chunk"], [])
|
||||
|
||||
def test_redact_parent(self) -> None:
|
||||
"""Test that annotations of an event are redacted when the original event
|
||||
is redacted.
|
||||
"""
|
||||
# Add a relation
|
||||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
# Redact the original event.
|
||||
self._redact(self.parent_id)
|
||||
|
||||
# Check that aggregations returns zero
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
self.assertIn("chunk", channel.json_body)
|
||||
self.assertEqual(channel.json_body["chunk"], [])
|
||||
|
||||
@@ -14,8 +14,13 @@
|
||||
|
||||
import json
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.rest.client import login, report_event, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
@@ -28,7 +33,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
|
||||
report_event.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
self.admin_user_tok = self.login("admin", "pass")
|
||||
self.other_user = self.register_user("user", "pass")
|
||||
@@ -42,35 +47,35 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
|
||||
self.event_id = resp["event_id"]
|
||||
self.report_path = f"rooms/{self.room_id}/report/{self.event_id}"
|
||||
|
||||
def test_reason_str_and_score_int(self):
|
||||
def test_reason_str_and_score_int(self) -> None:
|
||||
data = {"reason": "this makes me sad", "score": -100}
|
||||
self._assert_status(200, data)
|
||||
|
||||
def test_no_reason(self):
|
||||
def test_no_reason(self) -> None:
|
||||
data = {"score": 0}
|
||||
self._assert_status(200, data)
|
||||
|
||||
def test_no_score(self):
|
||||
def test_no_score(self) -> None:
|
||||
data = {"reason": "this makes me sad"}
|
||||
self._assert_status(200, data)
|
||||
|
||||
def test_no_reason_and_no_score(self):
|
||||
data = {}
|
||||
def test_no_reason_and_no_score(self) -> None:
|
||||
data: JsonDict = {}
|
||||
self._assert_status(200, data)
|
||||
|
||||
def test_reason_int_and_score_str(self):
|
||||
def test_reason_int_and_score_str(self) -> None:
|
||||
data = {"reason": 10, "score": "string"}
|
||||
self._assert_status(400, data)
|
||||
|
||||
def test_reason_zero_and_score_blank(self):
|
||||
def test_reason_zero_and_score_blank(self) -> None:
|
||||
data = {"reason": 0, "score": ""}
|
||||
self._assert_status(400, data)
|
||||
|
||||
def test_reason_and_score_null(self):
|
||||
def test_reason_and_score_null(self) -> None:
|
||||
data = {"reason": None, "score": None}
|
||||
self._assert_status(400, data)
|
||||
|
||||
def _assert_status(self, response_status, data):
|
||||
def _assert_status(self, response_status: int, data: JsonDict) -> None:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
self.report_path,
|
||||
|
||||
+138
-133
@@ -18,11 +18,12 @@
|
||||
"""Tests REST events for /rooms paths."""
|
||||
|
||||
import json
|
||||
from typing import Iterable, List
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from unittest.mock import Mock, call
|
||||
from urllib import parse as urlparse
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import (
|
||||
@@ -35,7 +36,9 @@ from synapse.api.errors import Codes, HttpResponseException
|
||||
from synapse.handlers.pagination import PurgeStatus
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import account, directory, login, profile, room, sync
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, RoomAlias, UserID, create_requester
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from tests import unittest
|
||||
@@ -45,11 +48,11 @@ PATH_PREFIX = b"/_matrix/client/api/v1"
|
||||
|
||||
|
||||
class RoomBase(unittest.HomeserverTestCase):
|
||||
rmcreator_id = None
|
||||
rmcreator_id: Optional[str] = None
|
||||
|
||||
servlets = [room.register_servlets, room.register_deprecated_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
|
||||
self.hs = self.setup_test_homeserver(
|
||||
"red",
|
||||
@@ -57,15 +60,15 @@ class RoomBase(unittest.HomeserverTestCase):
|
||||
federation_client=Mock(),
|
||||
)
|
||||
|
||||
self.hs.get_federation_handler = Mock()
|
||||
self.hs.get_federation_handler = Mock() # type: ignore[assignment]
|
||||
self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
|
||||
return_value=make_awaitable(None)
|
||||
)
|
||||
|
||||
async def _insert_client_ip(*args, **kwargs):
|
||||
async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
|
||||
return None
|
||||
|
||||
self.hs.get_datastores().main.insert_client_ip = _insert_client_ip
|
||||
self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment]
|
||||
|
||||
return self.hs
|
||||
|
||||
@@ -76,7 +79,7 @@ class RoomPermissionsTestCase(RoomBase):
|
||||
user_id = "@sid1:red"
|
||||
rmcreator_id = "@notme:red"
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
|
||||
self.helper.auth_user_id = self.rmcreator_id
|
||||
# create some rooms under the name rmcreator_id
|
||||
@@ -108,12 +111,12 @@ class RoomPermissionsTestCase(RoomBase):
|
||||
# auth as user_id now
|
||||
self.helper.auth_user_id = self.user_id
|
||||
|
||||
def test_can_do_action(self):
|
||||
def test_can_do_action(self) -> None:
|
||||
msg_content = b'{"msgtype":"m.text","body":"hello"}'
|
||||
|
||||
seq = iter(range(100))
|
||||
|
||||
def send_msg_path():
|
||||
def send_msg_path() -> str:
|
||||
return "/rooms/%s/send/m.room.message/mid%s" % (
|
||||
self.created_rmid,
|
||||
str(next(seq)),
|
||||
@@ -148,7 +151,7 @@ class RoomPermissionsTestCase(RoomBase):
|
||||
channel = self.make_request("PUT", send_msg_path(), msg_content)
|
||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||
|
||||
def test_topic_perms(self):
|
||||
def test_topic_perms(self) -> None:
|
||||
topic_content = b'{"topic":"My Topic Name"}'
|
||||
topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
|
||||
|
||||
@@ -214,14 +217,14 @@ class RoomPermissionsTestCase(RoomBase):
|
||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||
|
||||
def _test_get_membership(
|
||||
self, room=None, members: Iterable = frozenset(), expect_code=None
|
||||
):
|
||||
self, room: str, members: Iterable = frozenset(), expect_code: int = 200
|
||||
) -> None:
|
||||
for member in members:
|
||||
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
|
||||
channel = self.make_request("GET", path)
|
||||
self.assertEqual(expect_code, channel.code)
|
||||
|
||||
def test_membership_basic_room_perms(self):
|
||||
def test_membership_basic_room_perms(self) -> None:
|
||||
# === room does not exist ===
|
||||
room = self.uncreated_rmid
|
||||
# get membership of self, get membership of other, uncreated room
|
||||
@@ -241,7 +244,7 @@ class RoomPermissionsTestCase(RoomBase):
|
||||
self.helper.join(room=room, user=usr, expect_code=404)
|
||||
self.helper.leave(room=room, user=usr, expect_code=404)
|
||||
|
||||
def test_membership_private_room_perms(self):
|
||||
def test_membership_private_room_perms(self) -> None:
|
||||
room = self.created_rmid
|
||||
# get membership of self, get membership of other, private room + invite
|
||||
# expect all 403s
|
||||
@@ -264,7 +267,7 @@ class RoomPermissionsTestCase(RoomBase):
|
||||
members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
|
||||
)
|
||||
|
||||
def test_membership_public_room_perms(self):
|
||||
def test_membership_public_room_perms(self) -> None:
|
||||
room = self.created_public_rmid
|
||||
# get membership of self, get membership of other, public room + invite
|
||||
# expect 403
|
||||
@@ -287,7 +290,7 @@ class RoomPermissionsTestCase(RoomBase):
|
||||
members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
|
||||
)
|
||||
|
||||
def test_invited_permissions(self):
|
||||
def test_invited_permissions(self) -> None:
|
||||
room = self.created_rmid
|
||||
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
|
||||
|
||||
@@ -310,7 +313,7 @@ class RoomPermissionsTestCase(RoomBase):
|
||||
expect_code=403,
|
||||
)
|
||||
|
||||
def test_joined_permissions(self):
|
||||
def test_joined_permissions(self) -> None:
|
||||
room = self.created_rmid
|
||||
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
|
||||
self.helper.join(room=room, user=self.user_id)
|
||||
@@ -348,7 +351,7 @@ class RoomPermissionsTestCase(RoomBase):
|
||||
# set left of self, expect 200
|
||||
self.helper.leave(room=room, user=self.user_id)
|
||||
|
||||
def test_leave_permissions(self):
|
||||
def test_leave_permissions(self) -> None:
|
||||
room = self.created_rmid
|
||||
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
|
||||
self.helper.join(room=room, user=self.user_id)
|
||||
@@ -383,7 +386,7 @@ class RoomPermissionsTestCase(RoomBase):
|
||||
)
|
||||
|
||||
# tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember
|
||||
def test_member_event_from_ban(self):
|
||||
def test_member_event_from_ban(self) -> None:
|
||||
room = self.created_rmid
|
||||
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
|
||||
self.helper.join(room=room, user=self.user_id)
|
||||
@@ -475,21 +478,21 @@ class RoomsMemberListTestCase(RoomBase):
|
||||
|
||||
user_id = "@sid1:red"
|
||||
|
||||
def test_get_member_list(self):
|
||||
def test_get_member_list(self) -> None:
|
||||
room_id = self.helper.create_room_as(self.user_id)
|
||||
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
|
||||
def test_get_member_list_no_room(self):
|
||||
def test_get_member_list_no_room(self) -> None:
|
||||
channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
|
||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||
|
||||
def test_get_member_list_no_permission(self):
|
||||
def test_get_member_list_no_permission(self) -> None:
|
||||
room_id = self.helper.create_room_as("@some_other_guy:red")
|
||||
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
|
||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||
|
||||
def test_get_member_list_no_permission_with_at_token(self):
|
||||
def test_get_member_list_no_permission_with_at_token(self) -> None:
|
||||
"""
|
||||
Tests that a stranger to the room cannot get the member list
|
||||
(in the case that they use an at token).
|
||||
@@ -509,7 +512,7 @@ class RoomsMemberListTestCase(RoomBase):
|
||||
)
|
||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||
|
||||
def test_get_member_list_no_permission_former_member(self):
|
||||
def test_get_member_list_no_permission_former_member(self) -> None:
|
||||
"""
|
||||
Tests that a former member of the room can not get the member list.
|
||||
"""
|
||||
@@ -529,7 +532,7 @@ class RoomsMemberListTestCase(RoomBase):
|
||||
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
|
||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||
|
||||
def test_get_member_list_no_permission_former_member_with_at_token(self):
|
||||
def test_get_member_list_no_permission_former_member_with_at_token(self) -> None:
|
||||
"""
|
||||
Tests that a former member of the room can not get the member list
|
||||
(in the case that they use an at token).
|
||||
@@ -569,7 +572,7 @@ class RoomsMemberListTestCase(RoomBase):
|
||||
)
|
||||
self.assertEqual(403, channel.code, msg=channel.result["body"])
|
||||
|
||||
def test_get_member_list_mixed_memberships(self):
|
||||
def test_get_member_list_mixed_memberships(self) -> None:
|
||||
room_creator = "@some_other_guy:red"
|
||||
room_id = self.helper.create_room_as(room_creator)
|
||||
room_path = "/rooms/%s/members" % room_id
|
||||
@@ -594,26 +597,26 @@ class RoomsCreateTestCase(RoomBase):
|
||||
|
||||
user_id = "@sid1:red"
|
||||
|
||||
def test_post_room_no_keys(self):
|
||||
def test_post_room_no_keys(self) -> None:
|
||||
# POST with no config keys, expect new room id
|
||||
channel = self.make_request("POST", "/createRoom", "{}")
|
||||
|
||||
self.assertEqual(200, channel.code, channel.result)
|
||||
self.assertTrue("room_id" in channel.json_body)
|
||||
|
||||
def test_post_room_visibility_key(self):
|
||||
def test_post_room_visibility_key(self) -> None:
|
||||
# POST with visibility config key, expect new room id
|
||||
channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
|
||||
self.assertEqual(200, channel.code)
|
||||
self.assertTrue("room_id" in channel.json_body)
|
||||
|
||||
def test_post_room_custom_key(self):
|
||||
def test_post_room_custom_key(self) -> None:
|
||||
# POST with custom config keys, expect new room id
|
||||
channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
|
||||
self.assertEqual(200, channel.code)
|
||||
self.assertTrue("room_id" in channel.json_body)
|
||||
|
||||
def test_post_room_known_and_unknown_keys(self):
|
||||
def test_post_room_known_and_unknown_keys(self) -> None:
|
||||
# POST with custom + known config keys, expect new room id
|
||||
channel = self.make_request(
|
||||
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
|
||||
@@ -621,7 +624,7 @@ class RoomsCreateTestCase(RoomBase):
|
||||
self.assertEqual(200, channel.code)
|
||||
self.assertTrue("room_id" in channel.json_body)
|
||||
|
||||
def test_post_room_invalid_content(self):
|
||||
def test_post_room_invalid_content(self) -> None:
|
||||
# POST with invalid content / paths, expect 400
|
||||
channel = self.make_request("POST", "/createRoom", b'{"visibili')
|
||||
self.assertEqual(400, channel.code)
|
||||
@@ -629,7 +632,7 @@ class RoomsCreateTestCase(RoomBase):
|
||||
channel = self.make_request("POST", "/createRoom", b'["hello"]')
|
||||
self.assertEqual(400, channel.code)
|
||||
|
||||
def test_post_room_invitees_invalid_mxid(self):
|
||||
def test_post_room_invitees_invalid_mxid(self) -> None:
|
||||
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
|
||||
# Note the trailing space in the MXID here!
|
||||
channel = self.make_request(
|
||||
@@ -638,7 +641,7 @@ class RoomsCreateTestCase(RoomBase):
|
||||
self.assertEqual(400, channel.code)
|
||||
|
||||
@unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}})
|
||||
def test_post_room_invitees_ratelimit(self):
|
||||
def test_post_room_invitees_ratelimit(self) -> None:
|
||||
"""Test that invites sent when creating a room are ratelimited by a RateLimiter,
|
||||
which ratelimits them correctly, including by not limiting when the requester is
|
||||
exempt from ratelimiting.
|
||||
@@ -674,7 +677,7 @@ class RoomsCreateTestCase(RoomBase):
|
||||
channel = self.make_request("POST", "/createRoom", content)
|
||||
self.assertEqual(200, channel.code)
|
||||
|
||||
def test_spam_checker_may_join_room(self):
|
||||
def test_spam_checker_may_join_room(self) -> None:
|
||||
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
|
||||
when creating a new room.
|
||||
"""
|
||||
@@ -704,12 +707,12 @@ class RoomTopicTestCase(RoomBase):
|
||||
|
||||
user_id = "@sid1:red"
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# create the room
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
|
||||
|
||||
def test_invalid_puts(self):
|
||||
def test_invalid_puts(self) -> None:
|
||||
# missing keys or invalid json
|
||||
channel = self.make_request("PUT", self.path, "{}")
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
@@ -736,7 +739,7 @@ class RoomTopicTestCase(RoomBase):
|
||||
channel = self.make_request("PUT", self.path, content)
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
|
||||
def test_rooms_topic(self):
|
||||
def test_rooms_topic(self) -> None:
|
||||
# nothing should be there
|
||||
channel = self.make_request("GET", self.path)
|
||||
self.assertEqual(404, channel.code, msg=channel.result["body"])
|
||||
@@ -751,7 +754,7 @@ class RoomTopicTestCase(RoomBase):
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assert_dict(json.loads(content), channel.json_body)
|
||||
|
||||
def test_rooms_topic_with_extra_keys(self):
|
||||
def test_rooms_topic_with_extra_keys(self) -> None:
|
||||
# valid put with extra keys
|
||||
content = '{"topic":"Seasons","subtopic":"Summer"}'
|
||||
channel = self.make_request("PUT", self.path, content)
|
||||
@@ -768,10 +771,10 @@ class RoomMemberStateTestCase(RoomBase):
|
||||
|
||||
user_id = "@sid1:red"
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
def test_invalid_puts(self):
|
||||
def test_invalid_puts(self) -> None:
|
||||
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
|
||||
# missing keys or invalid json
|
||||
channel = self.make_request("PUT", path, "{}")
|
||||
@@ -801,7 +804,7 @@ class RoomMemberStateTestCase(RoomBase):
|
||||
channel = self.make_request("PUT", path, content.encode("ascii"))
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
|
||||
def test_rooms_members_self(self):
|
||||
def test_rooms_members_self(self) -> None:
|
||||
path = "/rooms/%s/state/m.room.member/%s" % (
|
||||
urlparse.quote(self.room_id),
|
||||
self.user_id,
|
||||
@@ -812,13 +815,13 @@ class RoomMemberStateTestCase(RoomBase):
|
||||
channel = self.make_request("PUT", path, content.encode("ascii"))
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
|
||||
channel = self.make_request("GET", path, None)
|
||||
channel = self.make_request("GET", path, content=b"")
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
|
||||
expected_response = {"membership": Membership.JOIN}
|
||||
self.assertEqual(expected_response, channel.json_body)
|
||||
|
||||
def test_rooms_members_other(self):
|
||||
def test_rooms_members_other(self) -> None:
|
||||
self.other_id = "@zzsid1:red"
|
||||
path = "/rooms/%s/state/m.room.member/%s" % (
|
||||
urlparse.quote(self.room_id),
|
||||
@@ -830,11 +833,11 @@ class RoomMemberStateTestCase(RoomBase):
|
||||
channel = self.make_request("PUT", path, content)
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
|
||||
channel = self.make_request("GET", path, None)
|
||||
channel = self.make_request("GET", path, content=b"")
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(json.loads(content), channel.json_body)
|
||||
|
||||
def test_rooms_members_other_custom_keys(self):
|
||||
def test_rooms_members_other_custom_keys(self) -> None:
|
||||
self.other_id = "@zzsid1:red"
|
||||
path = "/rooms/%s/state/m.room.member/%s" % (
|
||||
urlparse.quote(self.room_id),
|
||||
@@ -849,7 +852,7 @@ class RoomMemberStateTestCase(RoomBase):
|
||||
channel = self.make_request("PUT", path, content)
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
|
||||
channel = self.make_request("GET", path, None)
|
||||
channel = self.make_request("GET", path, content=b"")
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
self.assertEqual(json.loads(content), channel.json_body)
|
||||
|
||||
@@ -866,7 +869,7 @@ class RoomInviteRatelimitTestCase(RoomBase):
|
||||
@unittest.override_config(
|
||||
{"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_invites_by_rooms_ratelimit(self):
|
||||
def test_invites_by_rooms_ratelimit(self) -> None:
|
||||
"""Tests that invites in a room are actually rate-limited."""
|
||||
room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
@@ -878,7 +881,7 @@ class RoomInviteRatelimitTestCase(RoomBase):
|
||||
@unittest.override_config(
|
||||
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_invites_by_users_ratelimit(self):
|
||||
def test_invites_by_users_ratelimit(self) -> None:
|
||||
"""Tests that invites to a specific user are actually rate-limited."""
|
||||
|
||||
for _ in range(3):
|
||||
@@ -897,7 +900,7 @@ class RoomJoinTestCase(RoomBase):
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.user1 = self.register_user("thomas", "hackme")
|
||||
self.tok1 = self.login("thomas", "hackme")
|
||||
|
||||
@@ -908,7 +911,7 @@ class RoomJoinTestCase(RoomBase):
|
||||
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
|
||||
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
|
||||
|
||||
def test_spam_checker_may_join_room(self):
|
||||
def test_spam_checker_may_join_room(self) -> None:
|
||||
"""Tests that the user_may_join_room spam checker callback is correctly called
|
||||
and blocks room joins when needed.
|
||||
"""
|
||||
@@ -975,8 +978,8 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
super().prepare(reactor, clock, homeserver)
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
super().prepare(reactor, clock, hs)
|
||||
# profile changes expect that the user is actually registered
|
||||
user = UserID.from_string(self.user_id)
|
||||
self.get_success(self.register_user(user.localpart, "supersecretpassword"))
|
||||
@@ -984,7 +987,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit(self):
|
||||
def test_join_local_ratelimit(self) -> None:
|
||||
"""Tests that local joins are actually rate-limited."""
|
||||
for _ in range(3):
|
||||
self.helper.create_room_as(self.user_id)
|
||||
@@ -994,7 +997,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit_profile_change(self):
|
||||
def test_join_local_ratelimit_profile_change(self) -> None:
|
||||
"""Tests that sending a profile update into all of the user's joined rooms isn't
|
||||
rate-limited by the rate-limiter on joins."""
|
||||
|
||||
@@ -1031,7 +1034,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit_idempotent(self):
|
||||
def test_join_local_ratelimit_idempotent(self) -> None:
|
||||
"""Tests that the room join endpoints remain idempotent despite rate-limiting
|
||||
on room joins."""
|
||||
room_id = self.helper.create_room_as(self.user_id)
|
||||
@@ -1056,7 +1059,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
"autocreate_auto_join_rooms": True,
|
||||
},
|
||||
)
|
||||
def test_autojoin_rooms(self):
|
||||
def test_autojoin_rooms(self) -> None:
|
||||
user_id = self.register_user("testuser", "password")
|
||||
|
||||
# Check that the new user successfully joined the four rooms
|
||||
@@ -1071,10 +1074,10 @@ class RoomMessagesTestCase(RoomBase):
|
||||
|
||||
user_id = "@sid1:red"
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
def test_invalid_puts(self):
|
||||
def test_invalid_puts(self) -> None:
|
||||
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
|
||||
# missing keys or invalid json
|
||||
channel = self.make_request("PUT", path, b"{}")
|
||||
@@ -1095,7 +1098,7 @@ class RoomMessagesTestCase(RoomBase):
|
||||
channel = self.make_request("PUT", path, b"")
|
||||
self.assertEqual(400, channel.code, msg=channel.result["body"])
|
||||
|
||||
def test_rooms_messages_sent(self):
|
||||
def test_rooms_messages_sent(self) -> None:
|
||||
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
|
||||
|
||||
content = b'{"body":"test","msgtype":{"type":"a"}}'
|
||||
@@ -1119,11 +1122,11 @@ class RoomInitialSyncTestCase(RoomBase):
|
||||
|
||||
user_id = "@sid1:red"
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# create the room
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
def test_initial_sync(self):
|
||||
def test_initial_sync(self) -> None:
|
||||
channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
|
||||
self.assertEqual(200, channel.code)
|
||||
|
||||
@@ -1131,7 +1134,7 @@ class RoomInitialSyncTestCase(RoomBase):
|
||||
self.assertEqual("join", channel.json_body["membership"])
|
||||
|
||||
# Room state is easier to assert on if we unpack it into a dict
|
||||
state = {}
|
||||
state: JsonDict = {}
|
||||
for event in channel.json_body["state"]:
|
||||
if "state_key" not in event:
|
||||
continue
|
||||
@@ -1160,10 +1163,10 @@ class RoomMessageListTestCase(RoomBase):
|
||||
|
||||
user_id = "@sid1:red"
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
def test_topo_token_is_accepted(self):
|
||||
def test_topo_token_is_accepted(self) -> None:
|
||||
token = "t1-0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
|
||||
@@ -1174,7 +1177,7 @@ class RoomMessageListTestCase(RoomBase):
|
||||
self.assertTrue("chunk" in channel.json_body)
|
||||
self.assertTrue("end" in channel.json_body)
|
||||
|
||||
def test_stream_token_is_accepted_for_fwd_pagianation(self):
|
||||
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
|
||||
token = "s0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
|
||||
@@ -1185,7 +1188,7 @@ class RoomMessageListTestCase(RoomBase):
|
||||
self.assertTrue("chunk" in channel.json_body)
|
||||
self.assertTrue("end" in channel.json_body)
|
||||
|
||||
def test_room_messages_purge(self):
|
||||
def test_room_messages_purge(self) -> None:
|
||||
store = self.hs.get_datastores().main
|
||||
pagination_handler = self.hs.get_pagination_handler()
|
||||
|
||||
@@ -1278,10 +1281,10 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
|
||||
user_id = True
|
||||
hijack_auth = False
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
|
||||
# Register the user who does the searching
|
||||
self.user_id = self.register_user("user", "pass")
|
||||
self.user_id2 = self.register_user("user", "pass")
|
||||
self.access_token = self.login("user", "pass")
|
||||
|
||||
# Register the user who sends the message
|
||||
@@ -1289,12 +1292,12 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
|
||||
self.other_access_token = self.login("otheruser", "pass")
|
||||
|
||||
# Create a room
|
||||
self.room = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
||||
self.room = self.helper.create_room_as(self.user_id2, tok=self.access_token)
|
||||
|
||||
# Invite the other person
|
||||
self.helper.invite(
|
||||
room=self.room,
|
||||
src=self.user_id,
|
||||
src=self.user_id2,
|
||||
tok=self.access_token,
|
||||
targ=self.other_user_id,
|
||||
)
|
||||
@@ -1304,7 +1307,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
|
||||
room=self.room, user=self.other_user_id, tok=self.other_access_token
|
||||
)
|
||||
|
||||
def test_finds_message(self):
|
||||
def test_finds_message(self) -> None:
|
||||
"""
|
||||
The search functionality will search for content in messages if asked to
|
||||
do so.
|
||||
@@ -1333,7 +1336,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
|
||||
# No context was requested, so we should get none.
|
||||
self.assertEqual(results["results"][0]["context"], {})
|
||||
|
||||
def test_include_context(self):
|
||||
def test_include_context(self) -> None:
|
||||
"""
|
||||
When event_context includes include_profile, profile information will be
|
||||
included in the search response.
|
||||
@@ -1379,7 +1382,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
|
||||
self.url = b"/_matrix/client/r0/publicRooms"
|
||||
|
||||
@@ -1389,11 +1392,11 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
return self.hs
|
||||
|
||||
def test_restricted_no_auth(self):
|
||||
def test_restricted_no_auth(self) -> None:
|
||||
channel = self.make_request("GET", self.url)
|
||||
self.assertEqual(channel.code, 401, channel.result)
|
||||
|
||||
def test_restricted_auth(self):
|
||||
def test_restricted_auth(self) -> None:
|
||||
self.register_user("user", "pass")
|
||||
tok = self.login("user", "pass")
|
||||
|
||||
@@ -1412,19 +1415,19 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
return self.setup_test_homeserver(federation_client=Mock())
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.register_user("user", "pass")
|
||||
self.token = self.login("user", "pass")
|
||||
|
||||
self.federation_client = hs.get_federation_client()
|
||||
|
||||
def test_simple(self):
|
||||
def test_simple(self) -> None:
|
||||
"Simple test for searching rooms over federation"
|
||||
self.federation_client.get_public_rooms.side_effect = (
|
||||
lambda *a, **k: defer.succeed({})
|
||||
self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
|
||||
{}
|
||||
)
|
||||
|
||||
search_filter = {"generic_search_term": "foobar"}
|
||||
@@ -1437,7 +1440,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
self.federation_client.get_public_rooms.assert_called_once_with(
|
||||
self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined]
|
||||
"testserv",
|
||||
limit=100,
|
||||
since_token=None,
|
||||
@@ -1446,12 +1449,12 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
||||
third_party_instance_id=None,
|
||||
)
|
||||
|
||||
def test_fallback(self):
|
||||
def test_fallback(self) -> None:
|
||||
"Test that searching public rooms over federation falls back if it gets a 404"
|
||||
|
||||
# The `get_public_rooms` should be called again if the first call fails
|
||||
# with a 404, when using search filters.
|
||||
self.federation_client.get_public_rooms.side_effect = (
|
||||
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
|
||||
HttpResponseException(404, "Not Found", b""),
|
||||
defer.succeed({}),
|
||||
)
|
||||
@@ -1466,7 +1469,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
self.federation_client.get_public_rooms.assert_has_calls(
|
||||
self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined]
|
||||
[
|
||||
call(
|
||||
"testserv",
|
||||
@@ -1497,14 +1500,14 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
|
||||
profile.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
config["allow_per_room_profiles"] = False
|
||||
self.hs = self.setup_test_homeserver(config=config)
|
||||
|
||||
return self.hs
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.user_id = self.register_user("test", "test")
|
||||
self.tok = self.login("test", "test")
|
||||
|
||||
@@ -1522,7 +1525,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||
|
||||
def test_per_room_profile_forbidden(self):
|
||||
def test_per_room_profile_forbidden(self) -> None:
|
||||
data = {"membership": "join", "displayname": "other test user"}
|
||||
request_data = json.dumps(data)
|
||||
channel = self.make_request(
|
||||
@@ -1557,7 +1560,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.creator = self.register_user("creator", "test")
|
||||
self.creator_tok = self.login("creator", "test")
|
||||
|
||||
@@ -1566,7 +1569,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok)
|
||||
|
||||
def test_join_reason(self):
|
||||
def test_join_reason(self) -> None:
|
||||
reason = "hello"
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
@@ -1578,7 +1581,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self._check_for_reason(reason)
|
||||
|
||||
def test_leave_reason(self):
|
||||
def test_leave_reason(self) -> None:
|
||||
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
|
||||
|
||||
reason = "hello"
|
||||
@@ -1592,7 +1595,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self._check_for_reason(reason)
|
||||
|
||||
def test_kick_reason(self):
|
||||
def test_kick_reason(self) -> None:
|
||||
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
|
||||
|
||||
reason = "hello"
|
||||
@@ -1606,7 +1609,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self._check_for_reason(reason)
|
||||
|
||||
def test_ban_reason(self):
|
||||
def test_ban_reason(self) -> None:
|
||||
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
|
||||
|
||||
reason = "hello"
|
||||
@@ -1620,7 +1623,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self._check_for_reason(reason)
|
||||
|
||||
def test_unban_reason(self):
|
||||
def test_unban_reason(self) -> None:
|
||||
reason = "hello"
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
@@ -1632,7 +1635,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self._check_for_reason(reason)
|
||||
|
||||
def test_invite_reason(self):
|
||||
def test_invite_reason(self) -> None:
|
||||
reason = "hello"
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
@@ -1644,7 +1647,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self._check_for_reason(reason)
|
||||
|
||||
def test_reject_invite_reason(self):
|
||||
def test_reject_invite_reason(self) -> None:
|
||||
self.helper.invite(
|
||||
self.room_id,
|
||||
src=self.creator,
|
||||
@@ -1663,7 +1666,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self._check_for_reason(reason)
|
||||
|
||||
def _check_for_reason(self, reason):
|
||||
def _check_for_reason(self, reason: str) -> None:
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
|
||||
@@ -1704,12 +1707,12 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
||||
"org.matrix.not_labels": ["#notfun"],
|
||||
}
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.user_id = self.register_user("test", "test")
|
||||
self.tok = self.login("test", "test")
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||
|
||||
def test_context_filter_labels(self):
|
||||
def test_context_filter_labels(self) -> None:
|
||||
"""Test that we can filter by a label on a /context request."""
|
||||
event_id = self._send_labelled_messages_in_room()
|
||||
|
||||
@@ -1739,7 +1742,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
||||
events_after[0]["content"]["body"], "with right label", events_after[0]
|
||||
)
|
||||
|
||||
def test_context_filter_not_labels(self):
|
||||
def test_context_filter_not_labels(self) -> None:
|
||||
"""Test that we can filter by the absence of a label on a /context request."""
|
||||
event_id = self._send_labelled_messages_in_room()
|
||||
|
||||
@@ -1772,7 +1775,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
||||
events_after[1]["content"]["body"], "with two wrong labels", events_after[1]
|
||||
)
|
||||
|
||||
def test_context_filter_labels_not_labels(self):
|
||||
def test_context_filter_labels_not_labels(self) -> None:
|
||||
"""Test that we can filter by both a label and the absence of another label on a
|
||||
/context request.
|
||||
"""
|
||||
@@ -1801,7 +1804,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
||||
events_after[0]["content"]["body"], "with wrong label", events_after[0]
|
||||
)
|
||||
|
||||
def test_messages_filter_labels(self):
|
||||
def test_messages_filter_labels(self) -> None:
|
||||
"""Test that we can filter by a label on a /messages request."""
|
||||
self._send_labelled_messages_in_room()
|
||||
|
||||
@@ -1818,7 +1821,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
|
||||
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
|
||||
|
||||
def test_messages_filter_not_labels(self):
|
||||
def test_messages_filter_not_labels(self) -> None:
|
||||
"""Test that we can filter by the absence of a label on a /messages request."""
|
||||
self._send_labelled_messages_in_room()
|
||||
|
||||
@@ -1839,7 +1842,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
||||
events[3]["content"]["body"], "with two wrong labels", events[3]
|
||||
)
|
||||
|
||||
def test_messages_filter_labels_not_labels(self):
|
||||
def test_messages_filter_labels_not_labels(self) -> None:
|
||||
"""Test that we can filter by both a label and the absence of another label on a
|
||||
/messages request.
|
||||
"""
|
||||
@@ -1862,7 +1865,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(len(events), 1, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
|
||||
|
||||
def test_search_filter_labels(self):
|
||||
def test_search_filter_labels(self) -> None:
|
||||
"""Test that we can filter by a label on a /search request."""
|
||||
request_data = json.dumps(
|
||||
{
|
||||
@@ -1899,7 +1902,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
||||
results[1]["result"]["content"]["body"],
|
||||
)
|
||||
|
||||
def test_search_filter_not_labels(self):
|
||||
def test_search_filter_not_labels(self) -> None:
|
||||
"""Test that we can filter by the absence of a label on a /search request."""
|
||||
request_data = json.dumps(
|
||||
{
|
||||
@@ -1946,7 +1949,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
||||
results[3]["result"]["content"]["body"],
|
||||
)
|
||||
|
||||
def test_search_filter_labels_not_labels(self):
|
||||
def test_search_filter_labels_not_labels(self) -> None:
|
||||
"""Test that we can filter by both a label and the absence of another label on a
|
||||
/search request.
|
||||
"""
|
||||
@@ -1980,7 +1983,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
||||
results[0]["result"]["content"]["body"],
|
||||
)
|
||||
|
||||
def _send_labelled_messages_in_room(self):
|
||||
def _send_labelled_messages_in_room(self) -> str:
|
||||
"""Sends several messages to a room with different labels (or without any) to test
|
||||
filtering by label.
|
||||
Returns:
|
||||
@@ -2056,12 +2059,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self):
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
config = super().default_config()
|
||||
config["experimental_features"] = {"msc3440_enabled": True}
|
||||
return config
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.user_id = self.register_user("test", "test")
|
||||
self.tok = self.login("test", "test")
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||
@@ -2136,7 +2139,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
return channel.json_body["chunk"]
|
||||
|
||||
def test_filter_relation_senders(self):
|
||||
def test_filter_relation_senders(self) -> None:
|
||||
# Messages which second user reacted to.
|
||||
filter = {"io.element.relation_senders": [self.second_user_id]}
|
||||
chunk = self._filter_messages(filter)
|
||||
@@ -2159,7 +2162,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
|
||||
)
|
||||
|
||||
def test_filter_relation_type(self):
|
||||
def test_filter_relation_type(self) -> None:
|
||||
# Messages which have annotations.
|
||||
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
|
||||
chunk = self._filter_messages(filter)
|
||||
@@ -2185,7 +2188,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
|
||||
)
|
||||
|
||||
def test_filter_relation_senders_and_type(self):
|
||||
def test_filter_relation_senders_and_type(self) -> None:
|
||||
# Messages which second user reacted to.
|
||||
filter = {
|
||||
"io.element.relation_senders": [self.second_user_id],
|
||||
@@ -2205,7 +2208,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
|
||||
account.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.user_id = self.register_user("user", "password")
|
||||
self.tok = self.login("user", "password")
|
||||
self.room_id = self.helper.create_room_as(
|
||||
@@ -2218,7 +2221,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
|
||||
self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok)
|
||||
self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok)
|
||||
|
||||
def test_erased_sender(self):
|
||||
def test_erased_sender(self) -> None:
|
||||
"""Test that an erasure request results in the requester's events being hidden
|
||||
from any new member of the room.
|
||||
"""
|
||||
@@ -2332,7 +2335,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.room_owner = self.register_user("room_owner", "test")
|
||||
self.room_owner_tok = self.login("room_owner", "test")
|
||||
|
||||
@@ -2340,17 +2343,17 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
||||
self.room_owner, tok=self.room_owner_tok
|
||||
)
|
||||
|
||||
def test_no_aliases(self):
|
||||
def test_no_aliases(self) -> None:
|
||||
res = self._get_aliases(self.room_owner_tok)
|
||||
self.assertEqual(res["aliases"], [])
|
||||
|
||||
def test_not_in_room(self):
|
||||
def test_not_in_room(self) -> None:
|
||||
self.register_user("user", "test")
|
||||
user_tok = self.login("user", "test")
|
||||
res = self._get_aliases(user_tok, expected_code=403)
|
||||
self.assertEqual(res["errcode"], "M_FORBIDDEN")
|
||||
|
||||
def test_admin_user(self):
|
||||
def test_admin_user(self) -> None:
|
||||
alias1 = self._random_alias()
|
||||
self._set_alias_via_directory(alias1)
|
||||
|
||||
@@ -2360,7 +2363,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
||||
res = self._get_aliases(user_tok)
|
||||
self.assertEqual(res["aliases"], [alias1])
|
||||
|
||||
def test_with_aliases(self):
|
||||
def test_with_aliases(self) -> None:
|
||||
alias1 = self._random_alias()
|
||||
alias2 = self._random_alias()
|
||||
|
||||
@@ -2370,7 +2373,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
||||
res = self._get_aliases(self.room_owner_tok)
|
||||
self.assertEqual(set(res["aliases"]), {alias1, alias2})
|
||||
|
||||
def test_peekable_room(self):
|
||||
def test_peekable_room(self) -> None:
|
||||
alias1 = self._random_alias()
|
||||
self._set_alias_via_directory(alias1)
|
||||
|
||||
@@ -2404,7 +2407,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
|
||||
def _random_alias(self) -> str:
|
||||
return RoomAlias(random_string(5), self.hs.hostname).to_string()
|
||||
|
||||
def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
|
||||
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
|
||||
url = "/_matrix/client/r0/directory/room/" + alias
|
||||
data = {"room_id": self.room_id}
|
||||
request_data = json.dumps(data)
|
||||
@@ -2423,7 +2426,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.room_owner = self.register_user("room_owner", "test")
|
||||
self.room_owner_tok = self.login("room_owner", "test")
|
||||
|
||||
@@ -2434,7 +2437,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
self.alias = "#alias:test"
|
||||
self._set_alias_via_directory(self.alias)
|
||||
|
||||
def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
|
||||
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
|
||||
url = "/_matrix/client/r0/directory/room/" + alias
|
||||
data = {"room_id": self.room_id}
|
||||
request_data = json.dumps(data)
|
||||
@@ -2456,7 +2459,9 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
self.assertIsInstance(res, dict)
|
||||
return res
|
||||
|
||||
def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
|
||||
def _set_canonical_alias(
|
||||
self, content: JsonDict, expected_code: int = 200
|
||||
) -> JsonDict:
|
||||
"""Calls the endpoint under test. returns the json response object."""
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
@@ -2469,7 +2474,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
self.assertIsInstance(res, dict)
|
||||
return res
|
||||
|
||||
def test_canonical_alias(self):
|
||||
def test_canonical_alias(self) -> None:
|
||||
"""Test a basic alias message."""
|
||||
# There is no canonical alias to start with.
|
||||
self._get_canonical_alias(expected_code=404)
|
||||
@@ -2488,7 +2493,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(res, {})
|
||||
|
||||
def test_alt_aliases(self):
|
||||
def test_alt_aliases(self) -> None:
|
||||
"""Test a canonical alias message with alt_aliases."""
|
||||
# Create an alias.
|
||||
self._set_canonical_alias({"alt_aliases": [self.alias]})
|
||||
@@ -2504,7 +2509,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(res, {})
|
||||
|
||||
def test_alias_alt_aliases(self):
|
||||
def test_alias_alt_aliases(self) -> None:
|
||||
"""Test a canonical alias message with an alias and alt_aliases."""
|
||||
# Create an alias.
|
||||
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
|
||||
@@ -2520,7 +2525,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(res, {})
|
||||
|
||||
def test_partial_modify(self):
|
||||
def test_partial_modify(self) -> None:
|
||||
"""Test removing only the alt_aliases."""
|
||||
# Create an alias.
|
||||
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
|
||||
@@ -2536,7 +2541,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
res = self._get_canonical_alias()
|
||||
self.assertEqual(res, {"alias": self.alias})
|
||||
|
||||
def test_add_alias(self):
|
||||
def test_add_alias(self) -> None:
|
||||
"""Test removing only the alt_aliases."""
|
||||
# Create an additional alias.
|
||||
second_alias = "#second:test"
|
||||
@@ -2556,7 +2561,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
|
||||
)
|
||||
|
||||
def test_bad_data(self):
|
||||
def test_bad_data(self) -> None:
|
||||
"""Invalid data for alt_aliases should cause errors."""
|
||||
self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
|
||||
@@ -2566,7 +2571,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
|
||||
|
||||
def test_bad_alias(self):
|
||||
def test_bad_alias(self) -> None:
|
||||
"""An alias which does not point to the room raises a SynapseError."""
|
||||
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
|
||||
@@ -2580,13 +2585,13 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.user_id = self.register_user("thomas", "hackme")
|
||||
self.tok = self.login("thomas", "hackme")
|
||||
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||
|
||||
def test_threepid_invite_spamcheck(self):
|
||||
def test_threepid_invite_spamcheck(self) -> None:
|
||||
# Mock a few functions to prevent the test from failing due to failing to talk to
|
||||
# a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
|
||||
# can check its call_count later on during the test.
|
||||
|
||||
@@ -12,16 +12,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventTypes, LoginType, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import account, login, profile, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict, Requester, StateMap
|
||||
from synapse.util import Clock
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
|
||||
from tests import unittest
|
||||
@@ -34,7 +40,7 @@ thread_local = threading.local()
|
||||
|
||||
|
||||
class LegacyThirdPartyRulesTestModule:
|
||||
def __init__(self, config: Dict, module_api: "ModuleApi"):
|
||||
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
|
||||
# keep a record of the "current" rules module, so that the test can patch
|
||||
# it if desired.
|
||||
thread_local.rules_module = self
|
||||
@@ -42,32 +48,36 @@ class LegacyThirdPartyRulesTestModule:
|
||||
|
||||
async def on_create_room(
|
||||
self, requester: Requester, config: dict, is_requester_admin: bool
|
||||
):
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
|
||||
async def check_event_allowed(
|
||||
self, event: EventBase, state: StateMap[EventBase]
|
||||
) -> Union[bool, dict]:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config):
|
||||
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return config
|
||||
|
||||
|
||||
class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
|
||||
def __init__(self, config: Dict, module_api: "ModuleApi"):
|
||||
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
|
||||
super().__init__(config, module_api)
|
||||
|
||||
def on_create_room(
|
||||
async def on_create_room(
|
||||
self, requester: Requester, config: dict, is_requester_admin: bool
|
||||
):
|
||||
) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
|
||||
def __init__(self, config: Dict, module_api: "ModuleApi"):
|
||||
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
|
||||
super().__init__(config, module_api)
|
||||
|
||||
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
|
||||
async def check_event_allowed(
|
||||
self, event: EventBase, state: StateMap[EventBase]
|
||||
) -> JsonDict:
|
||||
d = event.get_dict()
|
||||
content = unfreeze(event.content)
|
||||
content["foo"] = "bar"
|
||||
@@ -84,7 +94,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
account.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
hs = self.setup_test_homeserver()
|
||||
|
||||
load_legacy_third_party_event_rules(hs)
|
||||
@@ -94,22 +104,30 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
# Note that these checks are not relevant to this test case.
|
||||
|
||||
# Have this homeserver auto-approve all event signature checking.
|
||||
async def approve_all_signature_checking(_, pdu):
|
||||
async def approve_all_signature_checking(
|
||||
_: RoomVersion, pdu: EventBase
|
||||
) -> EventBase:
|
||||
return pdu
|
||||
|
||||
hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking
|
||||
hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking # type: ignore[assignment]
|
||||
|
||||
# Have this homeserver skip event auth checks. This is necessary due to
|
||||
# event auth checks ensuring that events were signed by the sender's homeserver.
|
||||
async def _check_event_auth(origin, event, context, *args, **kwargs):
|
||||
async def _check_event_auth(
|
||||
origin: str,
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> EventContext:
|
||||
return context
|
||||
|
||||
hs.get_federation_event_handler()._check_event_auth = _check_event_auth
|
||||
hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
super().prepare(reactor, clock, homeserver)
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
super().prepare(reactor, clock, hs)
|
||||
# Create some users and a room to play with during the tests
|
||||
self.user_id = self.register_user("kermit", "monkey")
|
||||
self.invitee = self.register_user("invitee", "hackme")
|
||||
@@ -121,13 +139,15 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def test_third_party_rules(self):
|
||||
def test_third_party_rules(self) -> None:
|
||||
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
|
||||
can be sent.
|
||||
"""
|
||||
# patch the rules module with a Mock which will return False for some event
|
||||
# types
|
||||
async def check(ev, state):
|
||||
async def check(
|
||||
ev: EventBase, state: StateMap[EventBase]
|
||||
) -> Tuple[bool, Optional[JsonDict]]:
|
||||
return ev.type != "foo.bar.forbidden", None
|
||||
|
||||
callback = Mock(spec=[], side_effect=check)
|
||||
@@ -161,7 +181,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||
|
||||
def test_third_party_rules_workaround_synapse_errors_pass_through(self):
|
||||
def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
|
||||
"""
|
||||
Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042
|
||||
is functional: that SynapseErrors are passed through from check_event_allowed
|
||||
@@ -172,7 +192,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
"""
|
||||
|
||||
class NastyHackException(SynapseError):
|
||||
def error_dict(self):
|
||||
def error_dict(self) -> JsonDict:
|
||||
"""
|
||||
This overrides SynapseError's `error_dict` to nastily inject
|
||||
JSON into the error response.
|
||||
@@ -182,7 +202,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
return result
|
||||
|
||||
# add a callback that will raise our hacky exception
|
||||
async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]:
|
||||
async def check(
|
||||
ev: EventBase, state: StateMap[EventBase]
|
||||
) -> Tuple[bool, Optional[JsonDict]]:
|
||||
raise NastyHackException(429, "message")
|
||||
|
||||
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
|
||||
@@ -202,11 +224,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
{"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"},
|
||||
)
|
||||
|
||||
def test_cannot_modify_event(self):
|
||||
def test_cannot_modify_event(self) -> None:
|
||||
"""cannot accidentally modify an event before it is persisted"""
|
||||
|
||||
# first patch the event checker so that it will try to modify the event
|
||||
async def check(ev: EventBase, state):
|
||||
async def check(
|
||||
ev: EventBase, state: StateMap[EventBase]
|
||||
) -> Tuple[bool, Optional[JsonDict]]:
|
||||
ev.content = {"x": "y"}
|
||||
return True, None
|
||||
|
||||
@@ -223,10 +247,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
# 500 Internal Server Error
|
||||
self.assertEqual(channel.code, 500, channel.result)
|
||||
|
||||
def test_modify_event(self):
|
||||
def test_modify_event(self) -> None:
|
||||
"""The module can return a modified version of the event"""
|
||||
# first patch the event checker so that it will modify the event
|
||||
async def check(ev: EventBase, state):
|
||||
async def check(
|
||||
ev: EventBase, state: StateMap[EventBase]
|
||||
) -> Tuple[bool, Optional[JsonDict]]:
|
||||
d = ev.get_dict()
|
||||
d["content"] = {"x": "y"}
|
||||
return True, d
|
||||
@@ -253,10 +279,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
ev = channel.json_body
|
||||
self.assertEqual(ev["content"]["x"], "y")
|
||||
|
||||
def test_message_edit(self):
|
||||
def test_message_edit(self) -> None:
|
||||
"""Ensure that the module doesn't cause issues with edited messages."""
|
||||
# first patch the event checker so that it will modify the event
|
||||
async def check(ev: EventBase, state):
|
||||
async def check(
|
||||
ev: EventBase, state: StateMap[EventBase]
|
||||
) -> Tuple[bool, Optional[JsonDict]]:
|
||||
d = ev.get_dict()
|
||||
d["content"] = {
|
||||
"msgtype": "m.text",
|
||||
@@ -315,7 +343,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
ev = channel.json_body
|
||||
self.assertEqual(ev["content"]["body"], "EDITED BODY")
|
||||
|
||||
def test_send_event(self):
|
||||
def test_send_event(self) -> None:
|
||||
"""Tests that a module can send an event into a room via the module api"""
|
||||
content = {
|
||||
"msgtype": "m.text",
|
||||
@@ -344,7 +372,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_legacy_check_event_allowed(self):
|
||||
def test_legacy_check_event_allowed(self) -> None:
|
||||
"""Tests that the wrapper for legacy check_event_allowed callbacks works
|
||||
correctly.
|
||||
"""
|
||||
@@ -379,13 +407,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_legacy_on_create_room(self):
|
||||
def test_legacy_on_create_room(self) -> None:
|
||||
"""Tests that the wrapper for legacy on_create_room callbacks works
|
||||
correctly.
|
||||
"""
|
||||
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
|
||||
|
||||
def test_sent_event_end_up_in_room_state(self):
|
||||
def test_sent_event_end_up_in_room_state(self) -> None:
|
||||
"""Tests that a state event sent by a module while processing another state event
|
||||
doesn't get dropped from the state of the room. This is to guard against a bug
|
||||
where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830
|
||||
@@ -400,7 +428,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
api = self.hs.get_module_api()
|
||||
|
||||
# Define a callback that sends a custom event on power levels update.
|
||||
async def test_fn(event: EventBase, state_events):
|
||||
async def test_fn(
|
||||
event: EventBase, state_events: StateMap[EventBase]
|
||||
) -> Tuple[bool, Optional[JsonDict]]:
|
||||
if event.is_state and event.type == EventTypes.PowerLevels:
|
||||
await api.create_and_send_event_into_room(
|
||||
{
|
||||
@@ -436,7 +466,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["i"], i)
|
||||
|
||||
def test_on_new_event(self):
|
||||
def test_on_new_event(self) -> None:
|
||||
"""Test that the on_new_event callback is called on new events"""
|
||||
on_new_event = Mock(make_awaitable(None))
|
||||
self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
|
||||
@@ -501,7 +531,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
def _update_power_levels(self, event_default: int = 0):
|
||||
def _update_power_levels(self, event_default: int = 0) -> None:
|
||||
"""Updates the room's power levels.
|
||||
|
||||
Args:
|
||||
@@ -533,7 +563,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
tok=self.tok,
|
||||
)
|
||||
|
||||
def test_on_profile_update(self):
|
||||
def test_on_profile_update(self) -> None:
|
||||
"""Tests that the on_profile_update module callback is correctly called on
|
||||
profile updates.
|
||||
"""
|
||||
@@ -592,7 +622,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
self.assertEqual(profile_info.display_name, displayname)
|
||||
self.assertEqual(profile_info.avatar_url, avatar_url)
|
||||
|
||||
def test_on_profile_update_admin(self):
|
||||
def test_on_profile_update_admin(self) -> None:
|
||||
"""Tests that the on_profile_update module callback is correctly called on
|
||||
profile updates triggered by a server admin.
|
||||
"""
|
||||
@@ -634,7 +664,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
self.assertEqual(profile_info.display_name, displayname)
|
||||
self.assertEqual(profile_info.avatar_url, avatar_url)
|
||||
|
||||
def test_on_user_deactivation_status_changed(self):
|
||||
def test_on_user_deactivation_status_changed(self) -> None:
|
||||
"""Tests that the on_user_deactivation_status_changed module callback is called
|
||||
correctly when processing a user's deactivation.
|
||||
"""
|
||||
@@ -691,7 +721,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
|
||||
args = profile_mock.call_args[0]
|
||||
self.assertTrue(args[3])
|
||||
|
||||
def test_on_user_deactivation_status_changed_admin(self):
|
||||
def test_on_user_deactivation_status_changed_admin(self) -> None:
|
||||
"""Tests that the on_user_deactivation_status_changed module callback is called
|
||||
correctly when processing a user's deactivation triggered by a server admin as
|
||||
well as a reactivation.
|
||||
|
||||
@@ -15,10 +15,12 @@
|
||||
|
||||
"""Tests REST events for /rooms paths."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.rest.client import room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
@@ -33,40 +35,17 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
||||
user = UserID.from_string(user_id)
|
||||
servlets = [room.register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
hs = self.setup_test_homeserver(
|
||||
"red",
|
||||
federation_http_client=None,
|
||||
federation_client=Mock(),
|
||||
)
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
hs = self.setup_test_homeserver("red")
|
||||
self.event_source = hs.get_event_sources().sources.typing
|
||||
|
||||
hs.get_federation_handler = Mock()
|
||||
|
||||
async def get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.auth_user_id),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
|
||||
hs.get_auth().get_user_by_access_token = get_user_by_access_token
|
||||
|
||||
async def _insert_client_ip(*args, **kwargs):
|
||||
return None
|
||||
|
||||
hs.get_datastores().main.insert_client_ip = _insert_client_ip
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
# Need another user to make notifications actually work
|
||||
self.helper.join(self.room_id, user="@jim:red")
|
||||
|
||||
def test_set_typing(self):
|
||||
def test_set_typing(self) -> None:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
||||
@@ -95,7 +74,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_set_not_typing(self):
|
||||
def test_set_not_typing(self) -> None:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
||||
@@ -103,7 +82,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(200, channel.code)
|
||||
|
||||
def test_typing_timeout(self):
|
||||
def test_typing_timeout(self) -> None:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
|
||||
|
||||
@@ -13,19 +13,24 @@
|
||||
# limitations under the License.
|
||||
import urllib.parse
|
||||
from io import BytesIO, StringIO
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from unittest.mock import Mock
|
||||
|
||||
import signedjson.key
|
||||
from canonicaljson import encode_canonical_json
|
||||
from nacl.signing import SigningKey
|
||||
from signedjson.sign import sign_json
|
||||
from signedjson.types import SigningKey
|
||||
|
||||
from twisted.web.resource import NoResource
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from twisted.web.resource import NoResource, Resource
|
||||
|
||||
from synapse.crypto.keyring import PerspectivesKeyFetcher
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.keys import FetchKeyResult
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
@@ -35,11 +40,11 @@ from tests.utils import default_config
|
||||
|
||||
|
||||
class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.http_client = Mock()
|
||||
return self.setup_test_homeserver(federation_http_client=self.http_client)
|
||||
|
||||
def create_test_resource(self):
|
||||
def create_test_resource(self) -> Resource:
|
||||
return create_resource_tree(
|
||||
{"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
|
||||
)
|
||||
@@ -51,7 +56,12 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
|
||||
Tell the mock http client to expect an outgoing GET request for the given key
|
||||
"""
|
||||
|
||||
async def get_json(destination, path, ignore_backoff=False, **kwargs):
|
||||
async def get_json(
|
||||
destination: str,
|
||||
path: str,
|
||||
ignore_backoff: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[JsonDict, list]:
|
||||
self.assertTrue(ignore_backoff)
|
||||
self.assertEqual(destination, server_name)
|
||||
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
|
||||
@@ -84,7 +94,8 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
|
||||
Checks that the response is a 200 and returns the decoded json body.
|
||||
"""
|
||||
channel = FakeChannel(self.site, self.reactor)
|
||||
req = SynapseRequest(channel, self.site)
|
||||
# channel is a `FakeChannel` but `HTTPChannel` is expected
|
||||
req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
|
||||
req.content = BytesIO(b"")
|
||||
req.requestReceived(
|
||||
b"GET",
|
||||
@@ -97,7 +108,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
|
||||
resp = channel.json_body
|
||||
return resp
|
||||
|
||||
def test_get_key(self):
|
||||
def test_get_key(self) -> None:
|
||||
"""Fetch a remote key"""
|
||||
SERVER_NAME = "remote.server"
|
||||
testkey = signedjson.key.generate_signing_key("ver1")
|
||||
@@ -114,7 +125,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
|
||||
self.assertIn(SERVER_NAME, keys[0]["signatures"])
|
||||
self.assertIn(self.hs.hostname, keys[0]["signatures"])
|
||||
|
||||
def test_get_own_key(self):
|
||||
def test_get_own_key(self) -> None:
|
||||
"""Fetch our own key"""
|
||||
testkey = signedjson.key.generate_signing_key("ver1")
|
||||
self.expect_outgoing_key_request(self.hs.hostname, testkey)
|
||||
@@ -141,7 +152,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
||||
endpoint, to check that the two implementations are compatible.
|
||||
"""
|
||||
|
||||
def default_config(self):
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
config = super().default_config()
|
||||
|
||||
# replace the signing key with our own
|
||||
@@ -152,7 +163,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
||||
|
||||
return config
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
# make a second homeserver, configured to use the first one as a key notary
|
||||
self.http_client2 = Mock()
|
||||
config = default_config(name="keyclient")
|
||||
@@ -175,7 +186,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
||||
|
||||
# wire up outbound POST /key/v2/query requests from hs2 so that they
|
||||
# will be forwarded to hs1
|
||||
async def post_json(destination, path, data):
|
||||
async def post_json(
|
||||
destination: str, path: str, data: Optional[JsonDict] = None
|
||||
) -> Union[JsonDict, list]:
|
||||
self.assertEqual(destination, self.hs.hostname)
|
||||
self.assertEqual(
|
||||
path,
|
||||
@@ -183,7 +196,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
||||
)
|
||||
|
||||
channel = FakeChannel(self.site, self.reactor)
|
||||
req = SynapseRequest(channel, self.site)
|
||||
# channel is a `FakeChannel` but `HTTPChannel` is expected
|
||||
req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
|
||||
req.content = BytesIO(encode_canonical_json(data))
|
||||
|
||||
req.requestReceived(
|
||||
@@ -198,7 +212,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
||||
|
||||
self.http_client2.post_json.side_effect = post_json
|
||||
|
||||
def test_get_key(self):
|
||||
def test_get_key(self) -> None:
|
||||
"""Fetch a key belonging to a random server"""
|
||||
# make up a key to be fetched.
|
||||
testkey = signedjson.key.generate_signing_key("abc")
|
||||
@@ -218,7 +232,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
||||
signedjson.key.encode_verify_key_base64(testkey.verify_key),
|
||||
)
|
||||
|
||||
def test_get_notary_key(self):
|
||||
def test_get_notary_key(self) -> None:
|
||||
"""Fetch a key belonging to the notary server"""
|
||||
# make up a key to be fetched. We randomise the keyid to try to get it to
|
||||
# appear before the key server signing key sometimes (otherwise we bail out
|
||||
@@ -240,7 +254,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
|
||||
signedjson.key.encode_verify_key_base64(testkey.verify_key),
|
||||
)
|
||||
|
||||
def test_get_notary_keyserver_key(self):
|
||||
def test_get_notary_keyserver_key(self) -> None:
|
||||
"""Fetch the notary's keyserver key"""
|
||||
# we expect hs1 to make a regular key request to itself
|
||||
self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key)
|
||||
|
||||
@@ -28,11 +28,11 @@ class GetFileNameFromHeadersTests(unittest.TestCase):
|
||||
b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar",
|
||||
}
|
||||
|
||||
def tests(self):
|
||||
def tests(self) -> None:
|
||||
for hdr, expected in self.TEST_CASES.items():
|
||||
res = get_filename_from_headers({b"Content-Disposition": [hdr]})
|
||||
self.assertEqual(
|
||||
res,
|
||||
expected,
|
||||
"expected output for %s to be %s but was %s" % (hdr, expected, res),
|
||||
f"expected output for {hdr!r} to be {expected} but was {res}",
|
||||
)
|
||||
|
||||
@@ -21,12 +21,12 @@ from tests import unittest
|
||||
|
||||
|
||||
class MediaFilePathsTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
|
||||
self.filepaths = MediaFilePaths("/media_store")
|
||||
|
||||
def test_local_media_filepath(self):
|
||||
def test_local_media_filepath(self) -> None:
|
||||
"""Test local media paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
|
||||
@@ -37,7 +37,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
"/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
|
||||
)
|
||||
|
||||
def test_local_media_thumbnail(self):
|
||||
def test_local_media_thumbnail(self) -> None:
|
||||
"""Test local media thumbnail paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.local_media_thumbnail_rel(
|
||||
@@ -52,14 +52,14 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
"/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
|
||||
)
|
||||
|
||||
def test_local_media_thumbnail_dir(self):
|
||||
def test_local_media_thumbnail_dir(self) -> None:
|
||||
"""Test local media thumbnail directory paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"),
|
||||
"/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
|
||||
)
|
||||
|
||||
def test_remote_media_filepath(self):
|
||||
def test_remote_media_filepath(self) -> None:
|
||||
"""Test remote media paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.remote_media_filepath_rel(
|
||||
@@ -74,7 +74,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
"/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
|
||||
)
|
||||
|
||||
def test_remote_media_thumbnail(self):
|
||||
def test_remote_media_thumbnail(self) -> None:
|
||||
"""Test remote media thumbnail paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.remote_media_thumbnail_rel(
|
||||
@@ -99,7 +99,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
"/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
|
||||
)
|
||||
|
||||
def test_remote_media_thumbnail_legacy(self):
|
||||
def test_remote_media_thumbnail_legacy(self) -> None:
|
||||
"""Test old-style remote media thumbnail paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.remote_media_thumbnail_rel_legacy(
|
||||
@@ -108,7 +108,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
"remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg",
|
||||
)
|
||||
|
||||
def test_remote_media_thumbnail_dir(self):
|
||||
def test_remote_media_thumbnail_dir(self) -> None:
|
||||
"""Test remote media thumbnail directory paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.remote_media_thumbnail_dir(
|
||||
@@ -117,7 +117,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
"/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
|
||||
)
|
||||
|
||||
def test_url_cache_filepath(self):
|
||||
def test_url_cache_filepath(self) -> None:
|
||||
"""Test URL cache paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"),
|
||||
@@ -128,7 +128,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
"/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar",
|
||||
)
|
||||
|
||||
def test_url_cache_filepath_legacy(self):
|
||||
def test_url_cache_filepath_legacy(self) -> None:
|
||||
"""Test old-style URL cache paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
|
||||
@@ -139,7 +139,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
"/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
|
||||
)
|
||||
|
||||
def test_url_cache_filepath_dirs_to_delete(self):
|
||||
def test_url_cache_filepath_dirs_to_delete(self) -> None:
|
||||
"""Test URL cache cleanup paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.url_cache_filepath_dirs_to_delete(
|
||||
@@ -148,7 +148,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
["/media_store/url_cache/2020-01-02"],
|
||||
)
|
||||
|
||||
def test_url_cache_filepath_dirs_to_delete_legacy(self):
|
||||
def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None:
|
||||
"""Test old-style URL cache cleanup paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.url_cache_filepath_dirs_to_delete(
|
||||
@@ -160,7 +160,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_url_cache_thumbnail(self):
|
||||
def test_url_cache_thumbnail(self) -> None:
|
||||
"""Test URL cache thumbnail paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.url_cache_thumbnail_rel(
|
||||
@@ -175,7 +175,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
"/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
|
||||
)
|
||||
|
||||
def test_url_cache_thumbnail_legacy(self):
|
||||
def test_url_cache_thumbnail_legacy(self) -> None:
|
||||
"""Test old-style URL cache thumbnail paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.url_cache_thumbnail_rel(
|
||||
@@ -190,7 +190,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
"/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
|
||||
)
|
||||
|
||||
def test_url_cache_thumbnail_directory(self):
|
||||
def test_url_cache_thumbnail_directory(self) -> None:
|
||||
"""Test URL cache thumbnail directory paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.url_cache_thumbnail_directory_rel(
|
||||
@@ -203,7 +203,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
"/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
|
||||
)
|
||||
|
||||
def test_url_cache_thumbnail_directory_legacy(self):
|
||||
def test_url_cache_thumbnail_directory_legacy(self) -> None:
|
||||
"""Test old-style URL cache thumbnail directory paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.url_cache_thumbnail_directory_rel(
|
||||
@@ -216,7 +216,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
"/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
|
||||
)
|
||||
|
||||
def test_url_cache_thumbnail_dirs_to_delete(self):
|
||||
def test_url_cache_thumbnail_dirs_to_delete(self) -> None:
|
||||
"""Test URL cache thumbnail cleanup paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.url_cache_thumbnail_dirs_to_delete(
|
||||
@@ -228,7 +228,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_url_cache_thumbnail_dirs_to_delete_legacy(self):
|
||||
def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None:
|
||||
"""Test old-style URL cache thumbnail cleanup paths"""
|
||||
self.assertEqual(
|
||||
self.filepaths.url_cache_thumbnail_dirs_to_delete(
|
||||
@@ -241,7 +241,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_server_name_validation(self):
|
||||
def test_server_name_validation(self) -> None:
|
||||
"""Test validation of server names"""
|
||||
self._test_path_validation(
|
||||
[
|
||||
@@ -274,7 +274,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_file_id_validation(self):
|
||||
def test_file_id_validation(self) -> None:
|
||||
"""Test validation of local, remote and legacy URL cache file / media IDs"""
|
||||
# File / media IDs get split into three parts to form paths, consisting of the
|
||||
# first two characters, next two characters and rest of the ID.
|
||||
@@ -357,7 +357,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
invalid_values=invalid_file_ids,
|
||||
)
|
||||
|
||||
def test_url_cache_media_id_validation(self):
|
||||
def test_url_cache_media_id_validation(self) -> None:
|
||||
"""Test validation of URL cache media IDs"""
|
||||
self._test_path_validation(
|
||||
[
|
||||
@@ -387,7 +387,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_content_type_validation(self):
|
||||
def test_content_type_validation(self) -> None:
|
||||
"""Test validation of thumbnail content types"""
|
||||
self._test_path_validation(
|
||||
[
|
||||
@@ -410,7 +410,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
def test_thumbnail_method_validation(self):
|
||||
def test_thumbnail_method_validation(self) -> None:
|
||||
"""Test validation of thumbnail methods"""
|
||||
self._test_path_validation(
|
||||
[
|
||||
@@ -440,7 +440,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
|
||||
parameter: str,
|
||||
valid_values: Iterable[str],
|
||||
invalid_values: Iterable[str],
|
||||
):
|
||||
) -> None:
|
||||
"""Test that the specified methods validate the named parameter as expected
|
||||
|
||||
Args:
|
||||
|
||||
@@ -32,7 +32,7 @@ class SummarizeTestCase(unittest.TestCase):
|
||||
if not lxml:
|
||||
skip = "url preview feature requires lxml"
|
||||
|
||||
def test_long_summarize(self):
|
||||
def test_long_summarize(self) -> None:
|
||||
example_paras = [
|
||||
"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
|
||||
Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in
|
||||
@@ -90,7 +90,7 @@ class SummarizeTestCase(unittest.TestCase):
|
||||
" Tromsøya had a population of 36,088. Substantial parts of the urban…",
|
||||
)
|
||||
|
||||
def test_short_summarize(self):
|
||||
def test_short_summarize(self) -> None:
|
||||
example_paras = [
|
||||
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
|
||||
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
|
||||
@@ -117,7 +117,7 @@ class SummarizeTestCase(unittest.TestCase):
|
||||
" most of the year.",
|
||||
)
|
||||
|
||||
def test_small_then_large_summarize(self):
|
||||
def test_small_then_large_summarize(self) -> None:
|
||||
example_paras = [
|
||||
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
|
||||
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
|
||||
@@ -150,7 +150,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
if not lxml:
|
||||
skip = "url preview feature requires lxml"
|
||||
|
||||
def test_simple(self):
|
||||
def test_simple(self) -> None:
|
||||
html = b"""
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
@@ -165,7 +165,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
def test_comment(self):
|
||||
def test_comment(self) -> None:
|
||||
html = b"""
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
@@ -181,7 +181,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
def test_comment2(self):
|
||||
def test_comment2(self) -> None:
|
||||
html = b"""
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
@@ -206,7 +206,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_script(self):
|
||||
def test_script(self) -> None:
|
||||
html = b"""
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
@@ -222,7 +222,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
def test_missing_title(self):
|
||||
def test_missing_title(self) -> None:
|
||||
html = b"""
|
||||
<html>
|
||||
<body>
|
||||
@@ -236,7 +236,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
||||
|
||||
def test_h1_as_title(self):
|
||||
def test_h1_as_title(self) -> None:
|
||||
html = b"""
|
||||
<html>
|
||||
<meta property="og:description" content="Some text."/>
|
||||
@@ -251,7 +251,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
|
||||
|
||||
def test_missing_title_and_broken_h1(self):
|
||||
def test_missing_title_and_broken_h1(self) -> None:
|
||||
html = b"""
|
||||
<html>
|
||||
<body>
|
||||
@@ -266,19 +266,19 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
||||
|
||||
def test_empty(self):
|
||||
def test_empty(self) -> None:
|
||||
"""Test a body with no data in it."""
|
||||
html = b""
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
self.assertIsNone(tree)
|
||||
|
||||
def test_no_tree(self):
|
||||
def test_no_tree(self) -> None:
|
||||
"""A valid body with no tree in it."""
|
||||
html = b"\x00"
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
self.assertIsNone(tree)
|
||||
|
||||
def test_xml(self):
|
||||
def test_xml(self) -> None:
|
||||
"""Test decoding XML and ensure it works properly."""
|
||||
# Note that the strip() call is important to ensure the xml tag starts
|
||||
# at the initial byte.
|
||||
@@ -293,7 +293,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
def test_invalid_encoding(self):
|
||||
def test_invalid_encoding(self) -> None:
|
||||
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
|
||||
html = b"""
|
||||
<html>
|
||||
@@ -307,7 +307,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
def test_invalid_encoding2(self):
|
||||
def test_invalid_encoding2(self) -> None:
|
||||
"""A body which doesn't match the sent character encoding."""
|
||||
# Note that this contains an invalid UTF-8 sequence in the title.
|
||||
html = b"""
|
||||
@@ -322,7 +322,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
|
||||
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
|
||||
|
||||
def test_windows_1252(self):
|
||||
def test_windows_1252(self) -> None:
|
||||
"""A body which uses cp1252, but doesn't declare that."""
|
||||
html = b"""
|
||||
<html>
|
||||
@@ -338,7 +338,7 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
|
||||
|
||||
class MediaEncodingTestCase(unittest.TestCase):
|
||||
def test_meta_charset(self):
|
||||
def test_meta_charset(self) -> None:
|
||||
"""A character encoding is found via the meta tag."""
|
||||
encodings = _get_html_media_encodings(
|
||||
b"""
|
||||
@@ -363,7 +363,7 @@ class MediaEncodingTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
|
||||
|
||||
def test_meta_charset_underscores(self):
|
||||
def test_meta_charset_underscores(self) -> None:
|
||||
"""A character encoding contains underscore."""
|
||||
encodings = _get_html_media_encodings(
|
||||
b"""
|
||||
@@ -376,7 +376,7 @@ class MediaEncodingTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"])
|
||||
|
||||
def test_xml_encoding(self):
|
||||
def test_xml_encoding(self) -> None:
|
||||
"""A character encoding is found via the meta tag."""
|
||||
encodings = _get_html_media_encodings(
|
||||
b"""
|
||||
@@ -388,7 +388,7 @@ class MediaEncodingTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
|
||||
|
||||
def test_meta_xml_encoding(self):
|
||||
def test_meta_xml_encoding(self) -> None:
|
||||
"""Meta tags take precedence over XML encoding."""
|
||||
encodings = _get_html_media_encodings(
|
||||
b"""
|
||||
@@ -402,7 +402,7 @@ class MediaEncodingTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"])
|
||||
|
||||
def test_content_type(self):
|
||||
def test_content_type(self) -> None:
|
||||
"""A character encoding is found via the Content-Type header."""
|
||||
# Test a few variations of the header.
|
||||
headers = (
|
||||
@@ -417,12 +417,12 @@ class MediaEncodingTestCase(unittest.TestCase):
|
||||
encodings = _get_html_media_encodings(b"", header)
|
||||
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
|
||||
|
||||
def test_fallback(self):
|
||||
def test_fallback(self) -> None:
|
||||
"""A character encoding cannot be found in the body or header."""
|
||||
encodings = _get_html_media_encodings(b"", "text/html")
|
||||
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
|
||||
|
||||
def test_duplicates(self):
|
||||
def test_duplicates(self) -> None:
|
||||
"""Ensure each encoding is only attempted once."""
|
||||
encodings = _get_html_media_encodings(
|
||||
b"""
|
||||
@@ -436,7 +436,7 @@ class MediaEncodingTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
|
||||
|
||||
def test_unknown_invalid(self):
|
||||
def test_unknown_invalid(self) -> None:
|
||||
"""A character encoding should be ignored if it is unknown or invalid."""
|
||||
encodings = _get_html_media_encodings(
|
||||
b"""
|
||||
@@ -451,7 +451,7 @@ class MediaEncodingTestCase(unittest.TestCase):
|
||||
|
||||
|
||||
class RebaseUrlTestCase(unittest.TestCase):
|
||||
def test_relative(self):
|
||||
def test_relative(self) -> None:
|
||||
"""Relative URLs should be resolved based on the context of the base URL."""
|
||||
self.assertEqual(
|
||||
rebase_url("subpage", "https://example.com/foo/"),
|
||||
@@ -466,14 +466,14 @@ class RebaseUrlTestCase(unittest.TestCase):
|
||||
"https://example.com/bar",
|
||||
)
|
||||
|
||||
def test_absolute(self):
|
||||
def test_absolute(self) -> None:
|
||||
"""Absolute URLs should not be modified."""
|
||||
self.assertEqual(
|
||||
rebase_url("https://alice.com/a/", "https://example.com/foo/"),
|
||||
"https://alice.com/a/",
|
||||
)
|
||||
|
||||
def test_data(self):
|
||||
def test_data(self) -> None:
|
||||
"""Data URLs should not be modified."""
|
||||
self.assertEqual(
|
||||
rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
|
||||
|
||||
@@ -16,7 +16,7 @@ import json
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.rest.media.v1.oembed import OEmbedProvider
|
||||
from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
@@ -25,15 +25,15 @@ from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class OEmbedTests(HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
|
||||
self.oembed = OEmbedProvider(homeserver)
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.oembed = OEmbedProvider(hs)
|
||||
|
||||
def parse_response(self, response: JsonDict):
|
||||
def parse_response(self, response: JsonDict) -> OEmbedResult:
|
||||
return self.oembed.parse_oembed_response(
|
||||
"https://test", json.dumps(response).encode("utf-8")
|
||||
)
|
||||
|
||||
def test_version(self):
|
||||
def test_version(self) -> None:
|
||||
"""Accept versions that are similar to 1.0 as a string or int (or missing)."""
|
||||
for version in ("1.0", 1.0, 1):
|
||||
result = self.parse_response({"version": version, "type": "link"})
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from http import HTTPStatus
|
||||
|
||||
from synapse.rest.health import HealthResource
|
||||
|
||||
@@ -19,12 +19,12 @@ from tests import unittest
|
||||
|
||||
|
||||
class HealthCheckTests(unittest.HomeserverTestCase):
|
||||
def create_test_resource(self):
|
||||
def create_test_resource(self) -> HealthResource:
|
||||
# replace the JsonResource with a HealthResource.
|
||||
return HealthResource()
|
||||
|
||||
def test_health(self):
|
||||
def test_health(self) -> None:
|
||||
channel = self.make_request("GET", "/health", shorthand=False)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(channel.result["body"], b"OK")
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from http import HTTPStatus
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.rest.well_known import well_known_resource
|
||||
@@ -19,7 +21,7 @@ from tests import unittest
|
||||
|
||||
|
||||
class WellKnownTests(unittest.HomeserverTestCase):
|
||||
def create_test_resource(self):
|
||||
def create_test_resource(self) -> Resource:
|
||||
# replace the JsonResource with a Resource wrapping the WellKnownResource
|
||||
res = Resource()
|
||||
res.putChild(b".well-known", well_known_resource(self.hs))
|
||||
@@ -31,12 +33,12 @@ class WellKnownTests(unittest.HomeserverTestCase):
|
||||
"default_identity_server": "https://testis",
|
||||
}
|
||||
)
|
||||
def test_client_well_known(self):
|
||||
def test_client_well_known(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET", "/.well-known/matrix/client", shorthand=False
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{
|
||||
@@ -50,27 +52,27 @@ class WellKnownTests(unittest.HomeserverTestCase):
|
||||
"public_baseurl": None,
|
||||
}
|
||||
)
|
||||
def test_client_well_known_no_public_baseurl(self):
|
||||
def test_client_well_known_no_public_baseurl(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET", "/.well-known/matrix/client", shorthand=False
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 404)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||
|
||||
@unittest.override_config({"serve_server_wellknown": True})
|
||||
def test_server_well_known(self):
|
||||
def test_server_well_known(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET", "/.well-known/matrix/server", shorthand=False
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(
|
||||
channel.json_body,
|
||||
{"m.server": "test:443"},
|
||||
)
|
||||
|
||||
def test_server_well_known_disabled(self):
|
||||
def test_server_well_known_disabled(self) -> None:
|
||||
channel = self.make_request(
|
||||
"GET", "/.well-known/matrix/server", shorthand=False
|
||||
)
|
||||
self.assertEqual(channel.code, 404)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||
|
||||
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
|
||||
ApplicationServiceStore,
|
||||
ApplicationServiceTransactionStore,
|
||||
)
|
||||
from synapse.types import DeviceLists
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
@@ -267,7 +268,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
||||
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
|
||||
txn = self.get_success(
|
||||
defer.ensureDeferred(
|
||||
self.store.create_appservice_txn(service, events, [], [], {}, {})
|
||||
self.store.create_appservice_txn(
|
||||
service, events, [], [], {}, {}, DeviceLists()
|
||||
)
|
||||
)
|
||||
)
|
||||
self.assertEqual(txn.id, 1)
|
||||
@@ -283,7 +286,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
||||
self.get_success(self._insert_txn(service.id, 9644, events))
|
||||
self.get_success(self._insert_txn(service.id, 9645, events))
|
||||
txn = self.get_success(
|
||||
self.store.create_appservice_txn(service, events, [], [], {}, {})
|
||||
self.store.create_appservice_txn(
|
||||
service, events, [], [], {}, {}, DeviceLists()
|
||||
)
|
||||
)
|
||||
self.assertEqual(txn.id, 9646)
|
||||
self.assertEqual(txn.events, events)
|
||||
@@ -296,7 +301,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
||||
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
|
||||
self.get_success(self._set_last_txn(service.id, 9643))
|
||||
txn = self.get_success(
|
||||
self.store.create_appservice_txn(service, events, [], [], {}, {})
|
||||
self.store.create_appservice_txn(
|
||||
service, events, [], [], {}, {}, DeviceLists()
|
||||
)
|
||||
)
|
||||
self.assertEqual(txn.id, 9644)
|
||||
self.assertEqual(txn.events, events)
|
||||
@@ -320,7 +327,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
|
||||
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
|
||||
|
||||
txn = self.get_success(
|
||||
self.store.create_appservice_txn(service, events, [], [], {}, {})
|
||||
self.store.create_appservice_txn(
|
||||
service, events, [], [], {}, {}, DeviceLists()
|
||||
)
|
||||
)
|
||||
self.assertEqual(txn.id, 9644)
|
||||
self.assertEqual(txn.events, events)
|
||||
@@ -476,12 +485,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
|
||||
value = self.get_success(
|
||||
self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
|
||||
)
|
||||
self.assertEqual(value, 0)
|
||||
self.assertEqual(value, 1)
|
||||
|
||||
value = self.get_success(
|
||||
self.store.get_type_stream_id_for_appservice(self.service, "presence")
|
||||
)
|
||||
self.assertEqual(value, 0)
|
||||
self.assertEqual(value, 1)
|
||||
|
||||
def test_get_type_stream_id_for_appservice_invalid_type(self) -> None:
|
||||
self.get_failure(
|
||||
|
||||
@@ -13,11 +13,12 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import JsonDict
|
||||
from synapse.visibility import filter_events_for_server
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.types import JsonDict, create_requester
|
||||
from synapse.visibility import filter_events_for_client, filter_events_for_server
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import create_room
|
||||
@@ -185,3 +186,72 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.get_success(self.storage.persistence.persist_event(event, context))
|
||||
return event
|
||||
|
||||
|
||||
class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
|
||||
def test_out_of_band_invite_rejection(self):
|
||||
# this is where we have received an invite event over federation, and then
|
||||
# rejected it.
|
||||
invite_pdu = {
|
||||
"room_id": "!room:id",
|
||||
"depth": 1,
|
||||
"auth_events": [],
|
||||
"prev_events": [],
|
||||
"origin_server_ts": 1,
|
||||
"sender": "@someone:" + self.OTHER_SERVER_NAME,
|
||||
"type": "m.room.member",
|
||||
"state_key": "@user:test",
|
||||
"content": {"membership": "invite"},
|
||||
}
|
||||
self.add_hashes_and_signatures(invite_pdu)
|
||||
invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_federation_server().on_invite_request(
|
||||
self.OTHER_SERVER_NAME,
|
||||
invite_pdu,
|
||||
"9",
|
||||
)
|
||||
)
|
||||
|
||||
# stub out do_remotely_reject_invite so that we fall back to a locally-
|
||||
# generated rejection
|
||||
with patch.object(
|
||||
self.hs.get_federation_handler(),
|
||||
"do_remotely_reject_invite",
|
||||
side_effect=Exception(),
|
||||
):
|
||||
reject_event_id, _ = self.get_success(
|
||||
self.hs.get_room_member_handler().remote_reject_invite(
|
||||
invite_event_id,
|
||||
txn_id=None,
|
||||
requester=create_requester("@user:test"),
|
||||
content={},
|
||||
)
|
||||
)
|
||||
|
||||
invite_event, reject_event = self.get_success(
|
||||
self.hs.get_datastores().main.get_events_as_list(
|
||||
[invite_event_id, reject_event_id]
|
||||
)
|
||||
)
|
||||
|
||||
# the invited user should be able to see both the invite and the rejection
|
||||
self.assertEqual(
|
||||
self.get_success(
|
||||
filter_events_for_client(
|
||||
self.hs.get_storage(), "@user:test", [invite_event, reject_event]
|
||||
)
|
||||
),
|
||||
[invite_event, reject_event],
|
||||
)
|
||||
|
||||
# other users should see neither
|
||||
self.assertEqual(
|
||||
self.get_success(
|
||||
filter_events_for_client(
|
||||
self.hs.get_storage(), "@other:test", [invite_event, reject_event]
|
||||
)
|
||||
),
|
||||
[],
|
||||
)
|
||||
|
||||
@@ -38,18 +38,8 @@ lint_targets =
|
||||
setup.py
|
||||
synapse
|
||||
tests
|
||||
scripts
|
||||
# annoyingly, black doesn't find these so we have to list them
|
||||
scripts/export_signing_key
|
||||
scripts/generate_config
|
||||
scripts/generate_log_config
|
||||
scripts/hash_password
|
||||
scripts/register_new_matrix_user
|
||||
scripts/synapse_port_db
|
||||
scripts/update_synapse_database
|
||||
scripts-dev
|
||||
scripts-dev/build_debian_packages
|
||||
scripts-dev/sign_json
|
||||
stubs
|
||||
contrib
|
||||
synctl
|
||||
@@ -168,32 +158,6 @@ commands =
|
||||
extras = lint
|
||||
commands = isort -c --df {[base]lint_targets}
|
||||
|
||||
[testenv:combine]
|
||||
skip_install = true
|
||||
usedevelop = false
|
||||
deps =
|
||||
coverage
|
||||
pip>=10
|
||||
commands=
|
||||
coverage combine
|
||||
coverage report
|
||||
|
||||
[testenv:cov-erase]
|
||||
skip_install = true
|
||||
usedevelop = false
|
||||
deps =
|
||||
coverage
|
||||
commands=
|
||||
coverage erase
|
||||
|
||||
[testenv:cov-html]
|
||||
skip_install = true
|
||||
usedevelop = false
|
||||
deps =
|
||||
coverage
|
||||
commands=
|
||||
coverage html
|
||||
|
||||
[testenv:mypy]
|
||||
deps =
|
||||
{[base]deps}
|
||||
|
||||
Reference in New Issue
Block a user