1
0

Compare commits

...

29 Commits

Author SHA1 Message Date
Andrew Morgan 3bd26733d7 Add tests
We add a series of tests that check whether device list sending works
across a variety of possible configurations.
2022-03-10 15:50:58 +00:00
Andrew Morgan 90fa2026ba fix tests for device lists 2022-03-10 15:50:58 +00:00
Andrew Morgan 55ac419b63 Add device lists to AS txns, thread thru the AS scheduler methods
Here we implement code that adds support for device list changes all
the way from our enqueue_for_appservice method down to where AS
transactions are actually built and sent out.
2022-03-10 15:50:58 +00:00
Andrew Morgan 047db4da1c Use get_users_whose_devices_changed to pull device list changes for given AS
When a new device list change occurs, we're now:

1. For each appservice, checking the last device list stream key that was
   processed up until.
2. Getting any users with changed device list between the last device list
   stream key and the stream key of the triggering update.
3. Filtering out those users based on those that are actually relevant
   to this application service.
4. Passing those changes to enqueue_for_appservice and saving the device list
   stream key that we've just processed up to for later reference.
2022-03-10 15:50:58 +00:00
Andrew Morgan 88c4e7369d Switch DeviceLists to containing Sets, which allows item deletes
In the next commit, we'll be merging one DeviceList into another. This
will require the ability to remove items by value, which Collection does
not provide, while a mutable structure such as Set does. Set was chosen to
to remove duplicate user IDs.
2022-03-10 15:50:58 +00:00
Andrew Morgan a77f35144f Move DeviceLists type to synapse.types
So that we can use it elsewhere.
2022-03-10 15:50:58 +00:00
Andrew Morgan 1671f8772d Add migration delta to track device_list stream id per appservice 2022-03-10 15:50:58 +00:00
Andrew Morgan b4aad3604a Add to_key arg, user_ids optional for get_users_whose_devices_changed
to_key prevents overlapping bounds when pulling out device list updates.

user_ids needs to be optional as we won't have a list of user_ids to
filter with when calling this function from a triggered device_list
change.
2022-03-10 15:50:58 +00:00
Andrew Morgan 51be04b918 Guard processing device list updates with experimental option 2022-03-10 15:50:58 +00:00
Andrew Morgan 4b6711803d Set min application service stream_id to 1
Factored out into #12193.
2022-03-09 17:27:52 +00:00
Richard van der Hoff 87c230c27c Update client-visibility filtering for outlier events (#12155)
Avoid trying to get the state for outliers, which isn't a sensible thing to do.
2022-03-04 10:31:19 +00:00
Richard van der Hoff d56202b038 Fix type of events in StateGroupStorage and StateHandler (#12156)
We make multiple passes over this, so a regular iterable won't do.
2022-03-04 10:25:18 +00:00
Richard van der Hoff 8533c8b03d Avoid generating state groups for local out-of-band leaves (#12154)
If we locally generate a rejection for an invite received over federation, it
is stored as an outlier (because we probably don't have the state for the
room). However, currently we still generate a state group for it (even though
the state in that state group will be nonsense).

By setting the `outlier` param on `create_event`, we avoid the nonsensical
state.
2022-03-03 19:58:08 +00:00
Andrew Morgan fb0ffa9676 Rename various ApplicationServices interested methods (#11915) 2022-03-03 18:14:09 +00:00
David Robertson 9297d040a7 Detox, part 2 of N (#12152)
I've argued in #11537 that poetry and tox don't cooperate well at the
moment. (See also #12119.) Therefore I'm pruning away bits of tox to make the transition to poetry easier. This change removes the commands for coverage.

We don't use coverage in anger at the moment. It shouldn't be too hard to add coverage as a dev-dependency and reintroduce this if we really want it.
2022-03-03 17:14:09 +00:00
Dirk Klimpel 7e91107be1 Add type hints to tests/rest (#12146)
* Add type hints to `tests/rest`

* newsfile

* change import from `SigningKey`
2022-03-03 16:05:44 +00:00
Patrick Cloke 1d11b452b7 Use the proper serialization format when bundling aggregations. (#12090)
This ensures that the `latest_event` field of the bundled aggregation
for threads uses the same format as the other events in the response.
2022-03-03 10:43:06 -05:00
Eric Eastwood a511a890d7 Enable MSC2716 Complement tests in Synapse (#12145)
Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
2022-03-03 11:19:20 +00:00
Erik Johnston 61fd2a8f59 Limit the size of the aggregation_key (#12101)
There's no reason to let people use long keys.
2022-03-03 10:52:35 +00:00
Eric Eastwood 31b125ccec Enable MSC3030 Complement tests in Synapse (#12144)
The Complement tests for MSC3030 are now merged, https://github.com/matrix-org/complement/pull/178

Synapse implmentation: https://github.com/matrix-org/synapse/pull/9445
2022-03-03 11:45:23 +01:00
David Robertson 11282ade1d Move the snapcraft configuration to contrib. (#12142)
* Move the `snapcraft` configuration to `contrib`.

We're happy for people to package this as a snap image if it's useful,
but we don't support or maintain it. I'd like to move the config to
`contrib` to reflect this state of affairs.

* Changelog
2022-03-02 19:22:44 +00:00
David Robertson 1fbe0316a9 Add suffices to scripts in scripts-dev (#12137)
* Rename scripts-dev to have suffices

* Update references to `scripts-dev`

* Changelog

* These scripts don't pass mypy
2022-03-02 18:00:26 +00:00
David Robertson 106959b3cf Remove unused mocks from test_typing (#12136)
* Remove unused mocks from `test_typing`

It's not clear what these do. `get_user_by_access_token` has the wrong
signature, including the return type. Tests all pass without these. I
think we should nuke them.

* Changelog

* Fixup imports
2022-03-02 17:24:52 +00:00
Dirk Klimpel 2ffaf30803 Add type hints to tests/rest/client (#12108)
* Add type hints to `tests/rest/client`

* newsfile

* fix imports

* add `test_account.py`

* Remove one type hint in `test_report_event.py`

* change `on_create_room` to `async`

* update new functions in `test_third_party_rules.py`

* Add `test_filter.py`

* add `test_rooms.py`

* change to `assertEquals` to `assertEqual`

* lint
2022-03-02 16:34:14 +00:00
Andrew Morgan b4461e7d8a Enable complexity checking in complexity checking docs example (#11998) 2022-03-02 16:11:16 +00:00
Olivier Wilkinson (reivilibre) 594a07ede4 Merge tag 'v1.54.0rc1' into develop
Synapse 1.54.0rc1 (2022-03-02)
==============================

Please note that this will be the last release of Synapse that is compatible with Mjolnir 1.3.1 and earlier.
Administrators of servers which have the Mjolnir module installed are advised to upgrade Mjolnir to version 1.3.2 or later.

Features
--------

- Add support for [MSC3202](https://github.com/matrix-org/matrix-doc/pull/3202): sending one-time key counts and fallback key usage states to Application Services. ([\#11617](https://github.com/matrix-org/synapse/issues/11617))
- Improve the generated URL previews for some web pages. Contributed by @AndrewRyanChama. ([\#11985](https://github.com/matrix-org/synapse/issues/11985))
- Track cache invalidations in Prometheus metrics, as already happens for cache eviction based on size or time. ([\#12000](https://github.com/matrix-org/synapse/issues/12000))
- Implement experimental support for [MSC3720](https://github.com/matrix-org/matrix-doc/pull/3720) (account status endpoints). ([\#12001](https://github.com/matrix-org/synapse/issues/12001), [\#12067](https://github.com/matrix-org/synapse/issues/12067))
- Enable modules to set a custom display name when registering a user. ([\#12009](https://github.com/matrix-org/synapse/issues/12009))
- Advertise Matrix 1.1 and 1.2 support on `/_matrix/client/versions`. ([\#12020](https://github.com/matrix-org/synapse/issues/12020), ([\#12022](https://github.com/matrix-org/synapse/issues/12022))
- Support only the stable identifier for [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069)'s `is_guest` on `/_matrix/client/v3/account/whoami`. ([\#12021](https://github.com/matrix-org/synapse/issues/12021))
- Use room version 9 as the default room version (per [MSC3589](https://github.com/matrix-org/matrix-doc/pull/3589)). ([\#12058](https://github.com/matrix-org/synapse/issues/12058))
- Add module callbacks to react to user deactivation status changes (i.e. deactivations and reactivations) and profile updates. ([\#12062](https://github.com/matrix-org/synapse/issues/12062))

Bugfixes
--------

- Fix a bug introduced in Synapse 1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary. ([\#11992](https://github.com/matrix-org/synapse/issues/11992))
- Fix long-standing bug where the `get_rooms_for_user` cache was not correctly invalidated for remote users when the server left a room. ([\#11999](https://github.com/matrix-org/synapse/issues/11999))
- Fix a 500 error with Postgres when looking backwards with the [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) `/timestamp_to_event?dir=b` endpoint. ([\#12024](https://github.com/matrix-org/synapse/issues/12024))
- Properly fix a long-standing bug where wrong data could be inserted into the `event_search` table when using SQLite. This could block running `synapse_port_db` with an `argument of type 'int' is not iterable` error. This bug was partially fixed by a change in Synapse 1.44.0. ([\#12037](https://github.com/matrix-org/synapse/issues/12037))
- Fix slow performance of `/logout` in some cases where refresh tokens are in use. The slowness existed since the initial implementation of refresh tokens in version 1.38.0. ([\#12056](https://github.com/matrix-org/synapse/issues/12056))
- Fix a long-standing bug where Synapse would make additional failing requests over federation for missing data. ([\#12077](https://github.com/matrix-org/synapse/issues/12077))
- Fix occasional `Unhandled error in Deferred` error message. ([\#12089](https://github.com/matrix-org/synapse/issues/12089))
- Fix a bug introduced in Synapse 1.51.0 where incoming federation transactions containing at least one EDU would be dropped if debug logging was enabled for `synapse.8631_debug`. ([\#12098](https://github.com/matrix-org/synapse/issues/12098))
- Fix a long-standing bug which could cause push notifications to malfunction if `use_frozen_dicts` was set in the configuration. ([\#12100](https://github.com/matrix-org/synapse/issues/12100))
- Fix an extremely rare, long-standing bug in `ReadWriteLock` that would cause an error when a newly unblocked writer completes instantly. ([\#12105](https://github.com/matrix-org/synapse/issues/12105))
- Make a `POST` to `/rooms/<room_id>/receipt/m.read/<event_id>` only trigger a push notification if the count of unread messages is different to the one in the last successfully sent push. This reduces server load and load on the receiving device. ([\#11835](https://github.com/matrix-org/synapse/issues/11835))

Updates to the Docker image
---------------------------

- The Docker image no longer automatically creates a temporary volume at `/data`. This is not expected to affect normal usage. ([\#11997](https://github.com/matrix-org/synapse/issues/11997))
- Use Python 3.9 in Docker images by default. ([\#12112](https://github.com/matrix-org/synapse/issues/12112))

Improved Documentation
----------------------

- Document support for the `to_device`, `account_data`, `receipts`, and `presence` stream writers for workers. ([\#11599](https://github.com/matrix-org/synapse/issues/11599))
- Explain the meaning of spam checker callbacks' return values. ([\#12003](https://github.com/matrix-org/synapse/issues/12003))
- Clarify information about external Identity Provider IDs. ([\#12004](https://github.com/matrix-org/synapse/issues/12004))

Deprecations and Removals
-------------------------

- Deprecate using `synctl` with the config option `synctl_cache_factor` and print a warning if a user still uses this option. ([\#11865](https://github.com/matrix-org/synapse/issues/11865))
- Remove support for the legacy structured logging configuration (please see the the [upgrade notes](https://matrix-org.github.io/synapse/develop/upgrade#legacy-structured-logging-configuration-removal) if you are using `structured: true` in the Synapse configuration). ([\#12008](https://github.com/matrix-org/synapse/issues/12008))
- Drop support for [MSC3283](https://github.com/matrix-org/matrix-doc/pull/3283) unstable flags now that the stable flags are supported. ([\#12018](https://github.com/matrix-org/synapse/issues/12018))
- Remove the unstable `/spaces` endpoint from [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946). ([\#12073](https://github.com/matrix-org/synapse/issues/12073))

Internal Changes
----------------

- Make the `get_room_version` method use `get_room_version_id` to benefit from caching. ([\#11808](https://github.com/matrix-org/synapse/issues/11808))
- Remove unnecessary condition on knock -> leave auth rule check. ([\#11900](https://github.com/matrix-org/synapse/issues/11900))
- Add tests for device list changes between local users. ([\#11972](https://github.com/matrix-org/synapse/issues/11972))
- Optimise calculating `device_list` changes in `/sync`. ([\#11974](https://github.com/matrix-org/synapse/issues/11974))
- Add missing type hints to storage classes. ([\#11984](https://github.com/matrix-org/synapse/issues/11984))
- Refactor the search code for improved readability. ([\#11991](https://github.com/matrix-org/synapse/issues/11991))
- Move common deduplication code down into `_auth_and_persist_outliers`. ([\#11994](https://github.com/matrix-org/synapse/issues/11994))
- Limit concurrent joins from applications services. ([\#11996](https://github.com/matrix-org/synapse/issues/11996))
- Preparation for faster-room-join work: when parsing the `send_join` response, get the `m.room.create` event from `state`, not `auth_chain`. ([\#12005](https://github.com/matrix-org/synapse/issues/12005), [\#12039](https://github.com/matrix-org/synapse/issues/12039))
- Preparation for faster-room-join work: parse MSC3706 fields in send_join response. ([\#12011](https://github.com/matrix-org/synapse/issues/12011))
- Preparation for faster-room-join work: persist information on which events and rooms have partial state to the database. ([\#12012](https://github.com/matrix-org/synapse/issues/12012))
- Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server. ([\#12013](https://github.com/matrix-org/synapse/issues/12013))
- Configure `tox` to use `venv` rather than `virtualenv`. ([\#12015](https://github.com/matrix-org/synapse/issues/12015))
- Fix bug in `StateFilter.return_expanded()` and add some tests. ([\#12016](https://github.com/matrix-org/synapse/issues/12016))
- Use Matrix v1.1 endpoints (`/_matrix/client/v3/auth/...`) in fallback auth HTML forms. ([\#12019](https://github.com/matrix-org/synapse/issues/12019))
- Update the `olddeps` CI job to use an old version of `markupsafe`. ([\#12025](https://github.com/matrix-org/synapse/issues/12025))
- Upgrade Mypy to version 0.931. ([\#12030](https://github.com/matrix-org/synapse/issues/12030))
- Remove legacy `HomeServer.get_datastore()`. ([\#12031](https://github.com/matrix-org/synapse/issues/12031), [\#12070](https://github.com/matrix-org/synapse/issues/12070))
- Minor typing fixes. ([\#12034](https://github.com/matrix-org/synapse/issues/12034), [\#12069](https://github.com/matrix-org/synapse/issues/12069))
- After joining a room, create a dedicated logcontext to process the queued events. ([\#12041](https://github.com/matrix-org/synapse/issues/12041))
- Tidy up GitHub Actions config which builds distributions for PyPI. ([\#12051](https://github.com/matrix-org/synapse/issues/12051))
- Move configuration out of `setup.cfg`. ([\#12052](https://github.com/matrix-org/synapse/issues/12052), [\#12059](https://github.com/matrix-org/synapse/issues/12059))
- Fix error message when a worker process fails to talk to another worker process. ([\#12060](https://github.com/matrix-org/synapse/issues/12060))
- Fix using the `complement.sh` script without specifying a directory or a branch. Contributed by Nico on behalf of Famedly. ([\#12063](https://github.com/matrix-org/synapse/issues/12063))
- Add type hints to `tests/rest/client`. ([\#12066](https://github.com/matrix-org/synapse/issues/12066), [\#12072](https://github.com/matrix-org/synapse/issues/12072), [\#12084](https://github.com/matrix-org/synapse/issues/12084), [\#12094](https://github.com/matrix-org/synapse/issues/12094))
- Add some logging to `/sync` to try and track down #11916. ([\#12068](https://github.com/matrix-org/synapse/issues/12068))
- Inspect application dependencies using `importlib.metadata` or its backport. ([\#12088](https://github.com/matrix-org/synapse/issues/12088))
- Use `assertEqual` instead of the deprecated `assertEquals` in test code. ([\#12092](https://github.com/matrix-org/synapse/issues/12092))
- Move experimental support for [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440) to `/versions`. ([\#12099](https://github.com/matrix-org/synapse/issues/12099))
- Add `stop_cancellation` utility function to stop `Deferred`s from being cancelled. ([\#12106](https://github.com/matrix-org/synapse/issues/12106))
- Improve exception handling for concurrent execution. ([\#12109](https://github.com/matrix-org/synapse/issues/12109))
- Advertise support for Python 3.10 in packaging files. ([\#12111](https://github.com/matrix-org/synapse/issues/12111))
- Move CI checks out of tox, to facilitate a move to using poetry. ([\#12119](https://github.com/matrix-org/synapse/issues/12119))
2022-03-02 15:26:43 +00:00
Patrick Cloke 1103c5fe8a Check if instances are lists, not sequences. (#12128)
As a str is a sequence, the checks were not granular
enough and would allow lists or strings, when only
lists were valid.
2022-03-02 13:18:51 +00:00
David Robertson f3f0ab10fe Move scripts directory inside synapse, exposing as setuptools entry_points (#12118)
* Two scripts are basically entry_points already
* Move and rename scripts/* to synapse/_scripts/*.py
* Delete sync_room_to_group.pl
* Expose entry points in setup.py
* Update linter script and config
* Fixup scripts & docs mentioning scripts that moved

Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
2022-03-02 13:00:16 +00:00
Patrick Cloke 6adb89ff00 Improve and refactor the tests for relations. (#12113)
* Modernizes code (f-strings, etc.)
* Fixes incorrect comments.
* Splits the test case into two.
* Factors out some duplicated code.
2022-03-02 06:56:16 -05:00
100 changed files with 1707 additions and 1189 deletions
+2 -2
View File
@@ -21,7 +21,7 @@ python -m synapse.app.homeserver --generate-keys -c .ci/sqlite-config.yaml
echo "--- Prepare test database"
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
# Run the export-data command on the sqlite test database
python -m synapse.app.admin_cmd -c .ci/sqlite-config.yaml export-data @anon-20191002_181700-832:localhost:8800 \
@@ -41,7 +41,7 @@ fi
# Port the SQLite databse to postgres so we can check command works against postgres
echo "+++ Port SQLite3 databse to postgres"
scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
# Run the export-data command on postgres database
python -m synapse.app.admin_cmd -c .ci/postgres-config.yaml export-data @anon-20191002_181700-832:localhost:8800 \
+7 -5
View File
@@ -25,17 +25,19 @@ python -m synapse.app.homeserver --generate-keys -c .ci/sqlite-config.yaml
echo "--- Prepare test database"
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
# Create the PostgreSQL database.
.ci/scripts/postgres_exec.py "CREATE DATABASE synapse"
echo "+++ Run synapse_port_db against test database"
coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
# TODO: this invocation of synapse_port_db (and others below) used to be prepended with `coverage run`,
# but coverage seems unable to find the entrypoints installed by `pip install -e .`.
synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
# We should be able to run twice against the same database.
echo "+++ Run synapse_port_db a second time"
coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
#####
@@ -46,7 +48,7 @@ echo "--- Prepare empty SQLite database"
# we do this by deleting the sqlite db, and then doing the same again.
rm .ci/test_db.db
scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
# re-create the PostgreSQL database.
.ci/scripts/postgres_exec.py \
@@ -54,4 +56,4 @@ scripts/update_synapse_database --database-config .ci/sqlite-config.yaml --run-b
"CREATE DATABASE synapse"
echo "+++ Run synapse_port_db against empty database"
coverage run scripts/synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
-1
View File
@@ -3,7 +3,6 @@
# things to include
!docker
!scripts
!synapse
!MANIFEST.in
!README.rst
+2 -2
View File
@@ -31,7 +31,7 @@ jobs:
# if we're running from a tag, get the full list of distros; otherwise just use debian:sid
dists='["debian:sid"]'
if [[ $GITHUB_REF == refs/tags/* ]]; then
dists=$(scripts-dev/build_debian_packages --show-dists-json)
dists=$(scripts-dev/build_debian_packages.py --show-dists-json)
fi
echo "::set-output name=distros::$dists"
# map the step outputs to job outputs
@@ -74,7 +74,7 @@ jobs:
# see https://github.com/docker/build-push-action/issues/252
# for the cache magic here
run: |
./src/scripts-dev/build_debian_packages \
./src/scripts-dev/build_debian_packages.py \
--docker-build-arg=--cache-from=type=local,src=/tmp/.buildx-cache \
--docker-build-arg=--cache-to=type=local,mode=max,dest=/tmp/.buildx-cache-new \
--docker-build-arg=--progress=plain \
+3 -3
View File
@@ -16,7 +16,7 @@ jobs:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- run: pip install -e .
- run: scripts-dev/generate_sample_config --check
- run: scripts-dev/generate_sample_config.sh --check
lint:
runs-on: ubuntu-latest
@@ -51,7 +51,7 @@ jobs:
fetch-depth: 0
- uses: actions/setup-python@v2
- run: "pip install 'towncrier>=18.6.0rc1'"
- run: scripts-dev/check-newsfragment
- run: scripts-dev/check-newsfragment.sh
env:
PULL_REQUEST_NUMBER: ${{ github.event.number }}
@@ -376,7 +376,7 @@ jobs:
# Run Complement
- run: |
set -o pipefail
go test -v -json -p 1 -tags synapse_blacklist,msc2403 ./tests/... 2>&1 | gotestfmt
go test -v -json -p 1 -tags synapse_blacklist,msc2403,msc2716,msc3030 ./tests/... 2>&1 | gotestfmt
shell: bash
name: Run Complement Tests
env:
-2
View File
@@ -17,7 +17,6 @@ recursive-include synapse/storage *.txt
recursive-include synapse/storage *.md
recursive-include docs *
recursive-include scripts *
recursive-include scripts-dev *
recursive-include synapse *.pyi
recursive-include tests *.py
@@ -53,5 +52,4 @@ prune contrib
prune debian
prune demo/etc
prune docker
prune snap
prune stubs
+1
View File
@@ -0,0 +1 @@
Simplify the `ApplicationService` class' set of public methods related to interest checking.
+1
View File
@@ -0,0 +1 @@
Fix complexity checking config example in [Resource Constrained Devices](https://matrix-org.github.io/synapse/v1.54/other/running_synapse_on_single_board_computers.html) docs page.
+1
View File
@@ -0,0 +1 @@
Use the proper serialization format for bundled thread aggregations. The bug has existed since Synapse v1.48.0.
+1
View File
@@ -0,0 +1 @@
Limit the size of `aggregation_key` on annotations.
+1
View File
@@ -0,0 +1 @@
Add type hints to `tests/rest/client`.
+1
View File
@@ -0,0 +1 @@
Refactor the tests for event relations.
+1
View File
@@ -0,0 +1 @@
Move scripts to Synapse package and expose as setuptools entry points.
+1
View File
@@ -0,0 +1 @@
Fix data validation to compare to lists, not sequences.
+1
View File
@@ -0,0 +1 @@
Remove unused mocks from `test_typing`.
+1
View File
@@ -0,0 +1 @@
Give `scripts-dev` scripts suffixes for neater CI config.
+1
View File
@@ -0,0 +1 @@
Move the snapcraft configuration file to `contrib`.
+1
View File
@@ -0,0 +1 @@
Enable [MSC3030](https://github.com/matrix-org/matrix-doc/pull/3030) Complement tests in CI.
+1
View File
@@ -0,0 +1 @@
Enable [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) Complement tests in CI.
+1
View File
@@ -0,0 +1 @@
Add type hints to `tests/rest`.
+1
View File
@@ -0,0 +1 @@
Prune unused jobs from `tox` config.
+1
View File
@@ -0,0 +1 @@
Avoid generating state groups for local out-of-band leaves.
+1
View File
@@ -0,0 +1 @@
Avoid trying to calculate the state at outlier events.
+1
View File
@@ -0,0 +1 @@
Fix some type annotations.
@@ -20,7 +20,7 @@ apps:
generate-config:
command: generate_config
generate-signing-key:
command: generate_signing_key.py
command: generate_signing_key
register-new-matrix-user:
command: register_new_matrix_user
plugs: [network]
-1
View File
@@ -46,7 +46,6 @@ RUN \
&& rm -rf /var/lib/apt/lists/*
# Copy just what we need to pip install
COPY scripts /synapse/scripts/
COPY MANIFEST.in README.rst setup.py synctl /synapse/
COPY synapse/__init__.py /synapse/synapse/__init__.py
COPY synapse/python_dependencies.py /synapse/synapse/python_dependencies.py
+1 -1
View File
@@ -172,6 +172,6 @@ frobber:
```
Note that the sample configuration is generated from the synapse code
and is maintained by a script, `scripts-dev/generate_sample_config`.
and is maintained by a script, `scripts-dev/generate_sample_config.sh`.
Making sure that the output from this script matches the desired format
is left as an exercise for the reader!
+3 -3
View File
@@ -158,9 +158,9 @@ same as integers.
There are three separate aspects to this:
* Any new boolean column must be added to the `BOOLEAN_COLUMNS` list in
`scripts/synapse_port_db`. This tells the port script to cast the integer
value from SQLite to a boolean before writing the value to the postgres
database.
`synapse/_scripts/synapse_port_db.py`. This tells the port script to cast
the integer value from SQLite to a boolean before writing the value to the
postgres database.
* Before SQLite 3.23, `TRUE` and `FALSE` were not recognised as constants by
SQLite, and the `IS [NOT] TRUE`/`IS [NOT] FALSE` operators were not
@@ -31,28 +31,29 @@ Anything that requires modifying the device list [#7721](https://github.com/matr
Put the below in a new file at /etc/matrix-synapse/conf.d/sbc.yaml to override the defaults in homeserver.yaml.
```
# Set to false to disable presence tracking on this homeserver.
# Disable presence tracking, which is currently fairly resource intensive
# More info: https://github.com/matrix-org/synapse/issues/9478
use_presence: false
# When this is enabled, the room "complexity" will be checked before a user
# joins a new remote room. If it is above the complexity limit, the server will
# disallow joining, or will instantly leave.
# Set a small complexity limit, preventing users from joining large rooms
# which may be resource-intensive to remain a part of.
#
# Note that this will not prevent users from joining smaller rooms that
# eventually become complex.
limit_remote_rooms:
# Uncomment to enable room complexity checking.
#enabled: true
enabled: true
complexity: 3.0
# Database configuration
database:
# Use postgres for the best performance
name: psycopg2
args:
user: matrix-synapse
# Generate a long, secure one with a password manager
# Generate a long, secure password using a password manager
password: hunter2
database: matrix-synapse
host: localhost
cp_min: 5
cp_max: 10
```
Currently the complexity is measured by [current_state_events / 500](https://github.com/matrix-org/synapse/blob/v1.20.1/synapse/storage/databases/main/events_worker.py#L986). You can find join times and your most complex rooms like this:
@@ -12,7 +12,7 @@ UPDATE users SET admin = 1 WHERE name = '@foo:bar.com';
```
A new server admin user can also be created using the `register_new_matrix_user`
command. This is a script that is located in the `scripts/` directory, or possibly
command. This is a script that is distributed as part of synapse. It is possibly
already on your `$PATH` depending on how Synapse was installed.
Finding your user's `access_token` is client-dependent, but will usually be shown in the client's settings.
+16 -13
View File
@@ -11,7 +11,7 @@ local_partial_types = True
no_implicit_optional = True
files =
scripts-dev/sign_json,
scripts-dev/,
setup.py,
synapse/,
tests/
@@ -23,6 +23,20 @@ files =
# https://docs.python.org/3/library/re.html#re.X
exclude = (?x)
^(
|scripts-dev/build_debian_packages.py
|scripts-dev/check_signature.py
|scripts-dev/definitions.py
|scripts-dev/federation_client.py
|scripts-dev/hash_history.py
|scripts-dev/list_url_patterns.py
|scripts-dev/release.py
|scripts-dev/tail-synapse.py
|synapse/_scripts/export_signing_key.py
|synapse/_scripts/move_remote_media_to_new_store.py
|synapse/_scripts/synapse_port_db.py
|synapse/_scripts/update_synapse_database.py
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/__init__.py
|synapse/storage/databases/main/cache.py
@@ -74,15 +88,7 @@ exclude = (?x)
|tests/push/test_http.py
|tests/push/test_presentable_names.py
|tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_account.py
|tests/rest/client/test_filter.py
|tests/rest/client/test_report_event.py
|tests/rest/client/test_rooms.py
|tests/rest/client/test_third_party_rules.py
|tests/rest/client/test_transactions.py
|tests/rest/client/test_typing.py
|tests/rest/key/v2/test_remote_key_resource.py
|tests/rest/media/v1/test_base.py
|tests/rest/media/v1/test_media_storage.py
|tests/rest/media/v1/test_url_preview.py
|tests/scripts/test_new_matrix_user.py
@@ -246,10 +252,7 @@ disallow_untyped_defs = True
[mypy-tests.storage.test_user_directory]
disallow_untyped_defs = True
[mypy-tests.rest.admin.*]
disallow_untyped_defs = True
[mypy-tests.rest.client.*]
[mypy-tests.rest.*]
disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client]
+1 -1
View File
@@ -71,4 +71,4 @@ fi
# Run the tests!
echo "Images built; running complement"
go test -v -tags synapse_blacklist,msc2403 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/...
go test -v -tags synapse_blacklist,msc2403,msc2716,msc3030 -count=1 $EXTRA_COMPLEMENT_ARGS ./tests/...
-28
View File
@@ -1,28 +0,0 @@
#!/usr/bin/env bash
#
# Update/check the docs/sample_config.yaml
set -e
cd "$(dirname "$0")/.."
SAMPLE_CONFIG="docs/sample_config.yaml"
SAMPLE_LOG_CONFIG="docs/sample_log_config.yaml"
check() {
diff -u "$SAMPLE_LOG_CONFIG" <(./scripts/generate_log_config) >/dev/null || return 1
}
if [ "$1" == "--check" ]; then
diff -u "$SAMPLE_CONFIG" <(./scripts/generate_config --header-file docs/.sample_config_header.yaml) >/dev/null || {
echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2
exit 1
}
diff -u "$SAMPLE_LOG_CONFIG" <(./scripts/generate_log_config) >/dev/null || {
echo -e "\e[1m\e[31m$SAMPLE_LOG_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2
exit 1
}
else
./scripts/generate_config --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG"
./scripts/generate_log_config -o "$SAMPLE_LOG_CONFIG"
fi
+28
View File
@@ -0,0 +1,28 @@
#!/usr/bin/env bash
#
# Update/check the docs/sample_config.yaml
set -e
cd "$(dirname "$0")/.."
SAMPLE_CONFIG="docs/sample_config.yaml"
SAMPLE_LOG_CONFIG="docs/sample_log_config.yaml"
check() {
diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || return 1
}
if [ "$1" == "--check" ]; then
diff -u "$SAMPLE_CONFIG" <(synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml) >/dev/null || {
echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config.sh\`.\e[0m" >&2
exit 1
}
diff -u "$SAMPLE_LOG_CONFIG" <(synapse/_scripts/generate_log_config.py) >/dev/null || {
echo -e "\e[1m\e[31m$SAMPLE_LOG_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config.sh\`.\e[0m" >&2
exit 1
}
else
synapse/_scripts/generate_config.py --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG"
synapse/_scripts/generate_log_config.py -o "$SAMPLE_LOG_CONFIG"
fi
-9
View File
@@ -84,16 +84,7 @@ else
files=(
"synapse" "docker" "tests"
# annoyingly, black doesn't find these so we have to list them
"scripts/export_signing_key"
"scripts/generate_config"
"scripts/generate_log_config"
"scripts/hash_password"
"scripts/register_new_matrix_user"
"scripts/synapse_port_db"
"scripts/update_synapse_database"
"scripts-dev"
"scripts-dev/build_debian_packages"
"scripts-dev/sign_json"
"contrib" "synctl" "setup.py" "synmark" "stubs" ".ci"
)
fi
+3 -3
View File
@@ -147,7 +147,7 @@ python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG"
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
echo "Running db background jobs..."
scripts/update_synapse_database --database-config --run-background-updates "$SQLITE_CONFIG"
synapse/_scripts/update_synapse_database.py --database-config --run-background-updates "$SQLITE_CONFIG"
# Create the PostgreSQL database.
echo "Creating postgres database..."
@@ -156,10 +156,10 @@ createdb --lc-collate=C --lc-ctype=C --template=template0 "$POSTGRES_DB_NAME"
echo "Copying data from SQLite3 to Postgres with synapse_port_db..."
if [ -z "$COVERAGE" ]; then
# No coverage needed
scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
synapse/_scripts/synapse_port_db.py --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
else
# Coverage desired
coverage run scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
coverage run synapse/_scripts/synapse_port_db.py --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
fi
# Delete schema_version, applied_schema_deltas and applied_module_schemas tables
-19
View File
@@ -1,19 +0,0 @@
#!/usr/bin/env python
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse._scripts.register_new_matrix_user import main
if __name__ == "__main__":
main()
-19
View File
@@ -1,19 +0,0 @@
#!/usr/bin/env python
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse._scripts.review_recent_signups import main
if __name__ == "__main__":
main()
-45
View File
@@ -1,45 +0,0 @@
#!/usr/bin/env perl
use strict;
use warnings;
use JSON::XS;
use LWP::UserAgent;
use URI::Escape;
if (@ARGV < 4) {
die "usage: $0 <homeserver url> <access_token> <room_id|room_alias> <group_id>\n";
}
my ($hs, $access_token, $room_id, $group_id) = @ARGV;
my $ua = LWP::UserAgent->new();
$ua->timeout(10);
if ($room_id =~ /^#/) {
$room_id = uri_escape($room_id);
$room_id = decode_json($ua->get("${hs}/_matrix/client/r0/directory/room/${room_id}?access_token=${access_token}")->decoded_content)->{room_id};
}
my $room_users = [ keys %{decode_json($ua->get("${hs}/_matrix/client/r0/rooms/${room_id}/joined_members?access_token=${access_token}")->decoded_content)->{joined}} ];
my $group_users = [
(map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/users?access_token=${access_token}" )->decoded_content)->{chunk}}),
(map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/invited_users?access_token=${access_token}" )->decoded_content)->{chunk}}),
];
die "refusing to sync from empty room" unless (@$room_users);
die "refusing to sync to empty group" unless (@$group_users);
my $diff = {};
foreach my $user (@$room_users) { $diff->{$user}++ }
foreach my $user (@$group_users) { $diff->{$user}-- }
foreach my $user (keys %$diff) {
if ($diff->{$user} == 1) {
warn "inviting $user";
print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/invite/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
}
elsif ($diff->{$user} == -1) {
warn "removing $user";
print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/remove/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
}
}
+12 -2
View File
@@ -15,7 +15,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
from typing import Any, Dict
@@ -153,8 +152,19 @@ setup(
python_requires="~=3.7",
entry_points={
"console_scripts": [
# Application
"synapse_homeserver = synapse.app.homeserver:main",
"synapse_worker = synapse.app.generic_worker:main",
# Scripts
"export_signing_key = synapse._scripts.export_signing_key:main",
"generate_config = synapse._scripts.generate_config:main",
"generate_log_config = synapse._scripts.generate_log_config:main",
"generate_signing_key = synapse._scripts.generate_signing_key:main",
"hash_password = synapse._scripts.hash_password:main",
"register_new_matrix_user = synapse._scripts.register_new_matrix_user:main",
"synapse_port_db = synapse._scripts.synapse_port_db:main",
"synapse_review_recent_signups = synapse._scripts.review_recent_signups:main",
"update_synapse_database = synapse._scripts.update_synapse_database:main",
]
},
classifiers=[
@@ -167,6 +177,6 @@ setup(
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
scripts=["synctl"] + glob.glob("scripts/*"),
scripts=["synctl"],
cmdclass={"test": TestCommand},
)
@@ -50,7 +50,7 @@ def format_for_config(public_key: nacl.signing.VerifyKey, expiry_ts: int):
)
if __name__ == "__main__":
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -85,7 +85,6 @@ if __name__ == "__main__":
else format_plain
)
keys = []
for file in args.key_file:
try:
res = read_signing_keys(file)
@@ -98,3 +97,7 @@ if __name__ == "__main__":
res = []
for key in res:
formatter(get_verify_key(key))
if __name__ == "__main__":
main()
@@ -6,7 +6,8 @@ import sys
from synapse.config.homeserver import HomeServerConfig
if __name__ == "__main__":
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config-dir",
@@ -76,3 +77,7 @@ if __name__ == "__main__":
shutil.copyfileobj(args.header_file, args.output_file)
args.output_file.write(conf)
if __name__ == "__main__":
main()
@@ -19,7 +19,8 @@ import sys
from synapse.config.logger import DEFAULT_LOG_CONFIG
if __name__ == "__main__":
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -42,3 +43,7 @@ if __name__ == "__main__":
out = args.output_file
out.write(DEFAULT_LOG_CONFIG.substitute(log_file=args.log_file))
out.flush()
if __name__ == "__main__":
main()
@@ -19,7 +19,8 @@ from signedjson.key import generate_signing_key, write_signing_keys
from synapse.util.stringutils import random_string
if __name__ == "__main__":
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -34,3 +35,7 @@ if __name__ == "__main__":
key_id = "a_" + random_string(4)
key = (generate_signing_key(key_id),)
write_signing_keys(args.output_file, key)
if __name__ == "__main__":
main()
@@ -8,9 +8,6 @@ import unicodedata
import bcrypt
import yaml
bcrypt_rounds = 12
password_pepper = ""
def prompt_for_pass():
password = getpass.getpass("Password: ")
@@ -26,7 +23,10 @@ def prompt_for_pass():
return password
if __name__ == "__main__":
def main():
bcrypt_rounds = 12
password_pepper = ""
parser = argparse.ArgumentParser(
description=(
"Calculate the hash of a new password, so that passwords can be reset"
@@ -77,3 +77,7 @@ if __name__ == "__main__":
).decode("ascii")
print(hashed)
if __name__ == "__main__":
main()
@@ -28,7 +28,7 @@ This can be extracted from postgres with::
To use, pipe the above into::
PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
PYTHON_PATH=. synapse/_scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
"""
import argparse
@@ -1146,7 +1146,7 @@ class TerminalProgress(Progress):
##############################################
if __name__ == "__main__":
def main():
parser = argparse.ArgumentParser(
description="A script to port an existing synapse SQLite database to"
" a new PostgreSQL database."
@@ -1251,3 +1251,7 @@ if __name__ == "__main__":
sys.stderr.write(end_error)
sys.exit(5)
if __name__ == "__main__":
main()
+96 -43
View File
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,7 +23,7 @@ from netaddr import IPSet
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import GroupID, JsonDict, UserID, get_domain_from_id
from synapse.types import DeviceLists, GroupID, JsonDict, UserID, get_domain_from_id
from synapse.util.caches.descriptors import _CacheContext, cached
if TYPE_CHECKING:
@@ -175,27 +176,14 @@ class ApplicationService:
return namespace.exclusive
return False
async def _matches_user(self, event: EventBase, store: "DataStore") -> bool:
if self.is_interested_in_user(event.sender):
return True
# also check m.room.member state key
if event.type == EventTypes.Member and self.is_interested_in_user(
event.state_key
):
return True
does_match = await self.matches_user_in_member_list(event.room_id, store)
return does_match
@cached(num_args=1, cache_context=True)
async def matches_user_in_member_list(
async def _matches_user_in_member_list(
self,
room_id: str,
store: "DataStore",
cache_context: _CacheContext,
) -> bool:
"""Check if this service is interested a room based upon it's membership
"""Check if this service is interested a room based upon its membership
Args:
room_id: The room to check.
@@ -214,47 +202,110 @@ class ApplicationService:
return True
return False
def _matches_room_id(self, event: EventBase) -> bool:
if hasattr(event, "room_id"):
return self.is_interested_in_room(event.room_id)
return False
def is_interested_in_user(
self,
user_id: str,
) -> bool:
"""
Returns whether the application is interested in a given user ID.
async def _matches_aliases(self, event: EventBase, store: "DataStore") -> bool:
alias_list = await store.get_aliases_for_room(event.room_id)
The appservice is considered to be interested in a user if either: the
user ID is in the appservice's user namespace, or if the user is the
appservice's configured sender_localpart.
Args:
user_id: The ID of the user to check.
Returns:
True if the application service is interested in the user, False if not.
"""
return (
# User is the appservice's sender_localpart user
user_id == self.sender
# User is in the appservice's user namespace
or self.is_user_in_namespace(user_id)
)
@cached(num_args=1, cache_context=True)
async def is_interested_in_room(
self,
room_id: str,
store: "DataStore",
cache_context: _CacheContext,
) -> bool:
"""
Returns whether the application service is interested in a given room ID.
The appservice is considered to be interested in the room if either: the ID or one
of the aliases of the room is in the appservice's room ID or alias namespace
respectively, or if one of the members of the room fall into the appservice's user
namespace.
Args:
room_id: The ID of the room to check.
store: The homeserver's datastore class.
Returns:
True if the application service is interested in the room, False if not.
"""
# Check if we have interest in this room ID
if self.is_room_id_in_namespace(room_id):
return True
# likewise with the room's aliases (if it has any)
alias_list = await store.get_aliases_for_room(room_id)
for alias in alias_list:
if self.is_interested_in_alias(alias):
if self.is_room_alias_in_namespace(alias):
return True
return False
# And finally, perform an expensive check on whether any of the
# users in the room match the appservice's user namespace
return await self._matches_user_in_member_list(
room_id, store, on_invalidate=cache_context.invalidate
)
async def is_interested(self, event: EventBase, store: "DataStore") -> bool:
@cached(num_args=1, cache_context=True)
async def is_interested_in_event(
self,
event_id: str,
event: EventBase,
store: "DataStore",
cache_context: _CacheContext,
) -> bool:
"""Check if this service is interested in this event.
Args:
event_id: The ID of the event to check. This is purely used for simplifying the
caching of calls to this method.
event: The event to check.
store: The datastore to query.
Returns:
True if this service would like to know about this event.
True if this service would like to know about this event, otherwise False.
"""
# Do cheap checks first
if self._matches_room_id(event):
# Check if we're interested in this event's sender by namespace (or if they're the
# sender_localpart user)
if self.is_interested_in_user(event.sender):
return True
# This will check the namespaces first before
# checking the store, so should be run before _matches_aliases
if await self._matches_user(event, store):
# additionally, if this is a membership event, perform the same checks on
# the user it references
if event.type == EventTypes.Member and self.is_interested_in_user(
event.state_key
):
return True
# This will check the store, so should be run last
if await self._matches_aliases(event, store):
# This will check the datastore, so should be run last
if await self.is_interested_in_room(
event.room_id, store, on_invalidate=cache_context.invalidate
):
return True
return False
@cached(num_args=1)
@cached(num_args=1, cache_context=True)
async def is_interested_in_presence(
self, user_id: UserID, store: "DataStore"
self, user_id: UserID, store: "DataStore", cache_context: _CacheContext
) -> bool:
"""Check if this service is interested a user's presence
@@ -272,20 +323,19 @@ class ApplicationService:
# Then find out if the appservice is interested in any of those rooms
for room_id in room_ids:
if await self.matches_user_in_member_list(room_id, store):
if await self.is_interested_in_room(
room_id, store, on_invalidate=cache_context.invalidate
):
return True
return False
def is_interested_in_user(self, user_id: str) -> bool:
return (
bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
or user_id == self.sender
)
def is_user_in_namespace(self, user_id: str) -> bool:
return bool(self._matches_regex(ApplicationService.NS_USERS, user_id))
def is_interested_in_alias(self, alias: str) -> bool:
def is_room_alias_in_namespace(self, alias: str) -> bool:
return bool(self._matches_regex(ApplicationService.NS_ALIASES, alias))
def is_interested_in_room(self, room_id: str) -> bool:
def is_room_id_in_namespace(self, room_id: str) -> bool:
return bool(self._matches_regex(ApplicationService.NS_ROOMS, room_id))
def is_exclusive_user(self, user_id: str) -> bool:
@@ -351,6 +401,7 @@ class AppServiceTransaction:
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceLists,
):
self.service = service
self.id = id
@@ -359,6 +410,7 @@ class AppServiceTransaction:
self.to_device_messages = to_device_messages
self.one_time_key_counts = one_time_key_counts
self.unused_fallback_keys = unused_fallback_keys
self.device_list_summary = device_list_summary
async def send(self, as_api: "ApplicationServiceApi") -> bool:
"""Sends this transaction using the provided AS API interface.
@@ -375,6 +427,7 @@ class AppServiceTransaction:
to_device_messages=self.to_device_messages,
one_time_key_counts=self.one_time_key_counts,
unused_fallback_keys=self.unused_fallback_keys,
device_list_summary=self.device_list_summary,
txn_id=self.id,
)
+22 -12
View File
@@ -1,4 +1,5 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,9 +26,9 @@ from synapse.appservice import (
TransactionUnusedFallbackKeys,
)
from synapse.events import EventBase
from synapse.events.utils import serialize_event
from synapse.events.utils import SerializeEventConfig, serialize_event
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, ThirdPartyInstanceID
from synapse.types import DeviceLists, JsonDict, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
@@ -225,6 +226,7 @@ class ApplicationServiceApi(SimpleHttpClient):
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceLists,
txn_id: Optional[int] = None,
) -> bool:
"""
@@ -268,6 +270,7 @@ class ApplicationServiceApi(SimpleHttpClient):
}
)
# TODO: Update to stable prefixes once MSC3202 completes FCP merge
if service.msc3202_transaction_extensions:
if one_time_key_counts:
body[
@@ -277,6 +280,11 @@ class ApplicationServiceApi(SimpleHttpClient):
body[
"org.matrix.msc3202.device_unused_fallback_keys"
] = unused_fallback_keys
if device_list_summary:
body["org.matrix.msc3202.device_lists"] = {
"changed": list(device_list_summary.changed),
"left": list(device_list_summary.left),
}
try:
await self.put_json(
@@ -321,16 +329,18 @@ class ApplicationServiceApi(SimpleHttpClient):
serialize_event(
e,
time_now,
as_client_event=True,
# If this is an invite or a knock membership event, and we're interested
# in this user, then include any stripped state alongside the event.
include_stripped_room_state=(
e.type == EventTypes.Member
and (
e.membership == Membership.INVITE
or e.membership == Membership.KNOCK
)
and service.is_interested_in_user(e.state_key)
config=SerializeEventConfig(
as_client_event=True,
# If this is an invite or a knock membership event, and we're interested
# in this user, then include any stripped state alongside the event.
include_stripped_room_state=(
e.type == EventTypes.Member
and (
e.membership == Membership.INVITE
or e.membership == Membership.KNOCK
)
and service.is_interested_in_user(e.state_key)
),
),
)
for e in events
+55 -3
View File
@@ -72,7 +72,7 @@ from synapse.events import EventBase
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict
from synapse.types import DeviceLists, JsonDict
from synapse.util import Clock
if TYPE_CHECKING:
@@ -122,6 +122,7 @@ class ApplicationServiceScheduler:
events: Optional[Collection[EventBase]] = None,
ephemeral: Optional[Collection[JsonDict]] = None,
to_device_messages: Optional[Collection[JsonDict]] = None,
device_list_summary: Optional[DeviceLists] = None,
) -> None:
"""
Enqueue some data to be sent off to an application service.
@@ -133,10 +134,18 @@ class ApplicationServiceScheduler:
to_device_messages: The to-device messages to send. These differ from normal
to-device messages sent to clients, as they have 'to_device_id' and
'to_user_id' fields.
device_list_summary: A summary of users that the application service either needs
to refresh the device lists of, or those that the application service need no
longer track the device lists of.
"""
# We purposefully allow this method to run with empty events/ephemeral
# collections, so that callers do not need to check iterable size themselves.
if not events and not ephemeral and not to_device_messages:
if (
not events
and not ephemeral
and not to_device_messages
and not device_list_summary
):
return
if events:
@@ -147,6 +156,10 @@ class ApplicationServiceScheduler:
self.queuer.queued_to_device_messages.setdefault(appservice.id, []).extend(
to_device_messages
)
if device_list_summary:
self.queuer.queued_device_list_summaries.setdefault(
appservice.id, []
).append(device_list_summary)
# Kick off a new application service transaction
self.queuer.start_background_request(appservice)
@@ -169,6 +182,8 @@ class _ServiceQueuer:
self.queued_ephemeral: Dict[str, List[JsonDict]] = {}
# dict of {service_id: [to_device_message_json]}
self.queued_to_device_messages: Dict[str, List[JsonDict]] = {}
# dict of {service_id: [device_list_summary]}
self.queued_device_list_summaries: Dict[str, List[DeviceLists]] = {}
# the appservices which currently have a transaction in flight
self.requests_in_flight: Set[str] = set()
@@ -212,7 +227,40 @@ class _ServiceQueuer:
]
del all_to_device_messages[:MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION]
if not events and not ephemeral and not to_device_messages_to_send:
# Consolidate any pending device list summaries into a single, up-to-date
# summary.
# Note: this code assumes that in a single DeviceLists, a user will
# never be in both "changed" and "left" sets.
device_list_summary = DeviceLists()
while self.queued_device_list_summaries.get(service.id, []):
# Pop a summary off the front of the queue
summary = self.queued_device_list_summaries[service.id].pop(0)
# For every user in the incoming "changed" set:
# * Remove them from the existing "left" set if necessary
# (as we need to start tracking them again)
# * Add them to the existing "changed" set if necessary.
for user_id in summary.changed:
if user_id in device_list_summary.left:
device_list_summary.left.remove(user_id)
device_list_summary.changed.add(user_id)
# For every user in the incoming "left" set:
# * Remove them from the existing "changed" set if necessary
# (we no longer need to track them)
# * Add them to the existing "left" set if necessary.
for user_id in summary.left:
if user_id in device_list_summary.changed:
device_list_summary.changed.remove(user_id)
device_list_summary.left.add(user_id)
if (
not events
and not ephemeral
and not to_device_messages_to_send
# Note that DeviceLists implements __bool__
and not device_list_summary
):
return
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None
@@ -240,6 +288,7 @@ class _ServiceQueuer:
to_device_messages_to_send,
one_time_key_counts,
unused_fallback_keys,
device_list_summary,
)
except Exception:
logger.exception("AS request failed")
@@ -322,6 +371,7 @@ class _TransactionController:
to_device_messages: Optional[List[JsonDict]] = None,
one_time_key_counts: Optional[TransactionOneTimeKeyCounts] = None,
unused_fallback_keys: Optional[TransactionUnusedFallbackKeys] = None,
device_list_summary: Optional[DeviceLists] = None,
) -> None:
"""
Create a transaction with the given data and send to the provided
@@ -336,6 +386,7 @@ class _TransactionController:
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
device_list_summary: The device list summary to include in the transaction.
"""
try:
txn = await self.store.create_appservice_txn(
@@ -345,6 +396,7 @@ class _TransactionController:
to_device_messages=to_device_messages or [],
one_time_key_counts=one_time_key_counts or {},
unused_fallback_keys=unused_fallback_keys or {},
device_list_summary=device_list_summary or DeviceLists(),
)
service_is_up = await self._is_service_up(service)
if service_is_up:
+1 -1
View File
@@ -383,7 +383,7 @@ class RootConfig:
Build a default configuration file
This is used when the user explicitly asks us to generate a config file
(eg with --generate_config).
(eg with --generate-config).
Args:
config_dir_path: The path where the config files are kept. Used to
+1
View File
@@ -170,6 +170,7 @@ def _load_appservice(
# When enabled, appservice transactions contain the following information:
# - device One-Time Key counts
# - device unused fallback key usage states
# - device list changes
msc3202_transaction_extensions = as_info.get("org.matrix.msc3202", False)
if not isinstance(msc3202_transaction_extensions, bool):
raise ValueError(
+3 -2
View File
@@ -59,8 +59,9 @@ class ExperimentalConfig(Config):
"msc3202_device_masquerading", False
)
# Portion of MSC3202 related to transaction extensions:
# sending one-time key counts and fallback key usage to application services.
# The portion of MSC3202 related to transaction extensions:
# sending device list changes, one-time key counts and fallback key
# usage to application services.
self.msc3202_transaction_extensions: bool = experimental.get(
"msc3202_transaction_extensions", False
)
+54 -27
View File
@@ -26,6 +26,7 @@ from typing import (
Union,
)
import attr
from frozendict import frozendict
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
@@ -303,29 +304,37 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
return d
@attr.s(slots=True, frozen=True, auto_attribs=True)
class SerializeEventConfig:
as_client_event: bool = True
# Function to convert from federation format to client format
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1
# ID of the user's auth token - used for namespacing of transaction IDs
token_id: Optional[int] = None
# List of event fields to include. If empty, all fields will be returned.
only_event_fields: Optional[List[str]] = None
# Some events can have stripped room state stored in the `unsigned` field.
# This is required for invite and knock functionality. If this option is
# False, that state will be removed from the event before it is returned.
# Otherwise, it will be kept.
include_stripped_room_state: bool = False
_DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig()
def serialize_event(
e: Union[JsonDict, EventBase],
time_now_ms: int,
*,
as_client_event: bool = True,
event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
token_id: Optional[str] = None,
only_event_fields: Optional[List[str]] = None,
include_stripped_room_state: bool = False,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
) -> JsonDict:
"""Serialize event for clients
Args:
e
time_now_ms
as_client_event
event_format
token_id
only_event_fields
include_stripped_room_state: Some events can have stripped room state
stored in the `unsigned` field. This is required for invite and knock
functionality. If this option is False, that state will be removed from the
event before it is returned. Otherwise, it will be kept.
config: Event serialization config
Returns:
The serialized event dictionary.
@@ -348,11 +357,11 @@ def serialize_event(
if "redacted_because" in e.unsigned:
d["unsigned"]["redacted_because"] = serialize_event(
e.unsigned["redacted_because"], time_now_ms, event_format=event_format
e.unsigned["redacted_because"], time_now_ms, config=config
)
if token_id is not None:
if token_id == getattr(e.internal_metadata, "token_id", None):
if config.token_id is not None:
if config.token_id == getattr(e.internal_metadata, "token_id", None):
txn_id = getattr(e.internal_metadata, "txn_id", None)
if txn_id is not None:
d["unsigned"]["transaction_id"] = txn_id
@@ -361,13 +370,14 @@ def serialize_event(
# that are meant to provide metadata about a room to an invitee/knocker. They are
# intended to only be included in specific circumstances, such as down sync, and
# should not be included in any other case.
if not include_stripped_room_state:
if not config.include_stripped_room_state:
d["unsigned"].pop("invite_room_state", None)
d["unsigned"].pop("knock_room_state", None)
if as_client_event:
d = event_format(d)
if config.as_client_event:
d = config.event_format(d)
only_event_fields = config.only_event_fields
if only_event_fields:
if not isinstance(only_event_fields, list) or not all(
isinstance(f, str) for f in only_event_fields
@@ -390,18 +400,18 @@ class EventClientSerializer:
event: Union[JsonDict, EventBase],
time_now: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
Args:
event: The event being serialized.
time_now: The current time in milliseconds
config: Event serialization config
bundle_aggregations: Whether to include the bundled aggregations for this
event. Only applies to non-state events. (State events never include
bundled aggregations.)
**kwargs: Arguments to pass to `serialize_event`
Returns:
The serialized event
@@ -410,7 +420,7 @@ class EventClientSerializer:
if not isinstance(event, EventBase):
return event
serialized_event = serialize_event(event, time_now, **kwargs)
serialized_event = serialize_event(event, time_now, config=config)
# Check if there are any bundled aggregations to include with the event.
if bundle_aggregations:
@@ -419,6 +429,7 @@ class EventClientSerializer:
self._inject_bundled_aggregations(
event,
time_now,
config,
bundle_aggregations[event.event_id],
serialized_event,
)
@@ -456,6 +467,7 @@ class EventClientSerializer:
self,
event: EventBase,
time_now: int,
config: SerializeEventConfig,
aggregations: "BundledAggregations",
serialized_event: JsonDict,
) -> None:
@@ -466,6 +478,7 @@ class EventClientSerializer:
time_now: The current time in milliseconds
aggregations: The bundled aggregation to serialize.
serialized_event: The serialized event which may be modified.
config: Event serialization config
"""
serialized_aggregations = {}
@@ -493,8 +506,8 @@ class EventClientSerializer:
thread = aggregations.thread
# Don't bundle aggregations as this could recurse forever.
serialized_latest_event = self.serialize_event(
thread.latest_event, time_now, bundle_aggregations=None
serialized_latest_event = serialize_event(
thread.latest_event, time_now, config=config
)
# Manually apply an edit, if one exists.
if thread.latest_edit:
@@ -515,20 +528,34 @@ class EventClientSerializer:
)
def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
self,
events: Iterable[Union[JsonDict, EventBase]],
time_now: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
) -> List[JsonDict]:
"""Serializes multiple events.
Args:
event
time_now: The current time in milliseconds
**kwargs: Arguments to pass to `serialize_event`
config: Event serialization config
bundle_aggregations: Whether to include the bundled aggregations for this
event. Only applies to non-state events. (State events never include
bundled aggregations.)
Returns:
The list of serialized events
"""
return [
self.serialize_event(event, time_now=time_now, **kwargs) for event in events
self.serialize_event(
event,
time_now,
config=config,
bundle_aggregations=bundle_aggregations,
)
for event in events
]
+4 -4
View File
@@ -1428,7 +1428,7 @@ class FederationClient(FederationBase):
# Validate children_state of the room.
children_state = room.pop("children_state", [])
if not isinstance(children_state, Sequence):
if not isinstance(children_state, list):
raise InvalidResponseError("'room.children_state' must be a list")
if any(not isinstance(e, dict) for e in children_state):
raise InvalidResponseError("Invalid event in 'children_state' list")
@@ -1440,14 +1440,14 @@ class FederationClient(FederationBase):
# Validate the children rooms.
children = res.get("children", [])
if not isinstance(children, Sequence):
if not isinstance(children, list):
raise InvalidResponseError("'children' must be a list")
if any(not isinstance(r, dict) for r in children):
raise InvalidResponseError("Invalid room in 'children' list")
# Validate the inaccessible children.
inaccessible_children = res.get("inaccessible_children", [])
if not isinstance(inaccessible_children, Sequence):
if not isinstance(inaccessible_children, list):
raise InvalidResponseError("'inaccessible_children' must be a list")
if any(not isinstance(r, str) for r in inaccessible_children):
raise InvalidResponseError(
@@ -1630,7 +1630,7 @@ def _validate_hierarchy_event(d: JsonDict) -> None:
raise ValueError("Invalid event: 'content' must be a dict")
via = content.get("via")
if not isinstance(via, Sequence):
if not isinstance(via, list):
raise ValueError("Invalid event: 'via' must be a list")
if any(not isinstance(v, str) for v in via):
raise ValueError("Invalid event: 'via' must be a list of strings")
+140 -8
View File
@@ -33,7 +33,7 @@ from synapse.metrics.background_process_metrics import (
wrap_as_background_process,
)
from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, RoomAlias, RoomStreamToken, UserID
from synapse.types import DeviceLists, JsonDict, RoomAlias, RoomStreamToken, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.metrics import Measure
@@ -58,6 +58,9 @@ class ApplicationServicesHandler:
self._msc2409_to_device_messages_enabled = (
hs.config.experimental.msc2409_to_device_messages_enabled
)
self._msc3202_transaction_extensions_enabled = (
hs.config.experimental.msc3202_transaction_extensions
)
self.current_max = 0
self.is_processing = False
@@ -204,9 +207,9 @@ class ApplicationServicesHandler:
Args:
stream_key: The stream the event came from.
`stream_key` can be "typing_key", "receipt_key", "presence_key" or
"to_device_key". Any other value for `stream_key` will cause this function
to return early.
`stream_key` can be "typing_key", "receipt_key", "presence_key",
"to_device_key" or "device_list_key". Any other value for `stream_key`
will cause this function to return early.
Ephemeral events will only be pushed to appservices that have opted into
receiving them by setting `push_ephemeral` to true in their registration
@@ -230,6 +233,7 @@ class ApplicationServicesHandler:
"receipt_key",
"presence_key",
"to_device_key",
"device_list_key",
):
return
@@ -253,15 +257,37 @@ class ApplicationServicesHandler:
):
return
# Ignore device lists if the feature flag is not enabled
if (
stream_key == "device_list_key"
and not self._msc3202_transaction_extensions_enabled
):
return
# Check whether there are any appservices which have registered to receive
# ephemeral events.
#
# Note that whether these events are actually relevant to these appservices
# is decided later on.
services = self.store.get_app_services()
services = [
service
for service in self.store.get_app_services()
if service.supports_ephemeral
for service in services
# Different stream keys require different support booleans
if (
stream_key
in (
"typing_key",
"receipt_key",
"presence_key",
"to_device_key",
)
and service.supports_ephemeral
)
or (
stream_key == "device_list_key"
and service.msc3202_transaction_extensions
)
]
if not services:
# Bail out early if none of the target appservices have explicitly registered
@@ -336,6 +362,20 @@ class ApplicationServicesHandler:
service, "to_device", new_token
)
elif stream_key == "device_list_key":
device_list_summary = await self._get_device_list_summary(
service, new_token
)
if device_list_summary:
self.scheduler.enqueue_for_appservice(
service, device_list_summary=device_list_summary
)
# Persist the latest handled stream token for this appservice
await self.store.set_appservice_stream_type_pos(
service, "device_list", new_token
)
async def _handle_typing(
self, service: ApplicationService, new_token: int
) -> List[JsonDict]:
@@ -542,6 +582,98 @@ class ApplicationServicesHandler:
return message_payload
async def _get_device_list_summary(
self,
appservice: ApplicationService,
new_key: int,
) -> DeviceLists:
"""
Retrieve a list of users who have changed their device lists.
Args:
appservice: The application service to retrieve device list changes for.
new_key: The stream key of the device list change that triggered this method call.
Returns:
A set of device list updates, comprised of users that the appservices needs to:
* resync the device list of, and
* stop tracking the device list of.
"""
# Fetch the last successfully processed device list update stream ID
# for this appservice.
from_key = await self.store.get_type_stream_id_for_appservice(
appservice, "device_list"
)
# Fetch the users who have modified their device list since then.
users_with_changed_device_lists = (
await self.store.get_users_whose_devices_changed(
from_key, user_ids=None, to_key=new_key
)
)
# Filter out any users the application service is not interested in
#
# For each user who changed their device list, we want to check whether this
# appservice would be interested in the change.
filtered_users_with_changed_device_lists = {
user_id
for user_id in users_with_changed_device_lists
if await self._is_appservice_interested_in_device_lists_of_user(
appservice, user_id
)
}
# Create a summary of "changed" and "left" users.
# TODO: Calculate "left" users.
device_list_summary = DeviceLists(
changed=filtered_users_with_changed_device_lists
)
return device_list_summary
async def _is_appservice_interested_in_device_lists_of_user(
self,
appservice: ApplicationService,
user_id: str,
) -> bool:
"""
Returns whether a given application service is interested in the device list
updates of a given user.
The application service is interested in the user's device list updates if any
of the following are true:
* The user is the appservice's sender localpart user.
* The user is in the appservice's user namespace.
* At least one member of one room that the user is a part of is in the
appservice's user namespace.
* The appservice is explicitly (via room ID or alias) interested in at
least one room that the user is in.
Args:
appservice: The application service to gauge interest of.
user_id: The ID of the user whose device list interest is in question.
Returns:
True if the application service is interested in the user's device lists, False
otherwise.
"""
# This method checks against both the sender localpart user as well as if the
# user is in the appservice's user namespace.
if appservice.is_interested_in_user(user_id):
return True
# FIXME: This is quite an expensive check. This method is called per device
# list change.
room_ids = await self.store.get_rooms_for_user(user_id)
for room_id in room_ids:
# This method covers checking room members for appservice interest as well as
# room ID and alias checks.
if await appservice.is_interested_in_room(room_id, self.store):
return True
return False
async def query_user_exists(self, user_id: str) -> bool:
"""Check if any application service knows this user_id exists.
@@ -571,7 +703,7 @@ class ApplicationServicesHandler:
room_alias_str = room_alias.to_string()
services = self.store.get_app_services()
alias_query_services = [
s for s in services if (s.is_interested_in_alias(room_alias_str))
s for s in services if (s.is_room_alias_in_namespace(room_alias_str))
]
for alias_service in alias_query_services:
is_known_alias = await self.appservice_api.query_alias(
@@ -660,7 +792,7 @@ class ApplicationServicesHandler:
# inside of a list comprehension anymore.
interested_list = []
for s in services:
if await s.is_interested(event, self.store):
if await s.is_interested_in_event(event.event_id, event, self.store):
interested_list.append(s)
return interested_list
+3 -3
View File
@@ -119,7 +119,7 @@ class DirectoryHandler:
service = requester.app_service
if service:
if not service.is_interested_in_alias(room_alias_str):
if not service.is_room_alias_in_namespace(room_alias_str):
raise SynapseError(
400,
"This application service has not reserved this kind of alias.",
@@ -221,7 +221,7 @@ class DirectoryHandler:
async def delete_appservice_association(
self, service: ApplicationService, room_alias: RoomAlias
) -> None:
if not service.is_interested_in_alias(room_alias.to_string()):
if not service.is_room_alias_in_namespace(room_alias.to_string()):
raise SynapseError(
400,
"This application service has not reserved this kind of alias",
@@ -376,7 +376,7 @@ class DirectoryHandler:
# non-exclusive locks on the alias (or there are no interested services)
services = self.store.get_app_services()
interested_services = [
s for s in services if s.is_interested_in_alias(alias.to_string())
s for s in services if s.is_room_alias_in_namespace(alias.to_string())
]
for service in interested_services:
+2 -1
View File
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional
from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig
from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, UserID
@@ -120,7 +121,7 @@ class EventStreamHandler:
chunks = self._event_serializer.serialize_events(
events,
time_now,
as_client_event=as_client_event,
config=SerializeEventConfig(as_client_event=as_client_event),
)
chunk = {
+6 -3
View File
@@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.utils import SerializeEventConfig
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.receipts import ReceiptEventSource
@@ -156,6 +157,8 @@ class InitialSyncHandler:
if limit is None:
limit = 10
serializer_options = SerializeEventConfig(as_client_event=as_client_event)
async def handle_room(event: RoomsForUser) -> None:
d: JsonDict = {
"room_id": event.room_id,
@@ -173,7 +176,7 @@ class InitialSyncHandler:
d["invite"] = self._event_serializer.serialize_event(
invite_event,
time_now,
as_client_event=as_client_event,
config=serializer_options,
)
rooms_ret.append(d)
@@ -225,7 +228,7 @@ class InitialSyncHandler:
self._event_serializer.serialize_events(
messages,
time_now=time_now,
as_client_event=as_client_event,
config=serializer_options,
)
),
"start": await start_token.to_string(self.store),
@@ -235,7 +238,7 @@ class InitialSyncHandler:
d["state"] = self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
as_client_event=as_client_event,
config=serializer_options,
)
account_data_events = []
+3
View File
@@ -1069,6 +1069,9 @@ class EventCreationHandler:
if relation_type == RelationTypes.ANNOTATION:
aggregation_key = relation["key"]
if len(aggregation_key) > 500:
raise SynapseError(400, "Aggregation key is too long")
already_exists = await self.store.has_user_annotated_event(
relates_to, event.type, aggregation_key, event.sender
)
+5 -2
View File
@@ -22,6 +22,7 @@ from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import SerializeEventConfig
from synapse.handlers.room import ShutdownRoomResponse
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
@@ -541,13 +542,15 @@ class PaginationHandler:
time_now = self.clock.time_msec()
serialize_options = SerializeEventConfig(as_client_event=as_client_event)
chunk = {
"chunk": (
self._event_serializer.serialize_events(
events,
time_now,
config=serialize_options,
bundle_aggregations=aggregations,
as_client_event=as_client_event,
)
),
"start": await from_token.to_string(self.store),
@@ -556,7 +559,7 @@ class PaginationHandler:
if state:
chunk["state"] = self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event
state, time_now, config=serialize_options
)
return chunk
+1 -1
View File
@@ -269,7 +269,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
# Then filter down to rooms that the AS can read
events = []
for room_id, event in rooms_to_events.items():
if not await service.matches_user_in_member_list(room_id, self.store):
if not await service.is_interested_in_room(room_id, self.store):
continue
events.append(event)
+1 -1
View File
@@ -1736,8 +1736,8 @@ class RoomMemberMasterHandler(RoomMemberHandler):
txn_id=txn_id,
prev_event_ids=prev_event_ids,
auth_event_ids=auth_event_ids,
outlier=True,
)
event.internal_metadata.outlier = True
event.internal_metadata.out_of_band_membership = True
result_event = await self.event_creation_handler.handle_new_client_event(
+1 -1
View File
@@ -857,7 +857,7 @@ class _RoomEntry:
def _has_valid_via(e: EventBase) -> bool:
via = e.content.get("via")
if not via or not isinstance(via, Sequence):
if not via or not isinstance(via, list):
return False
for v in via:
if not isinstance(v, str):
+3 -27
View File
@@ -13,17 +13,7 @@
# limitations under the License.
import itertools
import logging
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
FrozenSet,
List,
Optional,
Set,
Tuple,
)
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@@ -41,6 +31,7 @@ from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
DeviceLists,
JsonDict,
MutableStateMap,
Requester,
@@ -184,21 +175,6 @@ class GroupsSyncResult:
return bool(self.join or self.invite or self.leave)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLists:
"""
Attributes:
changed: List of user_ids whose devices may have changed
left: List of user_ids whose devices we no longer track
"""
changed: Collection[str]
left: Collection[str]
def __bool__(self) -> bool:
return bool(self.changed or self.left)
@attr.s(slots=True, auto_attribs=True)
class _RoomChanges:
"""The set of room entries to include in the sync, plus the set of joined
@@ -1380,7 +1356,7 @@ class SyncHandler:
return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
else:
return DeviceLists(changed=[], left=[])
return DeviceLists()
async def _generate_sync_entry_for_to_device(
self, sync_result_builder: "SyncResultBuilder"
+1 -3
View File
@@ -486,9 +486,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
if handler._room_serials[room_id] <= from_key:
continue
if not await service.matches_user_in_member_list(
room_id, self._main_store
):
if not await service.is_interested_in_room(room_id, self._main_store):
continue
events.append(self._make_event_for(room_id))
+7 -2
View File
@@ -16,7 +16,10 @@ import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import ReceiptTypes
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.events.utils import (
SerializeEventConfig,
format_event_for_client_v2_without_room_id,
)
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
@@ -75,7 +78,9 @@ class NotificationsServlet(RestServlet):
self._event_serializer.serialize_event(
notif_events[pa.event_id],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
config=SerializeEventConfig(
event_format=format_event_for_client_v2_without_room_id
),
)
),
}
+38 -94
View File
@@ -14,24 +14,14 @@
import itertools
import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
Union,
)
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.api.constants import Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.events.utils import (
SerializeEventConfig,
format_event_for_client_v2_without_room_id,
format_event_raw,
)
@@ -48,7 +38,6 @@ from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder
@@ -239,28 +228,31 @@ class SyncRestServlet(RestServlet):
else:
raise Exception("Unknown event format %s" % (filter.event_format,))
serialize_options = SerializeEventConfig(
event_format=event_formatter,
token_id=access_token_id,
only_event_fields=filter.event_fields,
)
stripped_serialize_options = SerializeEventConfig(
event_format=event_formatter,
token_id=access_token_id,
include_stripped_room_state=True,
)
joined = await self.encode_joined(
sync_result.joined,
time_now,
access_token_id,
filter.event_fields,
event_formatter,
sync_result.joined, time_now, serialize_options
)
invited = await self.encode_invited(
sync_result.invited, time_now, access_token_id, event_formatter
sync_result.invited, time_now, stripped_serialize_options
)
knocked = await self.encode_knocked(
sync_result.knocked, time_now, access_token_id, event_formatter
sync_result.knocked, time_now, stripped_serialize_options
)
archived = await self.encode_archived(
sync_result.archived,
time_now,
access_token_id,
filter.event_fields,
event_formatter,
sync_result.archived, time_now, serialize_options
)
logger.debug("building sync response dict")
@@ -339,9 +331,7 @@ class SyncRestServlet(RestServlet):
self,
rooms: List[JoinedSyncResult],
time_now: int,
token_id: Optional[int],
event_fields: List[str],
event_formatter: Callable[[JsonDict], JsonDict],
serialize_options: SerializeEventConfig,
) -> JsonDict:
"""
Encode the joined rooms in a sync result
@@ -349,24 +339,14 @@ class SyncRestServlet(RestServlet):
Args:
rooms: list of sync results for rooms this user is joined to
time_now: current time - used as a baseline for age calculations
token_id: ID of the user's auth token - used for namespacing
of transaction IDs
event_fields: List of event fields to include. If empty,
all fields will be returned.
event_formatter: function to convert from federation format
to client format
serialize_options: Event serializer options
Returns:
The joined rooms list, in our response format
"""
joined = {}
for room in rooms:
joined[room.room_id] = await self.encode_room(
room,
time_now,
token_id,
joined=True,
only_fields=event_fields,
event_formatter=event_formatter,
room, time_now, joined=True, serialize_options=serialize_options
)
return joined
@@ -376,8 +356,7 @@ class SyncRestServlet(RestServlet):
self,
rooms: List[InvitedSyncResult],
time_now: int,
token_id: Optional[int],
event_formatter: Callable[[JsonDict], JsonDict],
serialize_options: SerializeEventConfig,
) -> JsonDict:
"""
Encode the invited rooms in a sync result
@@ -385,10 +364,7 @@ class SyncRestServlet(RestServlet):
Args:
rooms: list of sync results for rooms this user is invited to
time_now: current time - used as a baseline for age calculations
token_id: ID of the user's auth token - used for namespacing
of transaction IDs
event_formatter: function to convert from federation format
to client format
serialize_options: Event serializer options
Returns:
The invited rooms list, in our response format
@@ -396,11 +372,7 @@ class SyncRestServlet(RestServlet):
invited = {}
for room in rooms:
invite = self._event_serializer.serialize_event(
room.invite,
time_now,
token_id=token_id,
event_format=event_formatter,
include_stripped_room_state=True,
room.invite, time_now, config=serialize_options
)
unsigned = dict(invite.get("unsigned", {}))
invite["unsigned"] = unsigned
@@ -415,8 +387,7 @@ class SyncRestServlet(RestServlet):
self,
rooms: List[KnockedSyncResult],
time_now: int,
token_id: Optional[int],
event_formatter: Callable[[Dict], Dict],
serialize_options: SerializeEventConfig,
) -> Dict[str, Dict[str, Any]]:
"""
Encode the rooms we've knocked on in a sync result.
@@ -424,8 +395,7 @@ class SyncRestServlet(RestServlet):
Args:
rooms: list of sync results for rooms this user is knocking on
time_now: current time - used as a baseline for age calculations
token_id: ID of the user's auth token - used for namespacing of transaction IDs
event_formatter: function to convert from federation format to client format
serialize_options: Event serializer options
Returns:
The list of rooms the user has knocked on, in our response format.
@@ -433,11 +403,7 @@ class SyncRestServlet(RestServlet):
knocked = {}
for room in rooms:
knock = self._event_serializer.serialize_event(
room.knock,
time_now,
token_id=token_id,
event_format=event_formatter,
include_stripped_room_state=True,
room.knock, time_now, config=serialize_options
)
# Extract the `unsigned` key from the knock event.
@@ -470,9 +436,7 @@ class SyncRestServlet(RestServlet):
self,
rooms: List[ArchivedSyncResult],
time_now: int,
token_id: Optional[int],
event_fields: List[str],
event_formatter: Callable[[JsonDict], JsonDict],
serialize_options: SerializeEventConfig,
) -> JsonDict:
"""
Encode the archived rooms in a sync result
@@ -480,23 +444,14 @@ class SyncRestServlet(RestServlet):
Args:
rooms: list of sync results for rooms this user is joined to
time_now: current time - used as a baseline for age calculations
token_id: ID of the user's auth token - used for namespacing
of transaction IDs
event_fields: List of event fields to include. If empty,
all fields will be returned.
event_formatter: function to convert from federation format to client format
serialize_options: Event serializer options
Returns:
The archived rooms list, in our response format
"""
joined = {}
for room in rooms:
joined[room.room_id] = await self.encode_room(
room,
time_now,
token_id,
joined=False,
only_fields=event_fields,
event_formatter=event_formatter,
room, time_now, joined=False, serialize_options=serialize_options
)
return joined
@@ -505,10 +460,8 @@ class SyncRestServlet(RestServlet):
self,
room: Union[JoinedSyncResult, ArchivedSyncResult],
time_now: int,
token_id: Optional[int],
joined: bool,
only_fields: Optional[List[str]],
event_formatter: Callable[[JsonDict], JsonDict],
serialize_options: SerializeEventConfig,
) -> JsonDict:
"""
Args:
@@ -524,20 +477,6 @@ class SyncRestServlet(RestServlet):
Returns:
The room, encoded in our response format
"""
def serialize(
events: Iterable[EventBase],
aggregations: Optional[Dict[str, BundledAggregations]] = None,
) -> List[JsonDict]:
return self._event_serializer.serialize_events(
events,
time_now=time_now,
bundle_aggregations=aggregations,
token_id=token_id,
event_format=event_formatter,
only_event_fields=only_fields,
)
state_dict = room.state
timeline_events = room.timeline.events
@@ -554,9 +493,14 @@ class SyncRestServlet(RestServlet):
event.room_id,
)
serialized_state = serialize(state_events)
serialized_timeline = serialize(
timeline_events, room.timeline.bundled_aggregations
serialized_state = self._event_serializer.serialize_events(
state_events, time_now, config=serialize_options
)
serialized_timeline = self._event_serializer.serialize_events(
timeline_events,
time_now,
config=serialize_options,
bundle_aggregations=room.timeline.bundled_aggregations,
)
account_data = room.account_data
+3 -3
View File
@@ -194,7 +194,7 @@ class StateHandler:
}
async def get_current_state_ids(
self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room
@@ -243,7 +243,7 @@ class StateHandler:
return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Iterable[str]
self, room_id: str, event_ids: Collection[str]
) -> Set[str]:
"""Get the hosts that were in a room at the given event ids
@@ -404,7 +404,7 @@ class StateHandler:
@measure_func()
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Iterable[str]
self, room_id: str, event_ids: Collection[str]
) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
+10 -6
View File
@@ -29,7 +29,7 @@ from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.types import JsonDict
from synapse.types import DeviceLists, JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached
@@ -217,6 +217,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
device_list_summary: DeviceLists,
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@@ -231,6 +232,7 @@ class ApplicationServiceTransactionWorkerStore(
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
device_list_summary: The device list summary to include in the transaction.
Returns:
A new transaction.
@@ -268,6 +270,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=to_device_messages,
one_time_key_counts=one_time_key_counts,
unused_fallback_keys=unused_fallback_keys,
device_list_summary=device_list_summary,
)
return await self.db_pool.runInteraction(
@@ -359,8 +362,8 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
# TODO: to-device messages, one-time key counts and unused fallback keys
# are not yet populated for catch-up transactions.
# TODO: to-device messages, one-time key counts, device list summaries and unused
# fallback keys are not yet populated for catch-up transactions.
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
@@ -370,6 +373,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
device_list_summary=DeviceLists(),
)
def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
@@ -430,7 +434,7 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
if type not in ("read_receipt", "presence", "to_device"):
if type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
@@ -446,7 +450,7 @@ class ApplicationServiceTransactionWorkerStore(
)
last_stream_id = txn.fetchone()
if last_stream_id is None or last_stream_id[0] is None: # no row exists
return 0
return 1
else:
return int(last_stream_id[0])
@@ -457,7 +461,7 @@ class ApplicationServiceTransactionWorkerStore(
async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
if stream_type not in ("read_receipt", "presence", "to_device"):
if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (stream_type,)
+37 -15
View File
@@ -681,42 +681,64 @@ class DeviceWorkerStore(SQLBaseStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
async def get_users_whose_devices_changed(
self, from_key: int, user_ids: Iterable[str]
self,
from_key: int,
user_ids: Optional[Iterable[str]] = None,
to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
Args:
from_key: The device lists stream token
user_ids: The user IDs to query for devices.
from_key: The minimum device lists stream token to query device list changes for,
exclusive.
user_ids: If provided, only check if these users have changed their device lists.
Otherwise changes from all users are returned.
to_key: The maximum device lists stream token to query device list changes for,
inclusive.
Returns:
The set of user_ids whose devices have changed since `from_key`
The set of user_ids whose devices have changed since `from_key` (exclusive)
until `to_key` (inclusive).
"""
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
if user_ids is None:
# Get set of all users that have had device list changes since 'from_key'
user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
from_key
)
else:
# The same as above, but filter results to only those users in 'user_ids'
user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
if not to_check:
if not user_ids_to_check:
return set()
def _get_users_whose_devices_changed_txn(txn):
changes = set()
sql = """
stream_id_where_clause = "stream_id > ?"
sql_args = [from_key]
if to_key:
stream_id_where_clause += " AND stream_id <= ?"
sql_args += [to_key]
sql = f"""
SELECT DISTINCT user_id FROM device_lists_stream
WHERE stream_id > ?
WHERE {stream_id_where_clause}
AND
"""
for chunk in batch_iter(to_check, 100):
# Query device changes with a batch of users at a time
for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql + clause, (from_key,) + tuple(args))
sql_args += args
txn.execute(sql + clause, sql_args)
changes.update(user_id for user_id, in txn)
return changes
@@ -0,0 +1,18 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a column to track what device list changes stream id that this application
-- service has been caught up to.
ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT;
+4 -4
View File
@@ -561,7 +561,7 @@ class StateGroupStorage:
return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
self, _room_id: str, event_ids: Collection[str]
) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
@@ -596,7 +596,7 @@ class StateGroupStorage:
return group_to_state[state_group]
async def get_state_groups(
self, room_id: str, event_ids: Iterable[str]
self, room_id: str, event_ids: Collection[str]
) -> Dict[int, List[EventBase]]:
"""Get the state groups for the given list of event_ids
@@ -648,7 +648,7 @@ class StateGroupStorage:
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events(
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
@@ -684,7 +684,7 @@ class StateGroupStorage:
return {event: event_to_state[event] for event in event_ids}
async def get_state_ids_for_events(
self, event_ids: Iterable[str], state_filter: Optional[StateFilter] = None
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
+21
View File
@@ -25,6 +25,7 @@ from typing import (
Match,
MutableMapping,
Optional,
Set,
Tuple,
Type,
TypeVar,
@@ -743,6 +744,26 @@ class ReadReceipt:
data: JsonDict
@attr.s(slots=True, frozen=True, auto_attribs=True)
class DeviceLists:
"""
Attributes:
changed: user_ids whose devices may have changed
left: user_ids whose devices we no longer track
"""
# We need to use a factory here, otherwise `set` is not evaluated at
# object instantiation, but instead at class definition instantiation.
# The latter happening only once, thus always giving you the same sets
# across multiple DeviceLists instances.
# Also see: don't define mutable default arguments.
changed: Set[str] = attr.ib(factory=set)
left: Set[str] = attr.ib(factory=set)
def __bool__(self) -> bool:
return bool(self.changed or self.left)
def get_verify_key_from_cross_signing_key(key_info):
"""Get the key ID and signedjson verify key from a cross-signing key dict
+16 -1
View File
@@ -81,8 +81,9 @@ async def filter_events_for_client(
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
# we exclude outliers at this point, and then handle them separately later
event_id_to_state = await storage.state.get_state_for_events(
frozenset(e.event_id for e in events),
frozenset(e.event_id for e in events if not e.internal_metadata.outlier),
state_filter=StateFilter.from_types(types),
)
@@ -154,6 +155,17 @@ async def filter_events_for_client(
if event.event_id in always_include_ids:
return event
# we need to handle outliers separately, since we don't have the room state.
if event.internal_metadata.outlier:
# Normally these can't be seen by clients, but we make an exception for
# for out-of-band membership events (eg, incoming invites, or rejections of
# said invite) for the user themselves.
if event.type == EventTypes.Member and event.state_key == user_id:
logger.debug("Returning out-of-band-membership event %s", event)
return event
return None
state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
@@ -198,6 +210,9 @@ async def filter_events_for_client(
# Always allow the user to see their own leave events, otherwise
# they won't see the room disappear if they reject the invite
#
# (Note this doesn't work for out-of-band invite rejections, which don't
# have prev_state populated. They are handled above in the outlier code.)
if membership == "leave" and (
prev_membership == "join" or prev_membership == "invite"
):
+34 -11
View File
@@ -36,7 +36,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
hostname="matrix.org", # only used by get_groups_for_user
)
self.event = Mock(
type="m.something", room_id="!foo:bar", sender="@someone:somewhere"
event_id="$abc:xyz",
type="m.something",
room_id="!foo:bar",
sender="@someone:somewhere",
)
self.store = Mock()
@@ -50,7 +53,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
)
@@ -62,7 +67,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertFalse(
(
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
)
@@ -76,7 +83,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
)
@@ -90,7 +99,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
)
@@ -104,7 +115,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertFalse(
(
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
)
@@ -121,7 +134,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
)
@@ -174,7 +189,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertFalse(
(
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
)
@@ -191,7 +208,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
)
@@ -207,7 +226,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested(self.event, self.store)
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
)
@@ -225,7 +246,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
self.assertTrue(
(
yield defer.ensureDeferred(
self.service.is_interested(event=self.event, store=self.store)
self.service.is_interested_in_event(
self.event.event_id, self.event, self.store
)
)
)
)
+33 -15
View File
@@ -24,6 +24,7 @@ from synapse.appservice.scheduler import (
)
from synapse.logging.context import make_deferred_yieldable
from synapse.server import HomeServer
from synapse.types import DeviceLists
from synapse.util import Clock
from tests import unittest
@@ -70,6 +71,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
to_device_messages=[], # txn made and saved
one_time_key_counts={},
unused_fallback_keys={},
device_list_summary=DeviceLists(),
)
self.assertEqual(0, len(self.txnctrl.recoverers)) # no recoverer made
txn.complete.assert_called_once_with(self.store) # txn completed
@@ -96,6 +98,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
to_device_messages=[], # txn made and saved
one_time_key_counts={},
unused_fallback_keys={},
device_list_summary=DeviceLists(),
)
self.assertEqual(0, txn.send.call_count) # txn not sent though
self.assertEqual(0, txn.complete.call_count) # or completed
@@ -124,6 +127,7 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
device_list_summary=DeviceLists(),
)
self.assertEqual(1, self.recoverer_fn.call_count) # recoverer made
self.assertEqual(1, self.recoverer.recover.call_count) # and invoked
@@ -225,7 +229,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
service = Mock(id=4)
event = Mock()
self.scheduler.enqueue_for_appservice(service, events=[event])
self.txn_ctrl.send.assert_called_once_with(service, [event], [], [], None, None)
self.txn_ctrl.send.assert_called_once_with(
service, [event], [], [], None, None, DeviceLists()
)
def test_send_single_event_with_queue(self):
d = defer.Deferred()
@@ -240,12 +246,14 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# (call enqueue_for_appservice multiple times deliberately)
self.scheduler.enqueue_for_appservice(service, events=[event2])
self.scheduler.enqueue_for_appservice(service, events=[event3])
self.txn_ctrl.send.assert_called_with(service, [event], [], [], None, None)
self.txn_ctrl.send.assert_called_with(
service, [event], [], [], None, None, DeviceLists()
)
self.assertEqual(1, self.txn_ctrl.send.call_count)
# Resolve the send event: expect the queued events to be sent
d.callback(service)
self.txn_ctrl.send.assert_called_with(
service, [event2, event3], [], [], None, None
service, [event2, event3], [], [], None, None, DeviceLists()
)
self.assertEqual(2, self.txn_ctrl.send.call_count)
@@ -272,15 +280,21 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# send events for different ASes and make sure they are sent
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event])
self.scheduler.enqueue_for_appservice(srv1, events=[srv_1_event2])
self.txn_ctrl.send.assert_called_with(srv1, [srv_1_event], [], [], None, None)
self.txn_ctrl.send.assert_called_with(
srv1, [srv_1_event], [], [], None, None, DeviceLists()
)
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event])
self.scheduler.enqueue_for_appservice(srv2, events=[srv_2_event2])
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event], [], [], None, None)
self.txn_ctrl.send.assert_called_with(
srv2, [srv_2_event], [], [], None, None, DeviceLists()
)
# make sure callbacks for a service only send queued events for THAT
# service
srv_2_defer.callback(srv2)
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [], [], None, None)
self.txn_ctrl.send.assert_called_with(
srv2, [srv_2_event2], [], [], None, None, DeviceLists()
)
self.assertEqual(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self):
@@ -300,17 +314,17 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# Expect the first event to be sent immediately.
self.txn_ctrl.send.assert_called_with(
service, [event_list[0]], [], [], None, None
service, [event_list[0]], [], [], None, None, DeviceLists()
)
srv_1_defer.callback(service)
# Then send the next 100 events
self.txn_ctrl.send.assert_called_with(
service, event_list[1:101], [], [], None, None
service, event_list[1:101], [], [], None, None, DeviceLists()
)
srv_2_defer.callback(service)
# Then the final 99 events
self.txn_ctrl.send.assert_called_with(
service, event_list[101:], [], [], None, None
service, event_list[101:], [], [], None, None, DeviceLists()
)
self.assertEqual(3, self.txn_ctrl.send.call_count)
@@ -320,7 +334,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = [Mock(name="event")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(
service, [], event_list, [], None, None
service, [], event_list, [], None, None, DeviceLists()
)
def test_send_multiple_ephemeral_no_queue(self):
@@ -329,7 +343,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = [Mock(name="event1"), Mock(name="event2"), Mock(name="event3")]
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(
service, [], event_list, [], None, None
service, [], event_list, [], None, None, DeviceLists()
)
def test_send_single_ephemeral_with_queue(self):
@@ -345,13 +359,15 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
# Send more events: expect send() to NOT be called multiple times.
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_2)
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list_3)
self.txn_ctrl.send.assert_called_with(service, [], event_list_1, [], None, None)
self.txn_ctrl.send.assert_called_with(
service, [], event_list_1, [], None, None, DeviceLists()
)
self.assertEqual(1, self.txn_ctrl.send.call_count)
# Resolve txn_ctrl.send
d.callback(service)
# Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(
service, [], event_list_2 + event_list_3, [], None, None
service, [], event_list_2 + event_list_3, [], None, None, DeviceLists()
)
self.assertEqual(2, self.txn_ctrl.send.call_count)
@@ -365,8 +381,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
event_list = first_chunk + second_chunk
self.scheduler.enqueue_for_appservice(service, ephemeral=event_list)
self.txn_ctrl.send.assert_called_once_with(
service, [], first_chunk, [], None, None
service, [], first_chunk, [], None, None, DeviceLists()
)
d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [], second_chunk, [], None, None)
self.txn_ctrl.send.assert_called_with(
service, [], second_chunk, [], None, None, DeviceLists()
)
self.assertEqual(2, self.txn_ctrl.send.call_count)
+4 -1
View File
@@ -16,6 +16,7 @@ from synapse.api.constants import EventContentFields
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events.utils import (
SerializeEventConfig,
copy_power_levels_contents,
prune_event,
serialize_event,
@@ -392,7 +393,9 @@ class PruneEventTestCase(unittest.TestCase):
class SerializeEventTestCase(unittest.TestCase):
def serialize(self, ev, fields):
return serialize_event(ev, 1479807801915, only_event_fields=fields)
return serialize_event(
ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
)
def test_event_fields_works_with_keys(self):
self.assertEqual(
+163 -15
View File
@@ -15,6 +15,8 @@
from typing import Dict, Iterable, List, Optional
from unittest.mock import Mock
from parameterized import parameterized
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
@@ -59,11 +61,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.event_source = hs.get_event_sources()
def test_notify_interested_services(self):
interested_service = self._mkservice(is_interested=True)
interested_service = self._mkservice(is_interested_in_event=True)
services = [
self._mkservice(is_interested=False),
self._mkservice(is_interested_in_event=False),
interested_service,
self._mkservice(is_interested=False),
self._mkservice(is_interested_in_event=False),
]
self.mock_as_api.query_user.return_value = make_awaitable(True)
@@ -85,7 +87,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_query_user_exists_unknown_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable(None)
@@ -102,7 +104,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_query_user_exists_known_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
@@ -127,11 +129,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
room_id = "!alpha:bet"
servers = ["aperture"]
interested_service = self._mkservice_alias(is_interested_in_alias=True)
interested_service = self._mkservice_alias(is_room_alias_in_namespace=True)
services = [
self._mkservice_alias(is_interested_in_alias=False),
self._mkservice_alias(is_room_alias_in_namespace=False),
interested_service,
self._mkservice_alias(is_interested_in_alias=False),
self._mkservice_alias(is_room_alias_in_namespace=False),
]
self.mock_as_api.query_alias.return_value = make_awaitable(True)
@@ -275,7 +277,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
to be pushed out to interested appservices, and that the stream ID is
updated accordingly.
"""
interested_service = self._mkservice(is_interested=True)
interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
@@ -304,7 +306,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
Test sending out of order ephemeral events to the appservice handler
are ignored.
"""
interested_service = self._mkservice(is_interested=True)
interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
@@ -325,17 +327,45 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service, ephemeral=[]
)
def _mkservice(self, is_interested, protocols=None):
def _mkservice(
self, is_interested_in_event: bool, protocols: Optional[Iterable] = None
) -> Mock:
"""
Create a new mock representing an ApplicationService.
Args:
is_interested_in_event: Whether this application service will be considered
interested in all events.
protocols: The third-party protocols that this application service claims to
support.
Returns:
A mock representing the ApplicationService.
"""
service = Mock()
service.is_interested.return_value = make_awaitable(is_interested)
service.is_interested_in_event.return_value = make_awaitable(
is_interested_in_event
)
service.token = "mock_service_token"
service.url = "mock_service_url"
service.protocols = protocols
return service
def _mkservice_alias(self, is_interested_in_alias):
def _mkservice_alias(self, is_room_alias_in_namespace: bool) -> Mock:
"""
Create a new mock representing an ApplicationService that is or is not interested
any given room aliase.
Args:
is_room_alias_in_namespace: If true, the application service will be interested
in all room aliases that are queried against it. If false, the application
service will not be interested in any room aliases.
Returns:
A mock representing the ApplicationService.
"""
service = Mock()
service.is_interested_in_alias.return_value = is_interested_in_alias
service.is_room_alias_in_namespace.return_value = is_room_alias_in_namespace
service.token = "mock_service_token"
service.url = "mock_service_url"
return service
@@ -443,6 +473,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
to_device_messages,
_otks,
_fbks,
_device_list_summary,
) = self.send_mock.call_args[0]
# Assert that this was the same to-device message that local_user sent
@@ -555,7 +586,15 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
service_id_to_message_count: Dict[str, int] = {}
for call in self.send_mock.call_args_list:
service, _events, _ephemeral, to_device_messages, _otks, _fbks = call[0]
(
service,
_events,
_ephemeral,
to_device_messages,
_otks,
_fbks,
_device_list_summary,
) = call[0]
# Check that this was made to an interested service
self.assertIn(service, interested_appservices)
@@ -599,6 +638,115 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
return appservice
class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase):
"""
Tests that the ApplicationServicesHandler sends device list updates to application
services correctly.
"""
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Allow us to modify cached feature flags mid-test
self.as_handler = hs.get_application_service_handler()
# Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
# will be sent over the wire
self.put_json = simple_async_mock()
hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment]
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
self.hs.get_datastores().main.get_app_services = Mock(
return_value=self._services
)
# Test across a variety of configuration values
@parameterized.expand(
[
(True, True, True),
(True, False, False),
(False, True, False),
(False, False, False),
]
)
@unittest.override_config({"experimental_features": {"": False}})
def test_application_service_receives_device_list_updates(
self,
experimental_feature_enabled: bool,
as_supports_txn_extensions: bool,
as_should_receive_device_list_updates: bool,
):
"""
Tests that an application service receives notice of changed device
lists for a user, when a user changes their device lists.
Arguments above are populated by parameterized.
Args:
as_should_receive_device_list_updates: Whether we expect the AS to receive the
device list changes.
experimental_feature_enabled: Whether the "msc3202_transaction_extensions" experimental
feature is enabled. This feature must be enabled for device lists to ASs to work.
as_supports_txn_extensions: Whether the application service has explicitly registered
to receive information defined by MSC3202 - which includes device list changes.
"""
# Change whether the experimental feature is enabled or disabled before making
# device list changes
self.as_handler._msc3202_transaction_extensions_enabled = (
experimental_feature_enabled
)
# Create an appservice that is interested in "local_user"
appservice = ApplicationService(
token=random_string(10),
hostname="example.com",
id=random_string(10),
sender="@as:example.com",
rate_limited=False,
namespaces={
ApplicationService.NS_USERS: [
{
"regex": "@local_user:.+",
"exclusive": False,
}
],
},
supports_ephemeral=True,
msc3202_transaction_extensions=as_supports_txn_extensions,
# Must be set for Synapse to try pushing data to the AS
hs_token="abcde",
url="some_url",
)
# Register the application service
self._services.append(appservice)
# Register a user on the homeserver
self.local_user = self.register_user("local_user", "password")
self.local_user_token = self.login("local_user", "password")
if as_should_receive_device_list_updates:
# Ensure that the resulting JSON uses the unstable prefix and contains the
# expected users
self.put_json.assert_called_once()
json_body = self.put_json.call_args.kwargs["json_body"]
# Our application service should have received a device list update with
# "local_user" in the "changed" list
device_list_dict = json_body.get("org.matrix.msc3202.device_lists", {})
self.assertEqual([], device_list_dict["left"])
self.assertEqual([self.local_user], device_list_dict["changed"])
else:
# No device list changes should have been sent out
self.put_json.assert_not_called()
class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
# Argument indices for pulling out arguments from a `send_mock`.
ARG_OTK_COUNTS = 4
+155 -131
View File
@@ -15,11 +15,12 @@ import json
import os
import re
from email.parser import Parser
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock
import pkg_resources
from twisted.internet.interfaces import IReactorTCP
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -30,6 +31,7 @@ from synapse.rest import admin
from synapse.rest.client import account, login, register, room
from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -46,7 +48,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# Email config.
@@ -67,20 +69,27 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(config=config)
async def sendmail(
reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
):
self.email_attempts.append(msg)
reactor: IReactorTCP,
smtphost: str,
smtpport: int,
from_addr: str,
to_addr: str,
msg_bytes: bytes,
*args: Any,
**kwargs: Any,
) -> None:
self.email_attempts.append(msg_bytes)
self.email_attempts = []
self.email_attempts: List[bytes] = []
hs.get_send_email_handler()._sendmail = sendmail
return hs
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.submit_token_resource = PasswordResetSubmitTokenResource(hs)
def test_basic_password_reset(self):
def test_basic_password_reset(self) -> None:
"""Test basic password reset flow"""
old_password = "monkey"
new_password = "kangeroo"
@@ -118,7 +127,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.attempt_wrong_password_login("kermit", old_password)
@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_email(self):
def test_ratelimit_by_email(self) -> None:
"""Test that we ratelimit /requestToken for the same email."""
old_password = "monkey"
new_password = "kangeroo"
@@ -139,7 +148,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
)
)
def reset(ip):
def reset(ip: str) -> None:
client_secret = "foobar"
session_id = self._request_token(email, client_secret, ip)
@@ -166,7 +175,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertEqual(cm.exception.code, 429)
def test_basic_password_reset_canonicalise_email(self):
def test_basic_password_reset_canonicalise_email(self) -> None:
"""Test basic password reset flow
Request password reset with different spelling
"""
@@ -206,7 +215,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the old password
self.attempt_wrong_password_login("kermit", old_password)
def test_cant_reset_password_without_clicking_link(self):
def test_cant_reset_password_without_clicking_link(self) -> None:
"""Test that we do actually need to click the link in the email"""
old_password = "monkey"
new_password = "kangeroo"
@@ -241,7 +250,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Assert we can't log in with the new password
self.attempt_wrong_password_login("kermit", new_password)
def test_no_valid_token(self):
def test_no_valid_token(self) -> None:
"""Test that we do actually need to request a token and can't just
make a session up.
"""
@@ -277,7 +286,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.attempt_wrong_password_login("kermit", new_password)
@unittest.override_config({"request_token_inhibit_3pid_errors": True})
def test_password_reset_bad_email_inhibit_error(self):
def test_password_reset_bad_email_inhibit_error(self) -> None:
"""Test that triggering a password reset with an email address that isn't bound
to an account doesn't leak the lack of binding for that address if configured
that way.
@@ -292,7 +301,12 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
self.assertIsNotNone(session_id)
def _request_token(self, email, client_secret, ip="127.0.0.1"):
def _request_token(
self,
email: str,
client_secret: str,
ip: str = "127.0.0.1",
) -> str:
channel = self.make_request(
"POST",
b"account/password/email/requestToken",
@@ -309,7 +323,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
return channel.json_body["sid"]
def _validate_token(self, link):
def _validate_token(self, link: str) -> None:
# Remove the host
path = link.replace("https://example.com", "")
@@ -339,7 +353,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code, channel.result)
def _get_link_from_email(self):
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
raw_msg = self.email_attempts[-1].decode("UTF-8")
@@ -354,14 +368,19 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
if not text:
self.fail("Could not find text portion of email to parse")
assert text is not None
match = re.search(r"https://example.com\S+", text)
assert match, "Could not find link in email"
return match.group(0)
def _reset_password(
self, new_password, session_id, client_secret, expected_code=200
):
self,
new_password: str,
session_id: str,
client_secret: str,
expected_code: int = 200,
) -> None:
channel = self.make_request(
"POST",
b"account/password",
@@ -388,11 +407,11 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver()
return self.hs
def test_deactivate_account(self):
def test_deactivate_account(self) -> None:
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
@@ -407,7 +426,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
channel = self.make_request("GET", "account/whoami", access_token=tok)
self.assertEqual(channel.code, 401)
def test_pending_invites(self):
def test_pending_invites(self) -> None:
"""Tests that deactivating a user rejects every pending invite for them."""
store = self.hs.get_datastores().main
@@ -448,7 +467,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(memberships), 1, memberships)
self.assertEqual(memberships[0].room_id, room_id, memberships)
def deactivate(self, user_id, tok):
def deactivate(self, user_id: str, tok: str) -> None:
request_data = json.dumps(
{
"auth": {
@@ -474,12 +493,12 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
register.register_servlets,
]
def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["allow_guest_access"] = True
return config
def test_GET_whoami(self):
def test_GET_whoami(self) -> None:
device_id = "wouldgohere"
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test", device_id=device_id)
@@ -496,7 +515,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
},
)
def test_GET_whoami_guests(self):
def test_GET_whoami_guests(self) -> None:
channel = self.make_request(
b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}"
)
@@ -516,7 +535,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
},
)
def test_GET_whoami_appservices(self):
def test_GET_whoami_appservices(self) -> None:
user_id = "@as:test"
as_token = "i_am_an_app_service"
@@ -541,7 +560,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(hasattr(whoami, "device_id"))
def _whoami(self, tok):
def _whoami(self, tok: str) -> JsonDict:
channel = self.make_request("GET", "account/whoami", {}, access_token=tok)
self.assertEqual(channel.code, 200)
return channel.json_body
@@ -555,7 +574,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# Email config.
@@ -576,16 +595,23 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.hs = self.setup_test_homeserver(config=config)
async def sendmail(
reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs
):
self.email_attempts.append(msg)
reactor: IReactorTCP,
smtphost: str,
smtpport: int,
from_addr: str,
to_addr: str,
msg_bytes: bytes,
*args: Any,
**kwargs: Any,
) -> None:
self.email_attempts.append(msg_bytes)
self.email_attempts = []
self.email_attempts: List[bytes] = []
self.hs.get_send_email_handler()._sendmail = sendmail
return self.hs
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.user_id = self.register_user("kermit", "test")
@@ -593,83 +619,73 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
self.email = "test@example.com"
self.url_3pid = b"account/3pid"
def test_add_valid_email(self):
self.get_success(self._add_email(self.email, self.email))
def test_add_valid_email(self) -> None:
self._add_email(self.email, self.email)
def test_add_valid_email_second_time(self):
self.get_success(self._add_email(self.email, self.email))
self.get_success(
self._request_token_invalid_email(
self.email,
expected_errcode=Codes.THREEPID_IN_USE,
expected_error="Email is already in use",
)
def test_add_valid_email_second_time(self) -> None:
self._add_email(self.email, self.email)
self._request_token_invalid_email(
self.email,
expected_errcode=Codes.THREEPID_IN_USE,
expected_error="Email is already in use",
)
def test_add_valid_email_second_time_canonicalise(self):
self.get_success(self._add_email(self.email, self.email))
self.get_success(
self._request_token_invalid_email(
"TEST@EXAMPLE.COM",
expected_errcode=Codes.THREEPID_IN_USE,
expected_error="Email is already in use",
)
def test_add_valid_email_second_time_canonicalise(self) -> None:
self._add_email(self.email, self.email)
self._request_token_invalid_email(
"TEST@EXAMPLE.COM",
expected_errcode=Codes.THREEPID_IN_USE,
expected_error="Email is already in use",
)
def test_add_email_no_at(self):
self.get_success(
self._request_token_invalid_email(
"address-without-at.bar",
expected_errcode=Codes.UNKNOWN,
expected_error="Unable to parse email address",
)
def test_add_email_no_at(self) -> None:
self._request_token_invalid_email(
"address-without-at.bar",
expected_errcode=Codes.UNKNOWN,
expected_error="Unable to parse email address",
)
def test_add_email_two_at(self):
self.get_success(
self._request_token_invalid_email(
"foo@foo@test.bar",
expected_errcode=Codes.UNKNOWN,
expected_error="Unable to parse email address",
)
def test_add_email_two_at(self) -> None:
self._request_token_invalid_email(
"foo@foo@test.bar",
expected_errcode=Codes.UNKNOWN,
expected_error="Unable to parse email address",
)
def test_add_email_bad_format(self):
self.get_success(
self._request_token_invalid_email(
"user@bad.example.net@good.example.com",
expected_errcode=Codes.UNKNOWN,
expected_error="Unable to parse email address",
)
def test_add_email_bad_format(self) -> None:
self._request_token_invalid_email(
"user@bad.example.net@good.example.com",
expected_errcode=Codes.UNKNOWN,
expected_error="Unable to parse email address",
)
def test_add_email_domain_to_lower(self):
self.get_success(self._add_email("foo@TEST.BAR", "foo@test.bar"))
def test_add_email_domain_to_lower(self) -> None:
self._add_email("foo@TEST.BAR", "foo@test.bar")
def test_add_email_domain_with_umlaut(self):
self.get_success(self._add_email("foo@Öumlaut.com", "foo@öumlaut.com"))
def test_add_email_domain_with_umlaut(self) -> None:
self._add_email("foo@Öumlaut.com", "foo@öumlaut.com")
def test_add_email_address_casefold(self):
self.get_success(self._add_email("Strauß@Example.com", "strauss@example.com"))
def test_add_email_address_casefold(self) -> None:
self._add_email("Strauß@Example.com", "strauss@example.com")
def test_address_trim(self):
self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
def test_address_trim(self) -> None:
self._add_email(" foo@test.bar ", "foo@test.bar")
@override_config({"rc_3pid_validation": {"burst_count": 3}})
def test_ratelimit_by_ip(self):
def test_ratelimit_by_ip(self) -> None:
"""Tests that adding emails is ratelimited by IP"""
# We expect to be able to set three emails before getting ratelimited.
self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
self._add_email("foo1@test.bar", "foo1@test.bar")
self._add_email("foo2@test.bar", "foo2@test.bar")
self._add_email("foo3@test.bar", "foo3@test.bar")
with self.assertRaises(HttpResponseException) as cm:
self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
self._add_email("foo4@test.bar", "foo4@test.bar")
self.assertEqual(cm.exception.code, 429)
def test_add_email_if_disabled(self):
def test_add_email_if_disabled(self) -> None:
"""Test adding email to profile when doing so is disallowed"""
self.hs.config.registration.enable_3pid_changes = False
@@ -695,7 +711,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -705,10 +721,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email(self):
def test_delete_email(self) -> None:
"""Test deleting an email from profile"""
# Add a threepid
self.get_success(
@@ -727,7 +743,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": self.email},
access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -736,10 +752,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email_if_disabled(self):
def test_delete_email_if_disabled(self) -> None:
"""Test deleting an email from profile when disallowed"""
self.hs.config.registration.enable_3pid_changes = False
@@ -761,7 +777,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -771,11 +787,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
def test_cant_add_email_without_clicking_link(self):
def test_cant_add_email_without_clicking_link(self) -> None:
"""Test that we do actually need to click the link in the email"""
client_secret = "foobar"
session_id = self._request_token(self.email, client_secret)
@@ -797,7 +813,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -807,10 +823,10 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_no_valid_token(self):
def test_no_valid_token(self) -> None:
"""Test that we do actually need to request a token and can't just
make a session up.
"""
@@ -832,7 +848,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -842,11 +858,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@override_config({"next_link_domain_whitelist": None})
def test_next_link(self):
def test_next_link(self) -> None:
"""Tests a valid next_link parameter value with no whitelist (good case)"""
self._request_token(
"something@example.com",
@@ -856,7 +872,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
@override_config({"next_link_domain_whitelist": None})
def test_next_link_exotic_protocol(self):
def test_next_link_exotic_protocol(self) -> None:
"""Tests using a esoteric protocol as a next_link parameter value.
Someone may be hosting a client on IPFS etc.
"""
@@ -868,7 +884,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
@override_config({"next_link_domain_whitelist": None})
def test_next_link_file_uri(self):
def test_next_link_file_uri(self) -> None:
"""Tests next_link parameters cannot be file URI"""
# Attempt to use a next_link value that points to the local disk
self._request_token(
@@ -879,7 +895,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
def test_next_link_domain_whitelist(self):
def test_next_link_domain_whitelist(self) -> None:
"""Tests next_link parameters must fit the whitelist if provided"""
# Ensure not providing a next_link parameter still works
@@ -912,7 +928,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
)
@override_config({"next_link_domain_whitelist": []})
def test_empty_next_link_domain_whitelist(self):
def test_empty_next_link_domain_whitelist(self) -> None:
"""Tests an empty next_lint_domain_whitelist value, meaning next_link is essentially
disallowed
"""
@@ -962,28 +978,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def _request_token_invalid_email(
self,
email,
expected_errcode,
expected_error,
client_secret="foobar",
):
email: str,
expected_errcode: str,
expected_error: str,
client_secret: str = "foobar",
) -> None:
channel = self.make_request(
"POST",
b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(400, channel.code, msg=channel.result["body"])
self.assertEqual(expected_errcode, channel.json_body["errcode"])
self.assertEqual(expected_error, channel.json_body["error"])
def _validate_token(self, link):
def _validate_token(self, link: str) -> None:
# Remove the host
path = link.replace("https://example.com", "")
channel = self.make_request("GET", path, shorthand=False)
self.assertEqual(200, channel.code, channel.result)
def _get_link_from_email(self):
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
raw_msg = self.email_attempts[-1].decode("UTF-8")
@@ -998,12 +1014,13 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
if not text:
self.fail("Could not find text portion of email to parse")
assert text is not None
match = re.search(r"https://example.com\S+", text)
assert match, "Could not find link in email"
return match.group(0)
def _add_email(self, request_email, expected_email):
def _add_email(self, request_email: str, expected_email: str) -> None:
"""Test adding an email to profile"""
previous_email_attempts = len(self.email_attempts)
@@ -1030,7 +1047,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -1039,7 +1056,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
@@ -1055,18 +1072,18 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
url = "/_matrix/client/unstable/org.matrix.msc3720/account_status"
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["experimental_features"] = {"msc3720_enabled": True}
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.requester = self.register_user("requester", "password")
self.requester_tok = self.login("requester", "password")
self.server_name = homeserver.config.server.server_name
self.server_name = hs.config.server.server_name
def test_missing_mxid(self):
def test_missing_mxid(self) -> None:
"""Tests that not providing any MXID raises an error."""
self._test_status(
users=None,
@@ -1074,7 +1091,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_errcode=Codes.MISSING_PARAM,
)
def test_invalid_mxid(self):
def test_invalid_mxid(self) -> None:
"""Tests that providing an invalid MXID raises an error."""
self._test_status(
users=["bad:test"],
@@ -1082,7 +1099,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_errcode=Codes.INVALID_PARAM,
)
def test_local_user_not_exists(self):
def test_local_user_not_exists(self) -> None:
"""Tests that the account status endpoints correctly reports that a user doesn't
exist.
"""
@@ -1098,7 +1115,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[],
)
def test_local_user_exists(self):
def test_local_user_exists(self) -> None:
"""Tests that the account status endpoint correctly reports that a user doesn't
exist.
"""
@@ -1115,7 +1132,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[],
)
def test_local_user_deactivated(self):
def test_local_user_deactivated(self) -> None:
"""Tests that the account status endpoint correctly reports a deactivated user."""
user = self.register_user("someuser", "password")
self.get_success(
@@ -1135,7 +1152,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_failures=[],
)
def test_mixed_local_and_remote_users(self):
def test_mixed_local_and_remote_users(self) -> None:
"""Tests that if some users are remote the account status endpoint correctly
merges the remote responses with the local result.
"""
@@ -1150,7 +1167,13 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"@bad:badremote",
]
async def post_json(destination, path, data, *a, **kwa):
async def post_json(
destination: str,
path: str,
data: Optional[JsonDict] = None,
*a: Any,
**kwa: Any,
) -> Union[JsonDict, list]:
if destination == "remote":
return {
"account_statuses": {
@@ -1160,9 +1183,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
},
}
}
if destination == "otherremote":
return {}
if destination == "badremote":
elif destination == "badremote":
# badremote tries to overwrite the status of a user that doesn't belong
# to it (i.e. users[1]) with false data, which Synapse is expected to
# ignore.
@@ -1176,6 +1197,9 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
},
}
}
# if destination == "otherremote"
else:
return {}
# Register a mock that will return the expected result depending on the remote.
self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json)
@@ -1205,7 +1229,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
expected_failures: Optional[List[str]] = None,
expected_errcode: Optional[str] = None,
):
) -> None:
"""Send a request to the account status endpoint and check that the response
matches with what's expected.
+16 -13
View File
@@ -12,10 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import Codes
from synapse.rest.client import filter
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
@@ -30,11 +32,11 @@ class FilterTestCase(unittest.HomeserverTestCase):
EXAMPLE_FILTER_JSON = b'{"room": {"timeline": {"types": ["m.room.message"]}}}'
servlets = [filter.register_servlets]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.filtering = hs.get_filtering()
self.store = hs.get_datastores().main
def test_add_filter(self):
def test_add_filter(self) -> None:
channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % (self.user_id),
@@ -43,11 +45,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, {"filter_id": "0"})
filter = self.store.get_user_filter(user_localpart="apple", filter_id=0)
filter = self.get_success(
self.store.get_user_filter(user_localpart="apple", filter_id=0)
)
self.pump()
self.assertEqual(filter.result, self.EXAMPLE_FILTER)
self.assertEqual(filter, self.EXAMPLE_FILTER)
def test_add_filter_for_other_user(self):
def test_add_filter_for_other_user(self) -> None:
channel = self.make_request(
"POST",
"/_matrix/client/r0/user/%s/filter" % ("@watermelon:test"),
@@ -57,7 +61,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_add_filter_non_local_user(self):
def test_add_filter_non_local_user(self) -> None:
_is_mine = self.hs.is_mine
self.hs.is_mine = lambda target_user: False
channel = self.make_request(
@@ -70,14 +74,13 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"403")
self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
def test_get_filter(self):
filter_id = defer.ensureDeferred(
def test_get_filter(self) -> None:
filter_id = self.get_success(
self.filtering.add_user_filter(
user_localpart="apple", user_filter=self.EXAMPLE_FILTER
)
)
self.reactor.advance(1)
filter_id = filter_id.result
channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
)
@@ -85,7 +88,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"200")
self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
def test_get_filter_non_existant(self):
def test_get_filter_non_existant(self) -> None:
channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
)
@@ -95,7 +98,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
# Currently invalid params do not have an appropriate errcode
# in errors.py
def test_get_filter_invalid_id(self):
def test_get_filter_invalid_id(self) -> None:
channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
)
@@ -103,7 +106,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.result["code"], b"400")
# No ID also returns an invalid_id error
def test_get_filter_no_id(self):
def test_get_filter_no_id(self) -> None:
channel = self.make_request(
"GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
)
+180 -212
View File
@@ -15,7 +15,7 @@
import itertools
import urllib.parse
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
@@ -34,7 +34,7 @@ from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event
class RelationsTestCase(unittest.HomeserverTestCase):
class BaseRelationsTestCase(unittest.HomeserverTestCase):
servlets = [
relations.register_servlets,
room.register_servlets,
@@ -45,10 +45,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
]
hijack_auth = False
def default_config(self) -> dict:
def default_config(self) -> Dict[str, Any]:
# We need to enable msc1849 support for aggregations
config = super().default_config()
config["experimental_msc1849_support_enabled"] = True
# We enable frozen dicts as relations/edits change event contents, so we
# want to test that we don't modify the events in the caches.
@@ -67,10 +66,62 @@ class RelationsTestCase(unittest.HomeserverTestCase):
res = self.helper.send(self.room, body="Hi!", tok=self.user_token)
self.parent_id = res["event_id"]
def test_send_relation(self) -> None:
"""Tests that sending a relation using the new /send_relation works
creates the right shape of event.
def _create_user(self, localpart: str) -> Tuple[str, str]:
user_id = self.register_user(localpart, "abc123")
access_token = self.login(localpart, "abc123")
return user_id, access_token
def _send_relation(
self,
relation_type: str,
event_type: str,
key: Optional[str] = None,
content: Optional[dict] = None,
access_token: Optional[str] = None,
parent_id: Optional[str] = None,
) -> FakeChannel:
"""Helper function to send a relation pointing at `self.parent_id`
Args:
relation_type: One of `RelationTypes`
event_type: The type of the event to create
key: The aggregation key used for m.annotation relation type.
content: The content of the created event. Will be modified to configure
the m.relates_to key based on the other provided parameters.
access_token: The access token used to send the relation, defaults
to `self.user_token`
parent_id: The event_id this relation relates to. If None, then self.parent_id
Returns:
FakeChannel
"""
if not access_token:
access_token = self.user_token
original_id = parent_id if parent_id else self.parent_id
if content is None:
content = {}
content["m.relates_to"] = {
"event_id": original_id,
"rel_type": relation_type,
}
if key is not None:
content["m.relates_to"]["key"] = key
channel = self.make_request(
"POST",
f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}",
content,
access_token=access_token,
)
return channel
class RelationsTestCase(BaseRelationsTestCase):
def test_send_relation(self) -> None:
"""Tests that sending a relation works."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
self.assertEqual(200, channel.code, channel.json_body)
@@ -79,7 +130,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, event_id),
f"/rooms/{self.room}/event/{event_id}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -317,9 +368,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# Request /sync, limiting it such that only the latest event is returned
# (and not the relation).
filter = urllib.parse.quote_plus(
'{"room": {"timeline": {"limit": 1}}}'.encode()
)
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}')
channel = self.make_request(
"GET", f"/sync?filter={filter}", access_token=self.user_token
)
@@ -404,8 +453,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1%s"
% (self.room, self.parent_id, from_token),
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -544,8 +592,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
% (self.room, self.parent_id),
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -560,47 +607,13 @@ class RelationsTestCase(unittest.HomeserverTestCase):
},
)
def test_aggregation_redactions(self) -> None:
"""Test that annotations get correctly aggregated after a redaction."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEqual(200, channel.code, channel.json_body)
to_redact_event_id = channel.json_body["event_id"]
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
self.assertEqual(200, channel.code, channel.json_body)
# Now lets redact one of the 'a' reactions
channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/redact/%s" % (self.room, to_redact_event_id),
access_token=self.user_token,
content={},
)
self.assertEqual(200, channel.code, channel.json_body)
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s"
% (self.room, self.parent_id),
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(
channel.json_body,
{"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
)
def test_aggregation_must_be_annotation(self) -> None:
"""Test that aggregations must be annotations."""
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s/%s?limit=1"
% (self.room, self.parent_id, RelationTypes.REPLACE),
f"/_matrix/client/unstable/rooms/{self.room}/aggregations"
f"/{self.parent_id}/{RelationTypes.REPLACE}?limit=1",
access_token=self.user_token,
)
self.assertEqual(400, channel.code, channel.json_body)
@@ -691,10 +704,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
}
},
"event_id": thread_2,
"room_id": self.room,
"sender": self.user_id,
"type": "m.room.test",
"user_id": self.user_id,
},
relations_dict[RelationTypes.THREAD].get("latest_event"),
)
@@ -986,9 +997,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# Request sync, but limit the timeline so it becomes limited (and includes
# bundled aggregations).
filter = urllib.parse.quote_plus(
'{"room": {"timeline": {"limit": 2}}}'.encode()
)
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 2}}}')
channel = self.make_request(
"GET", f"/sync?filter={filter}", access_token=self.user_token
)
@@ -1053,7 +1062,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1096,7 +1105,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, reply),
f"/rooms/{self.room}/event/{reply}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1198,7 +1207,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# Request the original event.
channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1217,102 +1226,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
def test_relations_redaction_redacts_edits(self) -> None:
"""Test that edits of an event are redacted when the original event
is redacted.
"""
# Send a new event
res = self.helper.send(self.room, body="Heyo!", tok=self.user_token)
original_event_id = res["event_id"]
# Add a relation
channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
parent_id=original_event_id,
content={
"msgtype": "m.text",
"body": "Wibble",
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
self.assertEqual(200, channel.code, channel.json_body)
# Check the relation is returned
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
% (self.room, original_event_id),
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
self.assertIn("chunk", channel.json_body)
self.assertEqual(len(channel.json_body["chunk"]), 1)
# Redact the original event
channel = self.make_request(
"PUT",
"/rooms/%s/redact/%s/%s"
% (self.room, original_event_id, "test_relations_redaction_redacts_edits"),
access_token=self.user_token,
content="{}",
)
self.assertEqual(200, channel.code, channel.json_body)
# Try to check for remaining m.replace relations
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s/m.replace/m.room.message"
% (self.room, original_event_id),
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
# Check that no relations are returned
self.assertIn("chunk", channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
def test_aggregations_redaction_prevents_access_to_aggregations(self) -> None:
"""Test that annotations of an event are redacted when the original event
is redacted.
"""
# Send a new event
res = self.helper.send(self.room, body="Hello!", tok=self.user_token)
original_event_id = res["event_id"]
# Add a relation
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", key="👍", parent_id=original_event_id
)
self.assertEqual(200, channel.code, channel.json_body)
# Redact the original
channel = self.make_request(
"PUT",
"/rooms/%s/redact/%s/%s"
% (
self.room,
original_event_id,
"test_aggregations_redaction_prevents_access_to_aggregations",
),
access_token=self.user_token,
content="{}",
)
self.assertEqual(200, channel.code, channel.json_body)
# Check that aggregations returns zero
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s/m.annotation/m.reaction"
% (self.room, original_event_id),
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
self.assertIn("chunk", channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
def test_unknown_relations(self) -> None:
"""Unknown relations should be accepted."""
channel = self._send_relation("m.relation.test", "m.room.test")
@@ -1321,8 +1234,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
% (self.room, self.parent_id),
f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1343,7 +1255,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# When bundling the unknown relation is not included.
channel = self.make_request(
"GET",
"/rooms/%s/event/%s" % (self.room, self.parent_id),
f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1352,8 +1264,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
# But unknown relations can be directly queried.
channel = self.make_request(
"GET",
"/_matrix/client/unstable/rooms/%s/aggregations/%s?limit=1"
% (self.room, self.parent_id),
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -1369,58 +1280,6 @@ class RelationsTestCase(unittest.HomeserverTestCase):
raise AssertionError(f"Event {self.parent_id} not found in chunk")
def _send_relation(
self,
relation_type: str,
event_type: str,
key: Optional[str] = None,
content: Optional[dict] = None,
access_token: Optional[str] = None,
parent_id: Optional[str] = None,
) -> FakeChannel:
"""Helper function to send a relation pointing at `self.parent_id`
Args:
relation_type: One of `RelationTypes`
event_type: The type of the event to create
key: The aggregation key used for m.annotation relation type.
content: The content of the created event. Will be modified to configure
the m.relates_to key based on the other provided parameters.
access_token: The access token used to send the relation, defaults
to `self.user_token`
parent_id: The event_id this relation relates to. If None, then self.parent_id
Returns:
FakeChannel
"""
if not access_token:
access_token = self.user_token
original_id = parent_id if parent_id else self.parent_id
if content is None:
content = {}
content["m.relates_to"] = {
"event_id": original_id,
"rel_type": relation_type,
}
if key is not None:
content["m.relates_to"]["key"] = key
channel = self.make_request(
"POST",
f"/_matrix/client/v3/rooms/{self.room}/send/{event_type}",
content,
access_token=access_token,
)
return channel
def _create_user(self, localpart: str) -> Tuple[str, str]:
user_id = self.register_user(localpart, "abc123")
access_token = self.login(localpart, "abc123")
return user_id, access_token
def test_background_update(self) -> None:
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
@@ -1482,3 +1341,112 @@ class RelationsTestCase(unittest.HomeserverTestCase):
[ev["event_id"] for ev in channel.json_body["chunk"]],
[annotation_event_id_good, thread_event_id],
)
class RelationRedactionTestCase(BaseRelationsTestCase):
"""Test the behaviour of relations when the parent or child event is redacted."""
def _redact(self, event_id: str) -> None:
channel = self.make_request(
"POST",
f"/_matrix/client/r0/rooms/{self.room}/redact/{event_id}",
access_token=self.user_token,
content={},
)
self.assertEqual(200, channel.code, channel.json_body)
def test_redact_relation_annotation(self) -> None:
"""Test that annotations of an event are properly handled after the
annotation is redacted.
"""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEqual(200, channel.code, channel.json_body)
to_redact_event_id = channel.json_body["event_id"]
channel = self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
)
self.assertEqual(200, channel.code, channel.json_body)
# Redact one of the reactions.
self._redact(to_redact_event_id)
# Ensure that the aggregations are correct.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(
channel.json_body,
{"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
)
def test_redact_relation_edit(self) -> None:
"""Test that edits of an event are redacted when the original event
is redacted.
"""
# Add a relation
channel = self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
parent_id=self.parent_id,
content={
"msgtype": "m.text",
"body": "Wibble",
"m.new_content": {"msgtype": "m.text", "body": "First edit"},
},
)
self.assertEqual(200, channel.code, channel.json_body)
# Check the relation is returned
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations"
f"/{self.parent_id}/m.replace/m.room.message",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
self.assertIn("chunk", channel.json_body)
self.assertEqual(len(channel.json_body["chunk"]), 1)
# Redact the original event
self._redact(self.parent_id)
# Try to check for remaining m.replace relations
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/relations"
f"/{self.parent_id}/m.replace/m.room.message",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
# Check that no relations are returned
self.assertIn("chunk", channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
def test_redact_parent(self) -> None:
"""Test that annotations of an event are redacted when the original event
is redacted.
"""
# Add a relation
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
self.assertEqual(200, channel.code, channel.json_body)
# Redact the original event.
self._redact(self.parent_id)
# Check that aggregations returns zero
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
self.assertIn("chunk", channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
+15 -10
View File
@@ -14,8 +14,13 @@
import json
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.rest.client import login, report_event, room
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
@@ -28,7 +33,7 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
report_event.register_servlets,
]
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
@@ -42,35 +47,35 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
self.event_id = resp["event_id"]
self.report_path = f"rooms/{self.room_id}/report/{self.event_id}"
def test_reason_str_and_score_int(self):
def test_reason_str_and_score_int(self) -> None:
data = {"reason": "this makes me sad", "score": -100}
self._assert_status(200, data)
def test_no_reason(self):
def test_no_reason(self) -> None:
data = {"score": 0}
self._assert_status(200, data)
def test_no_score(self):
def test_no_score(self) -> None:
data = {"reason": "this makes me sad"}
self._assert_status(200, data)
def test_no_reason_and_no_score(self):
data = {}
def test_no_reason_and_no_score(self) -> None:
data: JsonDict = {}
self._assert_status(200, data)
def test_reason_int_and_score_str(self):
def test_reason_int_and_score_str(self) -> None:
data = {"reason": 10, "score": "string"}
self._assert_status(400, data)
def test_reason_zero_and_score_blank(self):
def test_reason_zero_and_score_blank(self) -> None:
data = {"reason": 0, "score": ""}
self._assert_status(400, data)
def test_reason_and_score_null(self):
def test_reason_and_score_null(self) -> None:
data = {"reason": None, "score": None}
self._assert_status(400, data)
def _assert_status(self, response_status, data):
def _assert_status(self, response_status: int, data: JsonDict) -> None:
channel = self.make_request(
"POST",
self.report_path,
+138 -133
View File
@@ -18,11 +18,12 @@
"""Tests REST events for /rooms paths."""
import json
from typing import Iterable, List
from typing import Any, Dict, Iterable, List, Optional
from unittest.mock import Mock, call
from urllib import parse as urlparse
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.constants import (
@@ -35,7 +36,9 @@ from synapse.api.errors import Codes, HttpResponseException
from synapse.handlers.pagination import PurgeStatus
from synapse.rest import admin
from synapse.rest.client import account, directory, login, profile, room, sync
from synapse.server import HomeServer
from synapse.types import JsonDict, RoomAlias, UserID, create_requester
from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
@@ -45,11 +48,11 @@ PATH_PREFIX = b"/_matrix/client/api/v1"
class RoomBase(unittest.HomeserverTestCase):
rmcreator_id = None
rmcreator_id: Optional[str] = None
servlets = [room.register_servlets, room.register_deprecated_servlets]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver(
"red",
@@ -57,15 +60,15 @@ class RoomBase(unittest.HomeserverTestCase):
federation_client=Mock(),
)
self.hs.get_federation_handler = Mock()
self.hs.get_federation_handler = Mock() # type: ignore[assignment]
self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
return_value=make_awaitable(None)
)
async def _insert_client_ip(*args, **kwargs):
async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
return None
self.hs.get_datastores().main.insert_client_ip = _insert_client_ip
self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment]
return self.hs
@@ -76,7 +79,7 @@ class RoomPermissionsTestCase(RoomBase):
user_id = "@sid1:red"
rmcreator_id = "@notme:red"
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.helper.auth_user_id = self.rmcreator_id
# create some rooms under the name rmcreator_id
@@ -108,12 +111,12 @@ class RoomPermissionsTestCase(RoomBase):
# auth as user_id now
self.helper.auth_user_id = self.user_id
def test_can_do_action(self):
def test_can_do_action(self) -> None:
msg_content = b'{"msgtype":"m.text","body":"hello"}'
seq = iter(range(100))
def send_msg_path():
def send_msg_path() -> str:
return "/rooms/%s/send/m.room.message/mid%s" % (
self.created_rmid,
str(next(seq)),
@@ -148,7 +151,7 @@ class RoomPermissionsTestCase(RoomBase):
channel = self.make_request("PUT", send_msg_path(), msg_content)
self.assertEqual(403, channel.code, msg=channel.result["body"])
def test_topic_perms(self):
def test_topic_perms(self) -> None:
topic_content = b'{"topic":"My Topic Name"}'
topic_path = "/rooms/%s/state/m.room.topic" % self.created_rmid
@@ -214,14 +217,14 @@ class RoomPermissionsTestCase(RoomBase):
self.assertEqual(403, channel.code, msg=channel.result["body"])
def _test_get_membership(
self, room=None, members: Iterable = frozenset(), expect_code=None
):
self, room: str, members: Iterable = frozenset(), expect_code: int = 200
) -> None:
for member in members:
path = "/rooms/%s/state/m.room.member/%s" % (room, member)
channel = self.make_request("GET", path)
self.assertEqual(expect_code, channel.code)
def test_membership_basic_room_perms(self):
def test_membership_basic_room_perms(self) -> None:
# === room does not exist ===
room = self.uncreated_rmid
# get membership of self, get membership of other, uncreated room
@@ -241,7 +244,7 @@ class RoomPermissionsTestCase(RoomBase):
self.helper.join(room=room, user=usr, expect_code=404)
self.helper.leave(room=room, user=usr, expect_code=404)
def test_membership_private_room_perms(self):
def test_membership_private_room_perms(self) -> None:
room = self.created_rmid
# get membership of self, get membership of other, private room + invite
# expect all 403s
@@ -264,7 +267,7 @@ class RoomPermissionsTestCase(RoomBase):
members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
)
def test_membership_public_room_perms(self):
def test_membership_public_room_perms(self) -> None:
room = self.created_public_rmid
# get membership of self, get membership of other, public room + invite
# expect 403
@@ -287,7 +290,7 @@ class RoomPermissionsTestCase(RoomBase):
members=[self.user_id, self.rmcreator_id], room=room, expect_code=200
)
def test_invited_permissions(self):
def test_invited_permissions(self) -> None:
room = self.created_rmid
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
@@ -310,7 +313,7 @@ class RoomPermissionsTestCase(RoomBase):
expect_code=403,
)
def test_joined_permissions(self):
def test_joined_permissions(self) -> None:
room = self.created_rmid
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
self.helper.join(room=room, user=self.user_id)
@@ -348,7 +351,7 @@ class RoomPermissionsTestCase(RoomBase):
# set left of self, expect 200
self.helper.leave(room=room, user=self.user_id)
def test_leave_permissions(self):
def test_leave_permissions(self) -> None:
room = self.created_rmid
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
self.helper.join(room=room, user=self.user_id)
@@ -383,7 +386,7 @@ class RoomPermissionsTestCase(RoomBase):
)
# tests the "from banned" line from the table in https://spec.matrix.org/unstable/client-server-api/#mroommember
def test_member_event_from_ban(self):
def test_member_event_from_ban(self) -> None:
room = self.created_rmid
self.helper.invite(room=room, src=self.rmcreator_id, targ=self.user_id)
self.helper.join(room=room, user=self.user_id)
@@ -475,21 +478,21 @@ class RoomsMemberListTestCase(RoomBase):
user_id = "@sid1:red"
def test_get_member_list(self):
def test_get_member_list(self) -> None:
room_id = self.helper.create_room_as(self.user_id)
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEqual(200, channel.code, msg=channel.result["body"])
def test_get_member_list_no_room(self):
def test_get_member_list_no_room(self) -> None:
channel = self.make_request("GET", "/rooms/roomdoesnotexist/members")
self.assertEqual(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission(self):
def test_get_member_list_no_permission(self) -> None:
room_id = self.helper.create_room_as("@some_other_guy:red")
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEqual(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_with_at_token(self):
def test_get_member_list_no_permission_with_at_token(self) -> None:
"""
Tests that a stranger to the room cannot get the member list
(in the case that they use an at token).
@@ -509,7 +512,7 @@ class RoomsMemberListTestCase(RoomBase):
)
self.assertEqual(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_former_member(self):
def test_get_member_list_no_permission_former_member(self) -> None:
"""
Tests that a former member of the room can not get the member list.
"""
@@ -529,7 +532,7 @@ class RoomsMemberListTestCase(RoomBase):
channel = self.make_request("GET", "/rooms/%s/members" % room_id)
self.assertEqual(403, channel.code, msg=channel.result["body"])
def test_get_member_list_no_permission_former_member_with_at_token(self):
def test_get_member_list_no_permission_former_member_with_at_token(self) -> None:
"""
Tests that a former member of the room can not get the member list
(in the case that they use an at token).
@@ -569,7 +572,7 @@ class RoomsMemberListTestCase(RoomBase):
)
self.assertEqual(403, channel.code, msg=channel.result["body"])
def test_get_member_list_mixed_memberships(self):
def test_get_member_list_mixed_memberships(self) -> None:
room_creator = "@some_other_guy:red"
room_id = self.helper.create_room_as(room_creator)
room_path = "/rooms/%s/members" % room_id
@@ -594,26 +597,26 @@ class RoomsCreateTestCase(RoomBase):
user_id = "@sid1:red"
def test_post_room_no_keys(self):
def test_post_room_no_keys(self) -> None:
# POST with no config keys, expect new room id
channel = self.make_request("POST", "/createRoom", "{}")
self.assertEqual(200, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_visibility_key(self):
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"visibility":"private"}')
self.assertEqual(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_custom_key(self):
def test_post_room_custom_key(self) -> None:
# POST with custom config keys, expect new room id
channel = self.make_request("POST", "/createRoom", b'{"custom":"stuff"}')
self.assertEqual(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_known_and_unknown_keys(self):
def test_post_room_known_and_unknown_keys(self) -> None:
# POST with custom + known config keys, expect new room id
channel = self.make_request(
"POST", "/createRoom", b'{"visibility":"private","custom":"things"}'
@@ -621,7 +624,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(200, channel.code)
self.assertTrue("room_id" in channel.json_body)
def test_post_room_invalid_content(self):
def test_post_room_invalid_content(self) -> None:
# POST with invalid content / paths, expect 400
channel = self.make_request("POST", "/createRoom", b'{"visibili')
self.assertEqual(400, channel.code)
@@ -629,7 +632,7 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", b'["hello"]')
self.assertEqual(400, channel.code)
def test_post_room_invitees_invalid_mxid(self):
def test_post_room_invitees_invalid_mxid(self) -> None:
# POST with invalid invitee, see https://github.com/matrix-org/synapse/issues/4088
# Note the trailing space in the MXID here!
channel = self.make_request(
@@ -638,7 +641,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(400, channel.code)
@unittest.override_config({"rc_invites": {"per_room": {"burst_count": 3}}})
def test_post_room_invitees_ratelimit(self):
def test_post_room_invitees_ratelimit(self) -> None:
"""Test that invites sent when creating a room are ratelimited by a RateLimiter,
which ratelimits them correctly, including by not limiting when the requester is
exempt from ratelimiting.
@@ -674,7 +677,7 @@ class RoomsCreateTestCase(RoomBase):
channel = self.make_request("POST", "/createRoom", content)
self.assertEqual(200, channel.code)
def test_spam_checker_may_join_room(self):
def test_spam_checker_may_join_room(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly bypassed
when creating a new room.
"""
@@ -704,12 +707,12 @@ class RoomTopicTestCase(RoomBase):
user_id = "@sid1:red"
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# create the room
self.room_id = self.helper.create_room_as(self.user_id)
self.path = "/rooms/%s/state/m.room.topic" % (self.room_id,)
def test_invalid_puts(self):
def test_invalid_puts(self) -> None:
# missing keys or invalid json
channel = self.make_request("PUT", self.path, "{}")
self.assertEqual(400, channel.code, msg=channel.result["body"])
@@ -736,7 +739,7 @@ class RoomTopicTestCase(RoomBase):
channel = self.make_request("PUT", self.path, content)
self.assertEqual(400, channel.code, msg=channel.result["body"])
def test_rooms_topic(self):
def test_rooms_topic(self) -> None:
# nothing should be there
channel = self.make_request("GET", self.path)
self.assertEqual(404, channel.code, msg=channel.result["body"])
@@ -751,7 +754,7 @@ class RoomTopicTestCase(RoomBase):
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assert_dict(json.loads(content), channel.json_body)
def test_rooms_topic_with_extra_keys(self):
def test_rooms_topic_with_extra_keys(self) -> None:
# valid put with extra keys
content = '{"topic":"Seasons","subtopic":"Summer"}'
channel = self.make_request("PUT", self.path, content)
@@ -768,10 +771,10 @@ class RoomMemberStateTestCase(RoomBase):
user_id = "@sid1:red"
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
def test_invalid_puts(self):
def test_invalid_puts(self) -> None:
path = "/rooms/%s/state/m.room.member/%s" % (self.room_id, self.user_id)
# missing keys or invalid json
channel = self.make_request("PUT", path, "{}")
@@ -801,7 +804,7 @@ class RoomMemberStateTestCase(RoomBase):
channel = self.make_request("PUT", path, content.encode("ascii"))
self.assertEqual(400, channel.code, msg=channel.result["body"])
def test_rooms_members_self(self):
def test_rooms_members_self(self) -> None:
path = "/rooms/%s/state/m.room.member/%s" % (
urlparse.quote(self.room_id),
self.user_id,
@@ -812,13 +815,13 @@ class RoomMemberStateTestCase(RoomBase):
channel = self.make_request("PUT", path, content.encode("ascii"))
self.assertEqual(200, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, None)
channel = self.make_request("GET", path, content=b"")
self.assertEqual(200, channel.code, msg=channel.result["body"])
expected_response = {"membership": Membership.JOIN}
self.assertEqual(expected_response, channel.json_body)
def test_rooms_members_other(self):
def test_rooms_members_other(self) -> None:
self.other_id = "@zzsid1:red"
path = "/rooms/%s/state/m.room.member/%s" % (
urlparse.quote(self.room_id),
@@ -830,11 +833,11 @@ class RoomMemberStateTestCase(RoomBase):
channel = self.make_request("PUT", path, content)
self.assertEqual(200, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, None)
channel = self.make_request("GET", path, content=b"")
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body)
def test_rooms_members_other_custom_keys(self):
def test_rooms_members_other_custom_keys(self) -> None:
self.other_id = "@zzsid1:red"
path = "/rooms/%s/state/m.room.member/%s" % (
urlparse.quote(self.room_id),
@@ -849,7 +852,7 @@ class RoomMemberStateTestCase(RoomBase):
channel = self.make_request("PUT", path, content)
self.assertEqual(200, channel.code, msg=channel.result["body"])
channel = self.make_request("GET", path, None)
channel = self.make_request("GET", path, content=b"")
self.assertEqual(200, channel.code, msg=channel.result["body"])
self.assertEqual(json.loads(content), channel.json_body)
@@ -866,7 +869,7 @@ class RoomInviteRatelimitTestCase(RoomBase):
@unittest.override_config(
{"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}}
)
def test_invites_by_rooms_ratelimit(self):
def test_invites_by_rooms_ratelimit(self) -> None:
"""Tests that invites in a room are actually rate-limited."""
room_id = self.helper.create_room_as(self.user_id)
@@ -878,7 +881,7 @@ class RoomInviteRatelimitTestCase(RoomBase):
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
def test_invites_by_users_ratelimit(self):
def test_invites_by_users_ratelimit(self) -> None:
"""Tests that invites to a specific user are actually rate-limited."""
for _ in range(3):
@@ -897,7 +900,7 @@ class RoomJoinTestCase(RoomBase):
room.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user1 = self.register_user("thomas", "hackme")
self.tok1 = self.login("thomas", "hackme")
@@ -908,7 +911,7 @@ class RoomJoinTestCase(RoomBase):
self.room2 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
self.room3 = self.helper.create_room_as(room_creator=self.user1, tok=self.tok1)
def test_spam_checker_may_join_room(self):
def test_spam_checker_may_join_room(self) -> None:
"""Tests that the user_may_join_room spam checker callback is correctly called
and blocks room joins when needed.
"""
@@ -975,8 +978,8 @@ class RoomJoinRatelimitTestCase(RoomBase):
room.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
super().prepare(reactor, clock, homeserver)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# profile changes expect that the user is actually registered
user = UserID.from_string(self.user_id)
self.get_success(self.register_user(user.localpart, "supersecretpassword"))
@@ -984,7 +987,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit(self):
def test_join_local_ratelimit(self) -> None:
"""Tests that local joins are actually rate-limited."""
for _ in range(3):
self.helper.create_room_as(self.user_id)
@@ -994,7 +997,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit_profile_change(self):
def test_join_local_ratelimit_profile_change(self) -> None:
"""Tests that sending a profile update into all of the user's joined rooms isn't
rate-limited by the rate-limiter on joins."""
@@ -1031,7 +1034,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit_idempotent(self):
def test_join_local_ratelimit_idempotent(self) -> None:
"""Tests that the room join endpoints remain idempotent despite rate-limiting
on room joins."""
room_id = self.helper.create_room_as(self.user_id)
@@ -1056,7 +1059,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
"autocreate_auto_join_rooms": True,
},
)
def test_autojoin_rooms(self):
def test_autojoin_rooms(self) -> None:
user_id = self.register_user("testuser", "password")
# Check that the new user successfully joined the four rooms
@@ -1071,10 +1074,10 @@ class RoomMessagesTestCase(RoomBase):
user_id = "@sid1:red"
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
def test_invalid_puts(self):
def test_invalid_puts(self) -> None:
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
# missing keys or invalid json
channel = self.make_request("PUT", path, b"{}")
@@ -1095,7 +1098,7 @@ class RoomMessagesTestCase(RoomBase):
channel = self.make_request("PUT", path, b"")
self.assertEqual(400, channel.code, msg=channel.result["body"])
def test_rooms_messages_sent(self):
def test_rooms_messages_sent(self) -> None:
path = "/rooms/%s/send/m.room.message/mid1" % (urlparse.quote(self.room_id))
content = b'{"body":"test","msgtype":{"type":"a"}}'
@@ -1119,11 +1122,11 @@ class RoomInitialSyncTestCase(RoomBase):
user_id = "@sid1:red"
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# create the room
self.room_id = self.helper.create_room_as(self.user_id)
def test_initial_sync(self):
def test_initial_sync(self) -> None:
channel = self.make_request("GET", "/rooms/%s/initialSync" % self.room_id)
self.assertEqual(200, channel.code)
@@ -1131,7 +1134,7 @@ class RoomInitialSyncTestCase(RoomBase):
self.assertEqual("join", channel.json_body["membership"])
# Room state is easier to assert on if we unpack it into a dict
state = {}
state: JsonDict = {}
for event in channel.json_body["state"]:
if "state_key" not in event:
continue
@@ -1160,10 +1163,10 @@ class RoomMessageListTestCase(RoomBase):
user_id = "@sid1:red"
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
def test_topo_token_is_accepted(self):
def test_topo_token_is_accepted(self) -> None:
token = "t1-0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
@@ -1174,7 +1177,7 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
def test_stream_token_is_accepted_for_fwd_pagianation(self):
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
token = "s0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
@@ -1185,7 +1188,7 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("chunk" in channel.json_body)
self.assertTrue("end" in channel.json_body)
def test_room_messages_purge(self):
def test_room_messages_purge(self) -> None:
store = self.hs.get_datastores().main
pagination_handler = self.hs.get_pagination_handler()
@@ -1278,10 +1281,10 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
user_id = True
hijack_auth = False
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Register the user who does the searching
self.user_id = self.register_user("user", "pass")
self.user_id2 = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
# Register the user who sends the message
@@ -1289,12 +1292,12 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
self.other_access_token = self.login("otheruser", "pass")
# Create a room
self.room = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.room = self.helper.create_room_as(self.user_id2, tok=self.access_token)
# Invite the other person
self.helper.invite(
room=self.room,
src=self.user_id,
src=self.user_id2,
tok=self.access_token,
targ=self.other_user_id,
)
@@ -1304,7 +1307,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
room=self.room, user=self.other_user_id, tok=self.other_access_token
)
def test_finds_message(self):
def test_finds_message(self) -> None:
"""
The search functionality will search for content in messages if asked to
do so.
@@ -1333,7 +1336,7 @@ class RoomSearchTestCase(unittest.HomeserverTestCase):
# No context was requested, so we should get none.
self.assertEqual(results["results"][0]["context"], {})
def test_include_context(self):
def test_include_context(self) -> None:
"""
When event_context includes include_profile, profile information will be
included in the search response.
@@ -1379,7 +1382,7 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.url = b"/_matrix/client/r0/publicRooms"
@@ -1389,11 +1392,11 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
return self.hs
def test_restricted_no_auth(self):
def test_restricted_no_auth(self) -> None:
channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401, channel.result)
def test_restricted_auth(self):
def test_restricted_auth(self) -> None:
self.register_user("user", "pass")
tok = self.login("user", "pass")
@@ -1412,19 +1415,19 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(federation_client=Mock())
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.register_user("user", "pass")
self.token = self.login("user", "pass")
self.federation_client = hs.get_federation_client()
def test_simple(self):
def test_simple(self) -> None:
"Simple test for searching rooms over federation"
self.federation_client.get_public_rooms.side_effect = (
lambda *a, **k: defer.succeed({})
self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
{}
)
search_filter = {"generic_search_term": "foobar"}
@@ -1437,7 +1440,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
self.federation_client.get_public_rooms.assert_called_once_with(
self.federation_client.get_public_rooms.assert_called_once_with( # type: ignore[attr-defined]
"testserv",
limit=100,
since_token=None,
@@ -1446,12 +1449,12 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
third_party_instance_id=None,
)
def test_fallback(self):
def test_fallback(self) -> None:
"Test that searching public rooms over federation falls back if it gets a 404"
# The `get_public_rooms` should be called again if the first call fails
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = (
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(404, "Not Found", b""),
defer.succeed({}),
)
@@ -1466,7 +1469,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.result)
self.federation_client.get_public_rooms.assert_has_calls(
self.federation_client.get_public_rooms.assert_has_calls( # type: ignore[attr-defined]
[
call(
"testserv",
@@ -1497,14 +1500,14 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
profile.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
config["allow_per_room_profiles"] = False
self.hs = self.setup_test_homeserver(config=config)
return self.hs
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
@@ -1522,7 +1525,7 @@ class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_per_room_profile_forbidden(self):
def test_per_room_profile_forbidden(self) -> None:
data = {"membership": "join", "displayname": "other test user"}
request_data = json.dumps(data)
channel = self.make_request(
@@ -1557,7 +1560,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.creator = self.register_user("creator", "test")
self.creator_tok = self.login("creator", "test")
@@ -1566,7 +1569,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok)
def test_join_reason(self):
def test_join_reason(self) -> None:
reason = "hello"
channel = self.make_request(
"POST",
@@ -1578,7 +1581,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
def test_leave_reason(self):
def test_leave_reason(self) -> None:
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
@@ -1592,7 +1595,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
def test_kick_reason(self):
def test_kick_reason(self) -> None:
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
@@ -1606,7 +1609,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
def test_ban_reason(self):
def test_ban_reason(self) -> None:
self.helper.join(self.room_id, user=self.second_user_id, tok=self.second_tok)
reason = "hello"
@@ -1620,7 +1623,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
def test_unban_reason(self):
def test_unban_reason(self) -> None:
reason = "hello"
channel = self.make_request(
"POST",
@@ -1632,7 +1635,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
def test_invite_reason(self):
def test_invite_reason(self) -> None:
reason = "hello"
channel = self.make_request(
"POST",
@@ -1644,7 +1647,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
def test_reject_invite_reason(self):
def test_reject_invite_reason(self) -> None:
self.helper.invite(
self.room_id,
src=self.creator,
@@ -1663,7 +1666,7 @@ class RoomMembershipReasonTestCase(unittest.HomeserverTestCase):
self._check_for_reason(reason)
def _check_for_reason(self, reason):
def _check_for_reason(self, reason: str) -> None:
channel = self.make_request(
"GET",
"/_matrix/client/r0/rooms/{}/state/m.room.member/{}".format(
@@ -1704,12 +1707,12 @@ class LabelsTestCase(unittest.HomeserverTestCase):
"org.matrix.not_labels": ["#notfun"],
}
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_context_filter_labels(self):
def test_context_filter_labels(self) -> None:
"""Test that we can filter by a label on a /context request."""
event_id = self._send_labelled_messages_in_room()
@@ -1739,7 +1742,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
events_after[0]["content"]["body"], "with right label", events_after[0]
)
def test_context_filter_not_labels(self):
def test_context_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /context request."""
event_id = self._send_labelled_messages_in_room()
@@ -1772,7 +1775,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
events_after[1]["content"]["body"], "with two wrong labels", events_after[1]
)
def test_context_filter_labels_not_labels(self):
def test_context_filter_labels_not_labels(self) -> None:
"""Test that we can filter by both a label and the absence of another label on a
/context request.
"""
@@ -1801,7 +1804,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
events_after[0]["content"]["body"], "with wrong label", events_after[0]
)
def test_messages_filter_labels(self):
def test_messages_filter_labels(self) -> None:
"""Test that we can filter by a label on a /messages request."""
self._send_labelled_messages_in_room()
@@ -1818,7 +1821,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
def test_messages_filter_not_labels(self):
def test_messages_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /messages request."""
self._send_labelled_messages_in_room()
@@ -1839,7 +1842,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
events[3]["content"]["body"], "with two wrong labels", events[3]
)
def test_messages_filter_labels_not_labels(self):
def test_messages_filter_labels_not_labels(self) -> None:
"""Test that we can filter by both a label and the absence of another label on a
/messages request.
"""
@@ -1862,7 +1865,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(events), 1, [event["content"] for event in events])
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
def test_search_filter_labels(self):
def test_search_filter_labels(self) -> None:
"""Test that we can filter by a label on a /search request."""
request_data = json.dumps(
{
@@ -1899,7 +1902,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results[1]["result"]["content"]["body"],
)
def test_search_filter_not_labels(self):
def test_search_filter_not_labels(self) -> None:
"""Test that we can filter by the absence of a label on a /search request."""
request_data = json.dumps(
{
@@ -1946,7 +1949,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results[3]["result"]["content"]["body"],
)
def test_search_filter_labels_not_labels(self):
def test_search_filter_labels_not_labels(self) -> None:
"""Test that we can filter by both a label and the absence of another label on a
/search request.
"""
@@ -1980,7 +1983,7 @@ class LabelsTestCase(unittest.HomeserverTestCase):
results[0]["result"]["content"]["body"],
)
def _send_labelled_messages_in_room(self):
def _send_labelled_messages_in_room(self) -> str:
"""Sends several messages to a room with different labels (or without any) to test
filtering by label.
Returns:
@@ -2056,12 +2059,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
login.register_servlets,
]
def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["experimental_features"] = {"msc3440_enabled": True}
return config
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("test", "test")
self.tok = self.login("test", "test")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
@@ -2136,7 +2139,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
return channel.json_body["chunk"]
def test_filter_relation_senders(self):
def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to.
filter = {"io.element.relation_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
@@ -2159,7 +2162,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
)
def test_filter_relation_type(self):
def test_filter_relation_type(self) -> None:
# Messages which have annotations.
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
@@ -2185,7 +2188,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
[c["event_id"] for c in chunk], [self.event_id_1, self.event_id_2]
)
def test_filter_relation_senders_and_type(self):
def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to.
filter = {
"io.element.relation_senders": [self.second_user_id],
@@ -2205,7 +2208,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
account.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("user", "password")
self.tok = self.login("user", "password")
self.room_id = self.helper.create_room_as(
@@ -2218,7 +2221,7 @@ class ContextTestCase(unittest.HomeserverTestCase):
self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok)
self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok)
def test_erased_sender(self):
def test_erased_sender(self) -> None:
"""Test that an erasure request results in the requester's events being hidden
from any new member of the room.
"""
@@ -2332,7 +2335,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_owner = self.register_user("room_owner", "test")
self.room_owner_tok = self.login("room_owner", "test")
@@ -2340,17 +2343,17 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
self.room_owner, tok=self.room_owner_tok
)
def test_no_aliases(self):
def test_no_aliases(self) -> None:
res = self._get_aliases(self.room_owner_tok)
self.assertEqual(res["aliases"], [])
def test_not_in_room(self):
def test_not_in_room(self) -> None:
self.register_user("user", "test")
user_tok = self.login("user", "test")
res = self._get_aliases(user_tok, expected_code=403)
self.assertEqual(res["errcode"], "M_FORBIDDEN")
def test_admin_user(self):
def test_admin_user(self) -> None:
alias1 = self._random_alias()
self._set_alias_via_directory(alias1)
@@ -2360,7 +2363,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
res = self._get_aliases(user_tok)
self.assertEqual(res["aliases"], [alias1])
def test_with_aliases(self):
def test_with_aliases(self) -> None:
alias1 = self._random_alias()
alias2 = self._random_alias()
@@ -2370,7 +2373,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
res = self._get_aliases(self.room_owner_tok)
self.assertEqual(set(res["aliases"]), {alias1, alias2})
def test_peekable_room(self):
def test_peekable_room(self) -> None:
alias1 = self._random_alias()
self._set_alias_via_directory(alias1)
@@ -2404,7 +2407,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
def _random_alias(self) -> str:
return RoomAlias(random_string(5), self.hs.hostname).to_string()
def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias
data = {"room_id": self.room_id}
request_data = json.dumps(data)
@@ -2423,7 +2426,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_owner = self.register_user("room_owner", "test")
self.room_owner_tok = self.login("room_owner", "test")
@@ -2434,7 +2437,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self.alias = "#alias:test"
self._set_alias_via_directory(self.alias)
def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
def _set_alias_via_directory(self, alias: str, expected_code: int = 200) -> None:
url = "/_matrix/client/r0/directory/room/" + alias
data = {"room_id": self.room_id}
request_data = json.dumps(data)
@@ -2456,7 +2459,9 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertIsInstance(res, dict)
return res
def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
def _set_canonical_alias(
self, content: JsonDict, expected_code: int = 200
) -> JsonDict:
"""Calls the endpoint under test. returns the json response object."""
channel = self.make_request(
"PUT",
@@ -2469,7 +2474,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self.assertIsInstance(res, dict)
return res
def test_canonical_alias(self):
def test_canonical_alias(self) -> None:
"""Test a basic alias message."""
# There is no canonical alias to start with.
self._get_canonical_alias(expected_code=404)
@@ -2488,7 +2493,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res = self._get_canonical_alias()
self.assertEqual(res, {})
def test_alt_aliases(self):
def test_alt_aliases(self) -> None:
"""Test a canonical alias message with alt_aliases."""
# Create an alias.
self._set_canonical_alias({"alt_aliases": [self.alias]})
@@ -2504,7 +2509,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res = self._get_canonical_alias()
self.assertEqual(res, {})
def test_alias_alt_aliases(self):
def test_alias_alt_aliases(self) -> None:
"""Test a canonical alias message with an alias and alt_aliases."""
# Create an alias.
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
@@ -2520,7 +2525,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res = self._get_canonical_alias()
self.assertEqual(res, {})
def test_partial_modify(self):
def test_partial_modify(self) -> None:
"""Test removing only the alt_aliases."""
# Create an alias.
self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
@@ -2536,7 +2541,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res = self._get_canonical_alias()
self.assertEqual(res, {"alias": self.alias})
def test_add_alias(self):
def test_add_alias(self) -> None:
"""Test removing only the alt_aliases."""
# Create an additional alias.
second_alias = "#second:test"
@@ -2556,7 +2561,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
)
def test_bad_data(self):
def test_bad_data(self) -> None:
"""Invalid data for alt_aliases should cause errors."""
self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
@@ -2566,7 +2571,7 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
def test_bad_alias(self):
def test_bad_alias(self) -> None:
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
@@ -2580,13 +2585,13 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
room.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = self.register_user("thomas", "hackme")
self.tok = self.login("thomas", "hackme")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
def test_threepid_invite_spamcheck(self):
def test_threepid_invite_spamcheck(self) -> None:
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for _mock_make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
+69 -39
View File
@@ -12,16 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
from typing import TYPE_CHECKING, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, LoginType, Membership
from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
from synapse.rest import admin
from synapse.rest.client import account, login, profile, room
from synapse.server import HomeServer
from synapse.types import JsonDict, Requester, StateMap
from synapse.util import Clock
from synapse.util.frozenutils import unfreeze
from tests import unittest
@@ -34,7 +40,7 @@ thread_local = threading.local()
class LegacyThirdPartyRulesTestModule:
def __init__(self, config: Dict, module_api: "ModuleApi"):
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
# keep a record of the "current" rules module, so that the test can patch
# it if desired.
thread_local.rules_module = self
@@ -42,32 +48,36 @@ class LegacyThirdPartyRulesTestModule:
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
):
) -> bool:
return True
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
async def check_event_allowed(
self, event: EventBase, state: StateMap[EventBase]
) -> Union[bool, dict]:
return True
@staticmethod
def parse_config(config):
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
return config
class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
def __init__(self, config: Dict, module_api: "ModuleApi"):
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
super().__init__(config, module_api)
def on_create_room(
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
):
) -> bool:
return False
class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
def __init__(self, config: Dict, module_api: "ModuleApi"):
def __init__(self, config: Dict, module_api: "ModuleApi") -> None:
super().__init__(config, module_api)
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
async def check_event_allowed(
self, event: EventBase, state: StateMap[EventBase]
) -> JsonDict:
d = event.get_dict()
content = unfreeze(event.content)
content["foo"] = "bar"
@@ -84,7 +94,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
account.register_servlets,
]
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver()
load_legacy_third_party_event_rules(hs)
@@ -94,22 +104,30 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Note that these checks are not relevant to this test case.
# Have this homeserver auto-approve all event signature checking.
async def approve_all_signature_checking(_, pdu):
async def approve_all_signature_checking(
_: RoomVersion, pdu: EventBase
) -> EventBase:
return pdu
hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking
hs.get_federation_server()._check_sigs_and_hash = approve_all_signature_checking # type: ignore[assignment]
# Have this homeserver skip event auth checks. This is necessary due to
# event auth checks ensuring that events were signed by the sender's homeserver.
async def _check_event_auth(origin, event, context, *args, **kwargs):
async def _check_event_auth(
origin: str,
event: EventBase,
context: EventContext,
*args: Any,
**kwargs: Any,
) -> EventContext:
return context
hs.get_federation_event_handler()._check_event_auth = _check_event_auth
hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]
return hs
def prepare(self, reactor, clock, homeserver):
super().prepare(reactor, clock, homeserver)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super().prepare(reactor, clock, hs)
# Create some users and a room to play with during the tests
self.user_id = self.register_user("kermit", "monkey")
self.invitee = self.register_user("invitee", "hackme")
@@ -121,13 +139,15 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
except Exception:
pass
def test_third_party_rules(self):
def test_third_party_rules(self) -> None:
"""Tests that a forbidden event is forbidden from being sent, but an allowed one
can be sent.
"""
# patch the rules module with a Mock which will return False for some event
# types
async def check(ev, state):
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
return ev.type != "foo.bar.forbidden", None
callback = Mock(spec=[], side_effect=check)
@@ -161,7 +181,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
)
self.assertEqual(channel.result["code"], b"403", channel.result)
def test_third_party_rules_workaround_synapse_errors_pass_through(self):
def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
"""
Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042
is functional: that SynapseErrors are passed through from check_event_allowed
@@ -172,7 +192,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
"""
class NastyHackException(SynapseError):
def error_dict(self):
def error_dict(self) -> JsonDict:
"""
This overrides SynapseError's `error_dict` to nastily inject
JSON into the error response.
@@ -182,7 +202,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
return result
# add a callback that will raise our hacky exception
async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]:
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
raise NastyHackException(429, "message")
self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
@@ -202,11 +224,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
{"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"},
)
def test_cannot_modify_event(self):
def test_cannot_modify_event(self) -> None:
"""cannot accidentally modify an event before it is persisted"""
# first patch the event checker so that it will try to modify the event
async def check(ev: EventBase, state):
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
ev.content = {"x": "y"}
return True, None
@@ -223,10 +247,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# 500 Internal Server Error
self.assertEqual(channel.code, 500, channel.result)
def test_modify_event(self):
def test_modify_event(self) -> None:
"""The module can return a modified version of the event"""
# first patch the event checker so that it will modify the event
async def check(ev: EventBase, state):
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
d = ev.get_dict()
d["content"] = {"x": "y"}
return True, d
@@ -253,10 +279,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y")
def test_message_edit(self):
def test_message_edit(self) -> None:
"""Ensure that the module doesn't cause issues with edited messages."""
# first patch the event checker so that it will modify the event
async def check(ev: EventBase, state):
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
d = ev.get_dict()
d["content"] = {
"msgtype": "m.text",
@@ -315,7 +343,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev = channel.json_body
self.assertEqual(ev["content"]["body"], "EDITED BODY")
def test_send_event(self):
def test_send_event(self) -> None:
"""Tests that a module can send an event into a room via the module api"""
content = {
"msgtype": "m.text",
@@ -344,7 +372,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}
}
)
def test_legacy_check_event_allowed(self):
def test_legacy_check_event_allowed(self) -> None:
"""Tests that the wrapper for legacy check_event_allowed callbacks works
correctly.
"""
@@ -379,13 +407,13 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
}
}
)
def test_legacy_on_create_room(self):
def test_legacy_on_create_room(self) -> None:
"""Tests that the wrapper for legacy on_create_room callbacks works
correctly.
"""
self.helper.create_room_as(self.user_id, tok=self.tok, expect_code=403)
def test_sent_event_end_up_in_room_state(self):
def test_sent_event_end_up_in_room_state(self) -> None:
"""Tests that a state event sent by a module while processing another state event
doesn't get dropped from the state of the room. This is to guard against a bug
where Synapse has been observed doing so, see https://github.com/matrix-org/synapse/issues/10830
@@ -400,7 +428,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
api = self.hs.get_module_api()
# Define a callback that sends a custom event on power levels update.
async def test_fn(event: EventBase, state_events):
async def test_fn(
event: EventBase, state_events: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict]]:
if event.is_state and event.type == EventTypes.PowerLevels:
await api.create_and_send_event_into_room(
{
@@ -436,7 +466,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["i"], i)
def test_on_new_event(self):
def test_on_new_event(self) -> None:
"""Test that the on_new_event callback is called on new events"""
on_new_event = Mock(make_awaitable(None))
self.hs.get_third_party_event_rules()._on_new_event_callbacks.append(
@@ -501,7 +531,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
def _update_power_levels(self, event_default: int = 0):
def _update_power_levels(self, event_default: int = 0) -> None:
"""Updates the room's power levels.
Args:
@@ -533,7 +563,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
tok=self.tok,
)
def test_on_profile_update(self):
def test_on_profile_update(self) -> None:
"""Tests that the on_profile_update module callback is correctly called on
profile updates.
"""
@@ -592,7 +622,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(profile_info.display_name, displayname)
self.assertEqual(profile_info.avatar_url, avatar_url)
def test_on_profile_update_admin(self):
def test_on_profile_update_admin(self) -> None:
"""Tests that the on_profile_update module callback is correctly called on
profile updates triggered by a server admin.
"""
@@ -634,7 +664,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(profile_info.display_name, displayname)
self.assertEqual(profile_info.avatar_url, avatar_url)
def test_on_user_deactivation_status_changed(self):
def test_on_user_deactivation_status_changed(self) -> None:
"""Tests that the on_user_deactivation_status_changed module callback is called
correctly when processing a user's deactivation.
"""
@@ -691,7 +721,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
args = profile_mock.call_args[0]
self.assertTrue(args[3])
def test_on_user_deactivation_status_changed_admin(self):
def test_on_user_deactivation_status_changed_admin(self) -> None:
"""Tests that the on_user_deactivation_status_changed module callback is called
correctly when processing a user's deactivation triggered by a server admin as
well as a reactivation.
+9 -30
View File
@@ -15,10 +15,12 @@
"""Tests REST events for /rooms paths."""
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest.client import room
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
@@ -33,40 +35,17 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
user = UserID.from_string(user_id)
servlets = [room.register_servlets]
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
"red",
federation_http_client=None,
federation_client=Mock(),
)
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("red")
self.event_source = hs.get_event_sources().sources.typing
hs.get_federation_handler = Mock()
async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
"is_guest": False,
}
hs.get_auth().get_user_by_access_token = get_user_by_access_token
async def _insert_client_ip(*args, **kwargs):
return None
hs.get_datastores().main.insert_client_ip = _insert_client_ip
return hs
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.room_id = self.helper.create_room_as(self.user_id)
# Need another user to make notifications actually work
self.helper.join(self.room_id, user="@jim:red")
def test_set_typing(self):
def test_set_typing(self) -> None:
channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
@@ -95,7 +74,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
],
)
def test_set_not_typing(self):
def test_set_not_typing(self) -> None:
channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
@@ -103,7 +82,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(200, channel.code)
def test_typing_timeout(self):
def test_typing_timeout(self) -> None:
channel = self.make_request(
"PUT",
"/rooms/%s/typing/%s" % (self.room_id, self.user_id),
+29 -15
View File
@@ -13,19 +13,24 @@
# limitations under the License.
import urllib.parse
from io import BytesIO, StringIO
from typing import Any, Dict, Optional, Union
from unittest.mock import Mock
import signedjson.key
from canonicaljson import encode_canonical_json
from nacl.signing import SigningKey
from signedjson.sign import sign_json
from signedjson.types import SigningKey
from twisted.web.resource import NoResource
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import NoResource, Resource
from synapse.crypto.keyring import PerspectivesKeyFetcher
from synapse.http.site import SynapseRequest
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.stringutils import random_string
@@ -35,11 +40,11 @@ from tests.utils import default_config
class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock()
return self.setup_test_homeserver(federation_http_client=self.http_client)
def create_test_resource(self):
def create_test_resource(self) -> Resource:
return create_resource_tree(
{"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
)
@@ -51,7 +56,12 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
Tell the mock http client to expect an outgoing GET request for the given key
"""
async def get_json(destination, path, ignore_backoff=False, **kwargs):
async def get_json(
destination: str,
path: str,
ignore_backoff: bool = False,
**kwargs: Any,
) -> Union[JsonDict, list]:
self.assertTrue(ignore_backoff)
self.assertEqual(destination, server_name)
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
@@ -84,7 +94,8 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
Checks that the response is a 200 and returns the decoded json body.
"""
channel = FakeChannel(self.site, self.reactor)
req = SynapseRequest(channel, self.site)
# channel is a `FakeChannel` but `HTTPChannel` is expected
req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
req.content = BytesIO(b"")
req.requestReceived(
b"GET",
@@ -97,7 +108,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
resp = channel.json_body
return resp
def test_get_key(self):
def test_get_key(self) -> None:
"""Fetch a remote key"""
SERVER_NAME = "remote.server"
testkey = signedjson.key.generate_signing_key("ver1")
@@ -114,7 +125,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
self.assertIn(SERVER_NAME, keys[0]["signatures"])
self.assertIn(self.hs.hostname, keys[0]["signatures"])
def test_get_own_key(self):
def test_get_own_key(self) -> None:
"""Fetch our own key"""
testkey = signedjson.key.generate_signing_key("ver1")
self.expect_outgoing_key_request(self.hs.hostname, testkey)
@@ -141,7 +152,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
endpoint, to check that the two implementations are compatible.
"""
def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()
# replace the signing key with our own
@@ -152,7 +163,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
return config
def prepare(self, reactor, clock, homeserver):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# make a second homeserver, configured to use the first one as a key notary
self.http_client2 = Mock()
config = default_config(name="keyclient")
@@ -175,7 +186,9 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
# wire up outbound POST /key/v2/query requests from hs2 so that they
# will be forwarded to hs1
async def post_json(destination, path, data):
async def post_json(
destination: str, path: str, data: Optional[JsonDict] = None
) -> Union[JsonDict, list]:
self.assertEqual(destination, self.hs.hostname)
self.assertEqual(
path,
@@ -183,7 +196,8 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
)
channel = FakeChannel(self.site, self.reactor)
req = SynapseRequest(channel, self.site)
# channel is a `FakeChannel` but `HTTPChannel` is expected
req = SynapseRequest(channel, self.site) # type: ignore[arg-type]
req.content = BytesIO(encode_canonical_json(data))
req.requestReceived(
@@ -198,7 +212,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
self.http_client2.post_json.side_effect = post_json
def test_get_key(self):
def test_get_key(self) -> None:
"""Fetch a key belonging to a random server"""
# make up a key to be fetched.
testkey = signedjson.key.generate_signing_key("abc")
@@ -218,7 +232,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
signedjson.key.encode_verify_key_base64(testkey.verify_key),
)
def test_get_notary_key(self):
def test_get_notary_key(self) -> None:
"""Fetch a key belonging to the notary server"""
# make up a key to be fetched. We randomise the keyid to try to get it to
# appear before the key server signing key sometimes (otherwise we bail out
@@ -240,7 +254,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
signedjson.key.encode_verify_key_base64(testkey.verify_key),
)
def test_get_notary_keyserver_key(self):
def test_get_notary_keyserver_key(self) -> None:
"""Fetch the notary's keyserver key"""
# we expect hs1 to make a regular key request to itself
self.expect_outgoing_key_request(self.hs.hostname, self.hs_signing_key)
+2 -2
View File
@@ -28,11 +28,11 @@ class GetFileNameFromHeadersTests(unittest.TestCase):
b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar",
}
def tests(self):
def tests(self) -> None:
for hdr, expected in self.TEST_CASES.items():
res = get_filename_from_headers({b"Content-Disposition": [hdr]})
self.assertEqual(
res,
expected,
"expected output for %s to be %s but was %s" % (hdr, expected, res),
f"expected output for {hdr!r} to be {expected} but was {res}",
)
+24 -24
View File
@@ -21,12 +21,12 @@ from tests import unittest
class MediaFilePathsTestCase(unittest.TestCase):
def setUp(self):
def setUp(self) -> None:
super().setUp()
self.filepaths = MediaFilePaths("/media_store")
def test_local_media_filepath(self):
def test_local_media_filepath(self) -> None:
"""Test local media paths"""
self.assertEqual(
self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@@ -37,7 +37,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
def test_local_media_thumbnail(self):
def test_local_media_thumbnail(self) -> None:
"""Test local media thumbnail paths"""
self.assertEqual(
self.filepaths.local_media_thumbnail_rel(
@@ -52,14 +52,14 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
)
def test_local_media_thumbnail_dir(self):
def test_local_media_thumbnail_dir(self) -> None:
"""Test local media thumbnail directory paths"""
self.assertEqual(
self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"),
"/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
def test_remote_media_filepath(self):
def test_remote_media_filepath(self) -> None:
"""Test remote media paths"""
self.assertEqual(
self.filepaths.remote_media_filepath_rel(
@@ -74,7 +74,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
def test_remote_media_thumbnail(self):
def test_remote_media_thumbnail(self) -> None:
"""Test remote media thumbnail paths"""
self.assertEqual(
self.filepaths.remote_media_thumbnail_rel(
@@ -99,7 +99,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
)
def test_remote_media_thumbnail_legacy(self):
def test_remote_media_thumbnail_legacy(self) -> None:
"""Test old-style remote media thumbnail paths"""
self.assertEqual(
self.filepaths.remote_media_thumbnail_rel_legacy(
@@ -108,7 +108,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg",
)
def test_remote_media_thumbnail_dir(self):
def test_remote_media_thumbnail_dir(self) -> None:
"""Test remote media thumbnail directory paths"""
self.assertEqual(
self.filepaths.remote_media_thumbnail_dir(
@@ -117,7 +117,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
def test_url_cache_filepath(self):
def test_url_cache_filepath(self) -> None:
"""Test URL cache paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"),
@@ -128,7 +128,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar",
)
def test_url_cache_filepath_legacy(self):
def test_url_cache_filepath_legacy(self) -> None:
"""Test old-style URL cache paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
@@ -139,7 +139,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
def test_url_cache_filepath_dirs_to_delete(self):
def test_url_cache_filepath_dirs_to_delete(self) -> None:
"""Test URL cache cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_dirs_to_delete(
@@ -148,7 +148,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
["/media_store/url_cache/2020-01-02"],
)
def test_url_cache_filepath_dirs_to_delete_legacy(self):
def test_url_cache_filepath_dirs_to_delete_legacy(self) -> None:
"""Test old-style URL cache cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_filepath_dirs_to_delete(
@@ -160,7 +160,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
def test_url_cache_thumbnail(self):
def test_url_cache_thumbnail(self) -> None:
"""Test URL cache thumbnail paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_rel(
@@ -175,7 +175,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
)
def test_url_cache_thumbnail_legacy(self):
def test_url_cache_thumbnail_legacy(self) -> None:
"""Test old-style URL cache thumbnail paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_rel(
@@ -190,7 +190,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
)
def test_url_cache_thumbnail_directory(self):
def test_url_cache_thumbnail_directory(self) -> None:
"""Test URL cache thumbnail directory paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_directory_rel(
@@ -203,7 +203,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
)
def test_url_cache_thumbnail_directory_legacy(self):
def test_url_cache_thumbnail_directory_legacy(self) -> None:
"""Test old-style URL cache thumbnail directory paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_directory_rel(
@@ -216,7 +216,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
"/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
)
def test_url_cache_thumbnail_dirs_to_delete(self):
def test_url_cache_thumbnail_dirs_to_delete(self) -> None:
"""Test URL cache thumbnail cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_dirs_to_delete(
@@ -228,7 +228,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
def test_url_cache_thumbnail_dirs_to_delete_legacy(self):
def test_url_cache_thumbnail_dirs_to_delete_legacy(self) -> None:
"""Test old-style URL cache thumbnail cleanup paths"""
self.assertEqual(
self.filepaths.url_cache_thumbnail_dirs_to_delete(
@@ -241,7 +241,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
def test_server_name_validation(self):
def test_server_name_validation(self) -> None:
"""Test validation of server names"""
self._test_path_validation(
[
@@ -274,7 +274,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
def test_file_id_validation(self):
def test_file_id_validation(self) -> None:
"""Test validation of local, remote and legacy URL cache file / media IDs"""
# File / media IDs get split into three parts to form paths, consisting of the
# first two characters, next two characters and rest of the ID.
@@ -357,7 +357,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
invalid_values=invalid_file_ids,
)
def test_url_cache_media_id_validation(self):
def test_url_cache_media_id_validation(self) -> None:
"""Test validation of URL cache media IDs"""
self._test_path_validation(
[
@@ -387,7 +387,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
def test_content_type_validation(self):
def test_content_type_validation(self) -> None:
"""Test validation of thumbnail content types"""
self._test_path_validation(
[
@@ -410,7 +410,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
],
)
def test_thumbnail_method_validation(self):
def test_thumbnail_method_validation(self) -> None:
"""Test validation of thumbnail methods"""
self._test_path_validation(
[
@@ -440,7 +440,7 @@ class MediaFilePathsTestCase(unittest.TestCase):
parameter: str,
valid_values: Iterable[str],
invalid_values: Iterable[str],
):
) -> None:
"""Test that the specified methods validate the named parameter as expected
Args:
+27 -27
View File
@@ -32,7 +32,7 @@ class SummarizeTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
def test_long_summarize(self):
def test_long_summarize(self) -> None:
example_paras = [
"""Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:
Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in
@@ -90,7 +90,7 @@ class SummarizeTestCase(unittest.TestCase):
" Tromsøya had a population of 36,088. Substantial parts of the urban…",
)
def test_short_summarize(self):
def test_short_summarize(self) -> None:
example_paras = [
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -117,7 +117,7 @@ class SummarizeTestCase(unittest.TestCase):
" most of the year.",
)
def test_small_then_large_summarize(self):
def test_small_then_large_summarize(self) -> None:
example_paras = [
"Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:"
" Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in"
@@ -150,7 +150,7 @@ class CalcOgTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
def test_simple(self):
def test_simple(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -165,7 +165,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment(self):
def test_comment(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -181,7 +181,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_comment2(self):
def test_comment2(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -206,7 +206,7 @@ class CalcOgTestCase(unittest.TestCase):
},
)
def test_script(self):
def test_script(self) -> None:
html = b"""
<html>
<head><title>Foo</title></head>
@@ -222,7 +222,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_missing_title(self):
def test_missing_title(self) -> None:
html = b"""
<html>
<body>
@@ -236,7 +236,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_h1_as_title(self):
def test_h1_as_title(self) -> None:
html = b"""
<html>
<meta property="og:description" content="Some text."/>
@@ -251,7 +251,7 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
def test_missing_title_and_broken_h1(self):
def test_missing_title_and_broken_h1(self) -> None:
html = b"""
<html>
<body>
@@ -266,19 +266,19 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
def test_empty(self):
def test_empty(self) -> None:
"""Test a body with no data in it."""
html = b""
tree = decode_body(html, "http://example.com/test.html")
self.assertIsNone(tree)
def test_no_tree(self):
def test_no_tree(self) -> None:
"""A valid body with no tree in it."""
html = b"\x00"
tree = decode_body(html, "http://example.com/test.html")
self.assertIsNone(tree)
def test_xml(self):
def test_xml(self) -> None:
"""Test decoding XML and ensure it works properly."""
# Note that the strip() call is important to ensure the xml tag starts
# at the initial byte.
@@ -293,7 +293,7 @@ class CalcOgTestCase(unittest.TestCase):
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding(self):
def test_invalid_encoding(self) -> None:
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
html = b"""
<html>
@@ -307,7 +307,7 @@ class CalcOgTestCase(unittest.TestCase):
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding2(self):
def test_invalid_encoding2(self) -> None:
"""A body which doesn't match the sent character encoding."""
# Note that this contains an invalid UTF-8 sequence in the title.
html = b"""
@@ -322,7 +322,7 @@ class CalcOgTestCase(unittest.TestCase):
og = parse_html_to_open_graph(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
def test_windows_1252(self):
def test_windows_1252(self) -> None:
"""A body which uses cp1252, but doesn't declare that."""
html = b"""
<html>
@@ -338,7 +338,7 @@ class CalcOgTestCase(unittest.TestCase):
class MediaEncodingTestCase(unittest.TestCase):
def test_meta_charset(self):
def test_meta_charset(self) -> None:
"""A character encoding is found via the meta tag."""
encodings = _get_html_media_encodings(
b"""
@@ -363,7 +363,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
def test_meta_charset_underscores(self):
def test_meta_charset_underscores(self) -> None:
"""A character encoding contains underscore."""
encodings = _get_html_media_encodings(
b"""
@@ -376,7 +376,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"])
def test_xml_encoding(self):
def test_xml_encoding(self) -> None:
"""A character encoding is found via the meta tag."""
encodings = _get_html_media_encodings(
b"""
@@ -388,7 +388,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
def test_meta_xml_encoding(self):
def test_meta_xml_encoding(self) -> None:
"""Meta tags take precedence over XML encoding."""
encodings = _get_html_media_encodings(
b"""
@@ -402,7 +402,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"])
def test_content_type(self):
def test_content_type(self) -> None:
"""A character encoding is found via the Content-Type header."""
# Test a few variations of the header.
headers = (
@@ -417,12 +417,12 @@ class MediaEncodingTestCase(unittest.TestCase):
encodings = _get_html_media_encodings(b"", header)
self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
def test_fallback(self):
def test_fallback(self) -> None:
"""A character encoding cannot be found in the body or header."""
encodings = _get_html_media_encodings(b"", "text/html")
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
def test_duplicates(self):
def test_duplicates(self) -> None:
"""Ensure each encoding is only attempted once."""
encodings = _get_html_media_encodings(
b"""
@@ -436,7 +436,7 @@ class MediaEncodingTestCase(unittest.TestCase):
)
self.assertEqual(list(encodings), ["utf-8", "cp1252"])
def test_unknown_invalid(self):
def test_unknown_invalid(self) -> None:
"""A character encoding should be ignored if it is unknown or invalid."""
encodings = _get_html_media_encodings(
b"""
@@ -451,7 +451,7 @@ class MediaEncodingTestCase(unittest.TestCase):
class RebaseUrlTestCase(unittest.TestCase):
def test_relative(self):
def test_relative(self) -> None:
"""Relative URLs should be resolved based on the context of the base URL."""
self.assertEqual(
rebase_url("subpage", "https://example.com/foo/"),
@@ -466,14 +466,14 @@ class RebaseUrlTestCase(unittest.TestCase):
"https://example.com/bar",
)
def test_absolute(self):
def test_absolute(self) -> None:
"""Absolute URLs should not be modified."""
self.assertEqual(
rebase_url("https://alice.com/a/", "https://example.com/foo/"),
"https://alice.com/a/",
)
def test_data(self):
def test_data(self) -> None:
"""Data URLs should not be modified."""
self.assertEqual(
rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"),
+5 -5
View File
@@ -16,7 +16,7 @@ import json
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest.media.v1.oembed import OEmbedProvider
from synapse.rest.media.v1.oembed import OEmbedProvider, OEmbedResult
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -25,15 +25,15 @@ from tests.unittest import HomeserverTestCase
class OEmbedTests(HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
self.oembed = OEmbedProvider(homeserver)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.oembed = OEmbedProvider(hs)
def parse_response(self, response: JsonDict):
def parse_response(self, response: JsonDict) -> OEmbedResult:
return self.oembed.parse_oembed_response(
"https://test", json.dumps(response).encode("utf-8")
)
def test_version(self):
def test_version(self) -> None:
"""Accept versions that are similar to 1.0 as a string or int (or missing)."""
for version in ("1.0", 1.0, 1):
result = self.parse_response({"version": version, "type": "link"})
+4 -4
View File
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from synapse.rest.health import HealthResource
@@ -19,12 +19,12 @@ from tests import unittest
class HealthCheckTests(unittest.HomeserverTestCase):
def create_test_resource(self):
def create_test_resource(self) -> HealthResource:
# replace the JsonResource with a HealthResource.
return HealthResource()
def test_health(self):
def test_health(self) -> None:
channel = self.make_request("GET", "/health", shorthand=False)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.result["body"], b"OK")
+11 -9
View File
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from twisted.web.resource import Resource
from synapse.rest.well_known import well_known_resource
@@ -19,7 +21,7 @@ from tests import unittest
class WellKnownTests(unittest.HomeserverTestCase):
def create_test_resource(self):
def create_test_resource(self) -> Resource:
# replace the JsonResource with a Resource wrapping the WellKnownResource
res = Resource()
res.putChild(b".well-known", well_known_resource(self.hs))
@@ -31,12 +33,12 @@ class WellKnownTests(unittest.HomeserverTestCase):
"default_identity_server": "https://testis",
}
)
def test_client_well_known(self):
def test_client_well_known(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(
channel.json_body,
{
@@ -50,27 +52,27 @@ class WellKnownTests(unittest.HomeserverTestCase):
"public_baseurl": None,
}
)
def test_client_well_known_no_public_baseurl(self):
def test_client_well_known_no_public_baseurl(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/client", shorthand=False
)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
@unittest.override_config({"serve_server_wellknown": True})
def test_server_well_known(self):
def test_server_well_known(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False
)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(
channel.json_body,
{"m.server": "test:443"},
)
def test_server_well_known_disabled(self):
def test_server_well_known_disabled(self) -> None:
channel = self.make_request(
"GET", "/.well-known/matrix/server", shorthand=False
)
self.assertEqual(channel.code, 404)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
+15 -6
View File
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
ApplicationServiceStore,
ApplicationServiceTransactionStore,
)
from synapse.types import DeviceLists
from synapse.util import Clock
from tests import unittest
@@ -267,7 +268,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
txn = self.get_success(
defer.ensureDeferred(
self.store.create_appservice_txn(service, events, [], [], {}, {})
self.store.create_appservice_txn(
service, events, [], [], {}, {}, DeviceLists()
)
)
)
self.assertEqual(txn.id, 1)
@@ -283,7 +286,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(service.id, 9644, events))
self.get_success(self._insert_txn(service.id, 9645, events))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [], {}, {})
self.store.create_appservice_txn(
service, events, [], [], {}, {}, DeviceLists()
)
)
self.assertEqual(txn.id, 9646)
self.assertEqual(txn.events, events)
@@ -296,7 +301,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
events = cast(List[EventBase], [Mock(event_id="e1"), Mock(event_id="e2")])
self.get_success(self._set_last_txn(service.id, 9643))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [], {}, {})
self.store.create_appservice_txn(
service, events, [], [], {}, {}, DeviceLists()
)
)
self.assertEqual(txn.id, 9644)
self.assertEqual(txn.events, events)
@@ -320,7 +327,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
self.get_success(self._insert_txn(self.as_list[3]["id"], 9643, events))
txn = self.get_success(
self.store.create_appservice_txn(service, events, [], [], {}, {})
self.store.create_appservice_txn(
service, events, [], [], {}, {}, DeviceLists()
)
)
self.assertEqual(txn.id, 9644)
self.assertEqual(txn.events, events)
@@ -476,12 +485,12 @@ class ApplicationServiceStoreTypeStreamIds(unittest.HomeserverTestCase):
value = self.get_success(
self.store.get_type_stream_id_for_appservice(self.service, "read_receipt")
)
self.assertEqual(value, 0)
self.assertEqual(value, 1)
value = self.get_success(
self.store.get_type_stream_id_for_appservice(self.service, "presence")
)
self.assertEqual(value, 0)
self.assertEqual(value, 1)
def test_get_type_stream_id_for_appservice_invalid_type(self) -> None:
self.get_failure(
+73 -3
View File
@@ -13,11 +13,12 @@
# limitations under the License.
import logging
from typing import Optional
from unittest.mock import patch
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.types import JsonDict
from synapse.visibility import filter_events_for_server
from synapse.events import EventBase, make_event_from_dict
from synapse.types import JsonDict, create_requester
from synapse.visibility import filter_events_for_client, filter_events_for_server
from tests import unittest
from tests.utils import create_room
@@ -185,3 +186,72 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
self.get_success(self.storage.persistence.persist_event(event, context))
return event
class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
def test_out_of_band_invite_rejection(self):
# this is where we have received an invite event over federation, and then
# rejected it.
invite_pdu = {
"room_id": "!room:id",
"depth": 1,
"auth_events": [],
"prev_events": [],
"origin_server_ts": 1,
"sender": "@someone:" + self.OTHER_SERVER_NAME,
"type": "m.room.member",
"state_key": "@user:test",
"content": {"membership": "invite"},
}
self.add_hashes_and_signatures(invite_pdu)
invite_event_id = make_event_from_dict(invite_pdu, RoomVersions.V9).event_id
self.get_success(
self.hs.get_federation_server().on_invite_request(
self.OTHER_SERVER_NAME,
invite_pdu,
"9",
)
)
# stub out do_remotely_reject_invite so that we fall back to a locally-
# generated rejection
with patch.object(
self.hs.get_federation_handler(),
"do_remotely_reject_invite",
side_effect=Exception(),
):
reject_event_id, _ = self.get_success(
self.hs.get_room_member_handler().remote_reject_invite(
invite_event_id,
txn_id=None,
requester=create_requester("@user:test"),
content={},
)
)
invite_event, reject_event = self.get_success(
self.hs.get_datastores().main.get_events_as_list(
[invite_event_id, reject_event_id]
)
)
# the invited user should be able to see both the invite and the rejection
self.assertEqual(
self.get_success(
filter_events_for_client(
self.hs.get_storage(), "@user:test", [invite_event, reject_event]
)
),
[invite_event, reject_event],
)
# other users should see neither
self.assertEqual(
self.get_success(
filter_events_for_client(
self.hs.get_storage(), "@other:test", [invite_event, reject_event]
)
),
[],
)
-36
View File
@@ -38,18 +38,8 @@ lint_targets =
setup.py
synapse
tests
scripts
# annoyingly, black doesn't find these so we have to list them
scripts/export_signing_key
scripts/generate_config
scripts/generate_log_config
scripts/hash_password
scripts/register_new_matrix_user
scripts/synapse_port_db
scripts/update_synapse_database
scripts-dev
scripts-dev/build_debian_packages
scripts-dev/sign_json
stubs
contrib
synctl
@@ -168,32 +158,6 @@ commands =
extras = lint
commands = isort -c --df {[base]lint_targets}
[testenv:combine]
skip_install = true
usedevelop = false
deps =
coverage
pip>=10
commands=
coverage combine
coverage report
[testenv:cov-erase]
skip_install = true
usedevelop = false
deps =
coverage
commands=
coverage erase
[testenv:cov-html]
skip_install = true
usedevelop = false
deps =
coverage
commands=
coverage html
[testenv:mypy]
deps =
{[base]deps}