1
0

Merge branch 'develop' into comp-worker-shorthand

This commit is contained in:
realtyem
2023-02-18 20:05:54 -06:00
committed by GitHub
130 changed files with 1326 additions and 688 deletions
+6
View File
@@ -1,3 +1,9 @@
Synapse 1.77.0 (2023-02-14)
===========================
No significant changes since 1.77.0rc2.
Synapse 1.77.0rc2 (2023-02-10)
==============================
+1
View File
@@ -0,0 +1 @@
Prevent clients from reporting nonexistent events.
+1
View File
@@ -0,0 +1 @@
Document how to start Synapse with Poetry. Contributed by @thezaidbintariq.
+1
View File
@@ -0,0 +1 @@
Add account data to the command line [user data export tool](https://matrix-org.github.io/synapse/v1.78/usage/administration/admin_faq.html#how-can-i-export-user-data).
+1
View File
@@ -0,0 +1 @@
Skip calculating unread push actions in /sync when enable_push is false.
+1
View File
@@ -0,0 +1 @@
Add a schema dump symlinks inside `contrib`, to make it easier for IDEs to interrogate Synapse's database schema.
+1
View File
@@ -0,0 +1 @@
Allow Synapse to use a specific Redis [logical database](https://redis.io/commands/select/) in worker-mode deployments.
+1
View File
@@ -0,0 +1 @@
Update [MSC3952](https://github.com/matrix-org/matrix-spec-proposals/pull/3952) support based on changes to the MSC.
+1
View File
@@ -0,0 +1 @@
Experimental support for [MSC3966](https://github.com/matrix-org/matrix-spec-proposals/pull/3966): the `exact_event_property_contains` push rule condition.
+1
View File
@@ -0,0 +1 @@
Improve type hints.
+1
View File
@@ -0,0 +1 @@
Fix a bug introduced in Synapse 1.76.0 where partially-joined rooms could not be deleted using the [purge room API](https://matrix-org.github.io/synapse/latest/admin_api/rooms.html#delete-room-api).
+1
View File
@@ -0,0 +1 @@
Faster joins: omit device list updates originating from partial state rooms in /sync responses without lazy loading of members enabled.
+1
View File
@@ -0,0 +1 @@
Fix clashing database transaction name.
+1
View File
@@ -0,0 +1 @@
Improve type hints.
+1
View File
@@ -0,0 +1 @@
Remove spurious `dont_notify` action from the defaults for the `.m.rule.reaction` pushrule.
+1
View File
@@ -0,0 +1 @@
Fix a long-standing bug where federated joins would fail if the first server in the list of servers to try is not in the room.
+2
View File
@@ -0,0 +1,2 @@
Update the error code returned when user sends a duplicate annotation.
+1
View File
@@ -0,0 +1 @@
Fix a mistake in registration_shared_secret_path docs.
+1
View File
@@ -0,0 +1 @@
Reduce the likelihood of a rare race condition where rejoining a restricted room over federation would fail.
+1
View File
@@ -0,0 +1 @@
Improve type hints.
+28
View File
@@ -0,0 +1,28 @@
# Schema symlinks
This directory contains symlinks to the latest dump of the postgres full schema. This is useful to have, as it allows IDEs to understand our schema and provide autocomplete, linters, inspections, etc.
In particular, the DataGrip functionality in IntelliJ's products seems to only consider files called `*.sql` when defining a schema from DDL; `*.sql.postgres` will be ignored. To get around this we symlink those files to ones ending in `.sql`. We've chosen to ignore the `.sql.sqlite` schema dumps here, as they're not intended for production use (and are much quicker to test against).
## Example
![](datagrip-aware-of-schema.png)
## Caveats
- Doesn't include temporary tables created ad-hoc by Synapse.
- Postgres only. IDEs will likely be confused by SQLite-specific queries.
- Will not include migrations created after the latest schema dump.
- Symlinks might confuse checkouts on Windows systems.
## Instructions
### Jetbrains IDEs with DataGrip plugin
- View -> Tool Windows -> Database
- `+` Icon -> DDL Data Source
- Pick a name, e.g. `Synapse schema dump`
- Under sources, click `+`.
- Add an entry with Path pointing to this directory, and dialect set to PostgreSQL.
- OK, and OK.
- IDE should now be aware of the schema.
- Try control-clicking on a table name in a bit of SQL e.g. in `_get_forgotten_rooms_for_user_txn`.
+1
View File
@@ -0,0 +1 @@
../../synapse/storage/schema/common/full_schemas/72/full.sql.postgres
Binary file not shown.

After

Width:  |  Height:  |  Size: 13 KiB

+1
View File
@@ -0,0 +1 @@
../../synapse/storage/schema/main/full_schemas/72/full.sql.postgres
+1
View File
@@ -0,0 +1 @@
../../synapse/storage/schema/common/schema_version.sql
+1
View File
@@ -0,0 +1 @@
../../synapse/storage/schema/state/full_schemas/72/full.sql.postgres
+1
View File
@@ -68,6 +68,7 @@ redis:
enabled: true
host: redis
port: 6379
# dbid: <redis_logical_db_id>
# password: <secret_password>
```
+6
View File
@@ -1,3 +1,9 @@
matrix-synapse-py3 (1.77.0) stable; urgency=medium
* New Synapse release 1.77.0.
-- Synapse Packaging team <packages@matrix.org> Tue, 14 Feb 2023 12:59:02 +0100
matrix-synapse-py3 (1.77.0~rc2) stable; urgency=medium
* New Synapse release 1.77.0rc2.
+3
View File
@@ -71,6 +71,9 @@ output-directory
│ ├───invite_state
│ └───knock_state
└───user_data
├───account_data
│ ├───global
│ └───<room_id>
├───connections
├───devices
└───profile
@@ -2232,7 +2232,7 @@ key on startup and store it in this file.
Example configuration:
```yaml
registration_shared_secret_file: /path/to/secrets/file
registration_shared_secret_path: /path/to/secrets/file
```
_Added in Synapse 1.67.0._
@@ -3927,6 +3927,9 @@ This setting has the following sub-options:
* `host` and `port`: Optional host and port to use to connect to redis. Defaults to
localhost and 6379
* `password`: Optional password if configured on the Redis instance.
* `dbid`: Optional redis dbid if needs to connect to specific redis logical db.
_Added in Synapse 1.78.0._
Example configuration:
```yaml
@@ -3935,6 +3938,7 @@ redis:
host: localhost
port: 6379
password: <secret_password>
dbid: <dbid>
```
---
## Individual worker configuration
+11
View File
@@ -160,7 +160,18 @@ recommend the use of `systemd` where available: for information on setting up
[Systemd with Workers](systemd-with-workers/). To use `synctl`, see
[Using synctl with Workers](synctl_workers.md).
## Start Synapse with Poetry
The following applies to Synapse installations that have been installed from source using `poetry`.
You can start the main Synapse process with Poetry by running the following command:
```console
poetry run synapse_homeserver -c [your homeserver.yaml]
```
For worker setups, you can run the following command
```console
poetry run synapse_worker -c [your worker.yaml]
```
## Available worker applications
### `synapse.app.generic_worker`
-5
View File
@@ -31,8 +31,6 @@ exclude = (?x)
|synapse/storage/databases/__init__.py
|synapse/storage/databases/main/cache.py
|synapse/storage/schema/
|tests/server.py
)$
[mypy-synapse.federation.transport.client]
@@ -56,9 +54,6 @@ disallow_untyped_defs = False
[mypy-synapse.storage.database]
disallow_untyped_defs = False
[mypy-tests.unittest]
disallow_untyped_defs = False
[mypy-tests.util.caches.test_descriptors]
disallow_untyped_defs = False
Generated
+36 -33
View File
@@ -146,14 +146,14 @@ css = ["tinycss2 (>=1.1.0,<1.2)"]
[[package]]
name = "canonicaljson"
version = "1.6.4"
version = "1.6.5"
description = "Canonical JSON"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "canonicaljson-1.6.4-py3-none-any.whl", hash = "sha256:55d282853b4245dbcd953fe54c39b91571813d7c44e1dbf66e3c4f97ff134a48"},
{file = "canonicaljson-1.6.4.tar.gz", hash = "sha256:6c09b2119511f30eb1126cfcd973a10824e20f1cfd25039cde3d1218dd9c8d8f"},
{file = "canonicaljson-1.6.5-py3-none-any.whl", hash = "sha256:806ea6f2cbb7405d20259e1c36dd1214ba5c242fa9165f5bd0bf2081f82c23fb"},
{file = "canonicaljson-1.6.5.tar.gz", hash = "sha256:68dfc157b011e07d94bf74b5d4ccc01958584ed942d9dfd5fdd706609e81cd4b"},
]
[package.dependencies]
@@ -1146,36 +1146,38 @@ files = [
[[package]]
name = "mypy"
version = "0.981"
version = "1.0.0"
description = "Optional static typing for Python"
category = "dev"
optional = false
python-versions = ">=3.7"
files = [
{file = "mypy-0.981-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:4bc460e43b7785f78862dab78674e62ec3cd523485baecfdf81a555ed29ecfa0"},
{file = "mypy-0.981-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:756fad8b263b3ba39e4e204ee53042671b660c36c9017412b43af210ddee7b08"},
{file = "mypy-0.981-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a16a0145d6d7d00fbede2da3a3096dcc9ecea091adfa8da48fa6a7b75d35562d"},
{file = "mypy-0.981-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce65f70b14a21fdac84c294cde75e6dbdabbcff22975335e20827b3b94bdbf49"},
{file = "mypy-0.981-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6e35d764784b42c3e256848fb8ed1d4292c9fc0098413adb28d84974c095b279"},
{file = "mypy-0.981-cp310-cp310-win_amd64.whl", hash = "sha256:e53773073c864d5f5cec7f3fc72fbbcef65410cde8cc18d4f7242dea60dac52e"},
{file = "mypy-0.981-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6ee196b1d10b8b215e835f438e06965d7a480f6fe016eddbc285f13955cca659"},
{file = "mypy-0.981-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ad21d4c9d3673726cf986ea1d0c9fb66905258709550ddf7944c8f885f208be"},
{file = "mypy-0.981-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d1debb09043e1f5ee845fa1e96d180e89115b30e47c5d3ce53bc967bab53f62d"},
{file = "mypy-0.981-cp37-cp37m-win_amd64.whl", hash = "sha256:9f362470a3480165c4c6151786b5379351b790d56952005be18bdbdd4c7ce0ae"},
{file = "mypy-0.981-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:c9e0efb95ed6ca1654951bd5ec2f3fa91b295d78bf6527e026529d4aaa1e0c30"},
{file = "mypy-0.981-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e178eaffc3c5cd211a87965c8c0df6da91ed7d258b5fc72b8e047c3771317ddb"},
{file = "mypy-0.981-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:06e1eac8d99bd404ed8dd34ca29673c4346e76dd8e612ea507763dccd7e13c7a"},
{file = "mypy-0.981-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fa38f82f53e1e7beb45557ff167c177802ba7b387ad017eab1663d567017c8ee"},
{file = "mypy-0.981-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:64e1f6af81c003f85f0dfed52db632817dabb51b65c0318ffbf5ff51995bbb08"},
{file = "mypy-0.981-cp38-cp38-win_amd64.whl", hash = "sha256:e1acf62a8c4f7c092462c738aa2c2489e275ed386320c10b2e9bff31f6f7e8d6"},
{file = "mypy-0.981-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:b6ede64e52257931315826fdbfc6ea878d89a965580d1a65638ef77cb551f56d"},
{file = "mypy-0.981-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eb3978b191b9fa0488524bb4ffedf2c573340e8c2b4206fc191d44c7093abfb7"},
{file = "mypy-0.981-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:77f8fcf7b4b3cc0c74fb33ae54a4cd00bb854d65645c48beccf65fa10b17882c"},
{file = "mypy-0.981-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f64d2ce043a209a297df322eb4054dfbaa9de9e8738291706eaafda81ab2b362"},
{file = "mypy-0.981-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:2ee3dbc53d4df7e6e3b1c68ac6a971d3a4fb2852bf10a05fda228721dd44fae1"},
{file = "mypy-0.981-cp39-cp39-win_amd64.whl", hash = "sha256:8e8e49aa9cc23aa4c926dc200ce32959d3501c4905147a66ce032f05cb5ecb92"},
{file = "mypy-0.981-py3-none-any.whl", hash = "sha256:794f385653e2b749387a42afb1e14c2135e18daeb027e0d97162e4b7031210f8"},
{file = "mypy-0.981.tar.gz", hash = "sha256:ad77c13037d3402fbeffda07d51e3f228ba078d1c7096a73759c9419ea031bf4"},
{file = "mypy-1.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:e0626db16705ab9f7fa6c249c017c887baf20738ce7f9129da162bb3075fc1af"},
{file = "mypy-1.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1ace23f6bb4aec4604b86c4843276e8fa548d667dbbd0cb83a3ae14b18b2db6c"},
{file = "mypy-1.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87edfaf344c9401942883fad030909116aa77b0fa7e6e8e1c5407e14549afe9a"},
{file = "mypy-1.0.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0ab090d9240d6b4e99e1fa998c2d0aa5b29fc0fb06bd30e7ad6183c95fa07593"},
{file = "mypy-1.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:7cc2c01dfc5a3cbddfa6c13f530ef3b95292f926329929001d45e124342cd6b7"},
{file = "mypy-1.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14d776869a3e6c89c17eb943100f7868f677703c8a4e00b3803918f86aafbc52"},
{file = "mypy-1.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bb2782a036d9eb6b5a6efcdda0986774bf798beef86a62da86cb73e2a10b423d"},
{file = "mypy-1.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cfca124f0ac6707747544c127880893ad72a656e136adc935c8600740b21ff5"},
{file = "mypy-1.0.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8845125d0b7c57838a10fd8925b0f5f709d0e08568ce587cc862aacce453e3dd"},
{file = "mypy-1.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b1b9e1ed40544ef486fa8ac022232ccc57109f379611633ede8e71630d07d2"},
{file = "mypy-1.0.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c7cf862aef988b5fbaa17764ad1d21b4831436701c7d2b653156a9497d92c83c"},
{file = "mypy-1.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cd187d92b6939617f1168a4fe68f68add749902c010e66fe574c165c742ed88"},
{file = "mypy-1.0.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:4e5175026618c178dfba6188228b845b64131034ab3ba52acaffa8f6c361f805"},
{file = "mypy-1.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2f6ac8c87e046dc18c7d1d7f6653a66787a4555085b056fe2d599f1f1a2a2d21"},
{file = "mypy-1.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7306edca1c6f1b5fa0bc9aa645e6ac8393014fa82d0fa180d0ebc990ebe15964"},
{file = "mypy-1.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3cfad08f16a9c6611e6143485a93de0e1e13f48cfb90bcad7d5fde1c0cec3d36"},
{file = "mypy-1.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:67cced7f15654710386e5c10b96608f1ee3d5c94ca1da5a2aad5889793a824c1"},
{file = "mypy-1.0.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:a86b794e8a56ada65c573183756eac8ac5b8d3d59daf9d5ebd72ecdbb7867a43"},
{file = "mypy-1.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:50979d5efff8d4135d9db293c6cb2c42260e70fb010cbc697b1311a4d7a39ddb"},
{file = "mypy-1.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:3ae4c7a99e5153496243146a3baf33b9beff714464ca386b5f62daad601d87af"},
{file = "mypy-1.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e398652d005a198a7f3c132426b33c6b85d98aa7dc852137a2a3be8890c4072"},
{file = "mypy-1.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be78077064d016bc1b639c2cbcc5be945b47b4261a4f4b7d8923f6c69c5c9457"},
{file = "mypy-1.0.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:92024447a339400ea00ac228369cd242e988dd775640755fa4ac0c126e49bb74"},
{file = "mypy-1.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:fe523fcbd52c05040c7bee370d66fee8373c5972171e4fbc323153433198592d"},
{file = "mypy-1.0.0-py3-none-any.whl", hash = "sha256:2efa963bdddb27cb4a0d42545cd137a8d2b883bd181bbc4525b568ef6eca258f"},
{file = "mypy-1.0.0.tar.gz", hash = "sha256:f34495079c8d9da05b183f9f7daec2878280c2ad7cc81da686ef0b484cea2ecf"},
]
[package.dependencies]
@@ -1186,6 +1188,7 @@ typing-extensions = ">=3.10"
[package.extras]
dmypy = ["psutil (>=4.0)"]
install-types = ["pip"]
python2 = ["typed-ast (>=1.4.0,<2)"]
reports = ["lxml"]
@@ -1203,18 +1206,18 @@ files = [
[[package]]
name = "mypy-zope"
version = "0.3.11"
version = "0.9.0"
description = "Plugin for mypy to support zope interfaces"
category = "dev"
optional = false
python-versions = "*"
files = [
{file = "mypy-zope-0.3.11.tar.gz", hash = "sha256:d4255f9f04d48c79083bbd4e2fea06513a6ac7b8de06f8c4ce563fd85142ca05"},
{file = "mypy_zope-0.3.11-py3-none-any.whl", hash = "sha256:ec080a6508d1f7805c8d2054f9fdd13c849742ce96803519e1fdfa3d3cab7140"},
{file = "mypy-zope-0.9.0.tar.gz", hash = "sha256:88bf6cd056e38b338e6956055958a7805b4ff84404ccd99e29883a3647a1aeb3"},
{file = "mypy_zope-0.9.0-py3-none-any.whl", hash = "sha256:e1bb4b57084f76ff8a154a3e07880a1af2ac6536c491dad4b143d529f72c5d15"},
]
[package.dependencies]
mypy = "0.981"
mypy = "1.0.0"
"zope.interface" = "*"
"zope.schema" = "*"
@@ -1705,7 +1708,7 @@ files = [
cffi = ">=1.4.1"
[package.extras]
docs = ["sphinx (>=1.6.5)", "sphinx-rtd-theme"]
docs = ["sphinx (>=1.6.5)", "sphinx_rtd_theme"]
tests = ["hypothesis (>=3.27.0)", "pytest (>=3.2.1,!=3.3.0)"]
[[package]]
+1 -1
View File
@@ -89,7 +89,7 @@ manifest-path = "rust/Cargo.toml"
[tool.poetry]
name = "matrix-synapse"
version = "1.77.0rc2"
version = "1.77.0"
description = "Homeserver for the Matrix decentralised comms protocol"
authors = ["Matrix.org Team and Contributors <packages@matrix.org>"]
license = "Apache-2.0"
+18 -18
View File
@@ -15,8 +15,8 @@
#![feature(test)]
use std::collections::BTreeSet;
use synapse::push::{
evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, PushRules,
SimpleJsonValue,
evaluator::PushRuleEvaluator, Condition, EventMatchCondition, FilteredPushRules, JsonValue,
PushRules, SimpleJsonValue,
};
use test::Bencher;
@@ -27,15 +27,15 @@ fn bench_match_exact(b: &mut Bencher) {
let flattened_keys = [
(
"type".to_string(),
SimpleJsonValue::Str("m.text".to_string()),
JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())),
),
(
"room_id".to_string(),
SimpleJsonValue::Str("!room:server".to_string()),
JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())),
),
(
"content.body".to_string(),
SimpleJsonValue::Str("test message".to_string()),
JsonValue::Value(SimpleJsonValue::Str("test message".to_string())),
),
]
.into_iter()
@@ -45,7 +45,6 @@ fn bench_match_exact(b: &mut Bencher) {
flattened_keys,
false,
BTreeSet::new(),
false,
10,
Some(0),
Default::default(),
@@ -54,6 +53,7 @@ fn bench_match_exact(b: &mut Bencher) {
vec![],
false,
false,
false,
)
.unwrap();
@@ -76,15 +76,15 @@ fn bench_match_word(b: &mut Bencher) {
let flattened_keys = [
(
"type".to_string(),
SimpleJsonValue::Str("m.text".to_string()),
JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())),
),
(
"room_id".to_string(),
SimpleJsonValue::Str("!room:server".to_string()),
JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())),
),
(
"content.body".to_string(),
SimpleJsonValue::Str("test message".to_string()),
JsonValue::Value(SimpleJsonValue::Str("test message".to_string())),
),
]
.into_iter()
@@ -94,7 +94,6 @@ fn bench_match_word(b: &mut Bencher) {
flattened_keys,
false,
BTreeSet::new(),
false,
10,
Some(0),
Default::default(),
@@ -103,6 +102,7 @@ fn bench_match_word(b: &mut Bencher) {
vec![],
false,
false,
false,
)
.unwrap();
@@ -125,15 +125,15 @@ fn bench_match_word_miss(b: &mut Bencher) {
let flattened_keys = [
(
"type".to_string(),
SimpleJsonValue::Str("m.text".to_string()),
JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())),
),
(
"room_id".to_string(),
SimpleJsonValue::Str("!room:server".to_string()),
JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())),
),
(
"content.body".to_string(),
SimpleJsonValue::Str("test message".to_string()),
JsonValue::Value(SimpleJsonValue::Str("test message".to_string())),
),
]
.into_iter()
@@ -143,7 +143,6 @@ fn bench_match_word_miss(b: &mut Bencher) {
flattened_keys,
false,
BTreeSet::new(),
false,
10,
Some(0),
Default::default(),
@@ -152,6 +151,7 @@ fn bench_match_word_miss(b: &mut Bencher) {
vec![],
false,
false,
false,
)
.unwrap();
@@ -174,15 +174,15 @@ fn bench_eval_message(b: &mut Bencher) {
let flattened_keys = [
(
"type".to_string(),
SimpleJsonValue::Str("m.text".to_string()),
JsonValue::Value(SimpleJsonValue::Str("m.text".to_string())),
),
(
"room_id".to_string(),
SimpleJsonValue::Str("!room:server".to_string()),
JsonValue::Value(SimpleJsonValue::Str("!room:server".to_string())),
),
(
"content.body".to_string(),
SimpleJsonValue::Str("test message".to_string()),
JsonValue::Value(SimpleJsonValue::Str("test message".to_string())),
),
]
.into_iter()
@@ -192,7 +192,6 @@ fn bench_eval_message(b: &mut Bencher) {
flattened_keys,
false,
BTreeSet::new(),
false,
10,
Some(0),
Default::default(),
@@ -201,6 +200,7 @@ fn bench_eval_message(b: &mut Bencher) {
vec![],
false,
false,
false,
)
.unwrap();
+6 -3
View File
@@ -21,13 +21,13 @@ use lazy_static::lazy_static;
use serde_json::Value;
use super::KnownCondition;
use crate::push::Action;
use crate::push::Condition;
use crate::push::EventMatchCondition;
use crate::push::PushRule;
use crate::push::RelatedEventMatchCondition;
use crate::push::SetTweak;
use crate::push::TweakValue;
use crate::push::{Action, ExactEventMatchCondition, SimpleJsonValue};
const HIGHLIGHT_ACTION: Action = Action::SetTweak(SetTweak {
set_tweak: Cow::Borrowed("highlight"),
@@ -168,7 +168,10 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
rule_id: Cow::Borrowed(".org.matrix.msc3952.is_room_mention"),
priority_class: 5,
conditions: Cow::Borrowed(&[
Condition::Known(KnownCondition::IsRoomMention),
Condition::Known(KnownCondition::ExactEventMatch(ExactEventMatchCondition {
key: Cow::Borrowed("content.org.matrix.msc3952.mentions.room"),
value: Cow::Borrowed(&SimpleJsonValue::Bool(true)),
})),
Condition::Known(KnownCondition::SenderNotificationPermission {
key: Cow::Borrowed("room"),
}),
@@ -223,7 +226,7 @@ pub const BASE_APPEND_OVERRIDE_RULES: &[PushRule] = &[
pattern_type: None,
},
))]),
actions: Cow::Borrowed(&[Action::DontNotify]),
actions: Cow::Borrowed(&[]),
default: true,
default_enabled: true,
},
+51 -21
View File
@@ -14,6 +14,7 @@
use std::collections::{BTreeMap, BTreeSet};
use crate::push::JsonValue;
use anyhow::{Context, Error};
use lazy_static::lazy_static;
use log::warn;
@@ -63,7 +64,7 @@ impl RoomVersionFeatures {
pub struct PushRuleEvaluator {
/// A mapping of "flattened" keys to simple JSON values in the event, e.g.
/// includes things like "type" and "content.msgtype".
flattened_keys: BTreeMap<String, SimpleJsonValue>,
flattened_keys: BTreeMap<String, JsonValue>,
/// The "content.body", if any.
body: String,
@@ -72,8 +73,6 @@ pub struct PushRuleEvaluator {
has_mentions: bool,
/// The user mentions that were part of the message.
user_mentions: BTreeSet<String>,
/// True if the message is a room message.
room_mention: bool,
/// The number of users in the room.
room_member_count: u64,
@@ -87,7 +86,7 @@ pub struct PushRuleEvaluator {
/// The related events, indexed by relation type. Flattened in the same manner as
/// `flattened_keys`.
related_events_flattened: BTreeMap<String, BTreeMap<String, SimpleJsonValue>>,
related_events_flattened: BTreeMap<String, BTreeMap<String, JsonValue>>,
/// If msc3664, push rules for related events, is enabled.
related_event_match_enabled: bool,
@@ -101,6 +100,9 @@ pub struct PushRuleEvaluator {
/// If MSC3758 (exact_event_match push rule condition) is enabled.
msc3758_exact_event_match: bool,
/// If MSC3966 (exact_event_property_contains push rule condition) is enabled.
msc3966_exact_event_property_contains: bool,
}
#[pymethods]
@@ -109,21 +111,21 @@ impl PushRuleEvaluator {
#[allow(clippy::too_many_arguments)]
#[new]
pub fn py_new(
flattened_keys: BTreeMap<String, SimpleJsonValue>,
flattened_keys: BTreeMap<String, JsonValue>,
has_mentions: bool,
user_mentions: BTreeSet<String>,
room_mention: bool,
room_member_count: u64,
sender_power_level: Option<i64>,
notification_power_levels: BTreeMap<String, i64>,
related_events_flattened: BTreeMap<String, BTreeMap<String, SimpleJsonValue>>,
related_events_flattened: BTreeMap<String, BTreeMap<String, JsonValue>>,
related_event_match_enabled: bool,
room_version_feature_flags: Vec<String>,
msc3931_enabled: bool,
msc3758_exact_event_match: bool,
msc3966_exact_event_property_contains: bool,
) -> Result<Self, Error> {
let body = match flattened_keys.get("content.body") {
Some(SimpleJsonValue::Str(s)) => s.clone(),
Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone(),
_ => String::new(),
};
@@ -132,7 +134,6 @@ impl PushRuleEvaluator {
body,
has_mentions,
user_mentions,
room_mention,
room_member_count,
notification_power_levels,
sender_power_level,
@@ -141,6 +142,7 @@ impl PushRuleEvaluator {
room_version_feature_flags,
msc3931_enabled,
msc3758_exact_event_match,
msc3966_exact_event_property_contains,
})
}
@@ -263,6 +265,9 @@ impl PushRuleEvaluator {
KnownCondition::RelatedEventMatch(event_match) => {
self.match_related_event_match(event_match, user_id)?
}
KnownCondition::ExactEventPropertyContains(exact_event_match) => {
self.match_exact_event_property_contains(exact_event_match)?
}
KnownCondition::IsUserMention => {
if let Some(uid) = user_id {
self.user_mentions.contains(uid)
@@ -270,7 +275,6 @@ impl PushRuleEvaluator {
false
}
}
KnownCondition::IsRoomMention => self.room_mention,
KnownCondition::ContainsDisplayName => {
if let Some(dn) = display_name {
if !dn.is_empty() {
@@ -345,7 +349,7 @@ impl PushRuleEvaluator {
return Ok(false);
};
let haystack = if let Some(SimpleJsonValue::Str(haystack)) =
let haystack = if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) =
self.flattened_keys.get(&*event_match.key)
{
haystack
@@ -377,7 +381,9 @@ impl PushRuleEvaluator {
let value = &exact_event_match.value;
let haystack = if let Some(haystack) = self.flattened_keys.get(&*exact_event_match.key) {
let haystack = if let Some(JsonValue::Value(haystack)) =
self.flattened_keys.get(&*exact_event_match.key)
{
haystack
} else {
return Ok(false);
@@ -441,11 +447,12 @@ impl PushRuleEvaluator {
return Ok(false);
};
let haystack = if let Some(SimpleJsonValue::Str(haystack)) = event.get(&**key) {
haystack
} else {
return Ok(false);
};
let haystack =
if let Some(JsonValue::Value(SimpleJsonValue::Str(haystack))) = event.get(&**key) {
haystack
} else {
return Ok(false);
};
// For the content.body we match against "words", but for everything
// else we match against the entire value.
@@ -459,6 +466,29 @@ impl PushRuleEvaluator {
compiled_pattern.is_match(haystack)
}
/// Evaluates a `exact_event_property_contains` condition. (MSC3758)
fn match_exact_event_property_contains(
&self,
exact_event_match: &ExactEventMatchCondition,
) -> Result<bool, Error> {
// First check if the feature is enabled.
if !self.msc3966_exact_event_property_contains {
return Ok(false);
}
let value = &exact_event_match.value;
let haystack = if let Some(JsonValue::Array(haystack)) =
self.flattened_keys.get(&*exact_event_match.key)
{
haystack
} else {
return Ok(false);
};
Ok(haystack.contains(&**value))
}
/// Match the member count against an 'is' condition
/// The `is` condition can be things like '>2', '==3' or even just '4'.
fn match_member_count(&self, is: &str) -> Result<bool, Error> {
@@ -488,13 +518,12 @@ fn push_rule_evaluator() {
let mut flattened_keys = BTreeMap::new();
flattened_keys.insert(
"content.body".to_string(),
SimpleJsonValue::Str("foo bar bob hello".to_string()),
JsonValue::Value(SimpleJsonValue::Str("foo bar bob hello".to_string())),
);
let evaluator = PushRuleEvaluator::py_new(
flattened_keys,
false,
BTreeSet::new(),
false,
10,
Some(0),
BTreeMap::new(),
@@ -503,6 +532,7 @@ fn push_rule_evaluator() {
vec![],
true,
true,
true,
)
.unwrap();
@@ -519,14 +549,13 @@ fn test_requires_room_version_supports_condition() {
let mut flattened_keys = BTreeMap::new();
flattened_keys.insert(
"content.body".to_string(),
SimpleJsonValue::Str("foo bar bob hello".to_string()),
JsonValue::Value(SimpleJsonValue::Str("foo bar bob hello".to_string())),
);
let flags = vec![RoomVersionFeatures::ExtensibleEvents.as_str().to_string()];
let evaluator = PushRuleEvaluator::py_new(
flattened_keys,
false,
BTreeSet::new(),
false,
10,
Some(0),
BTreeMap::new(),
@@ -535,6 +564,7 @@ fn test_requires_room_version_supports_condition() {
flags,
true,
true,
true,
)
.unwrap();
+32 -14
View File
@@ -58,7 +58,7 @@ use anyhow::{Context, Error};
use log::warn;
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyLong, PyString};
use pyo3::types::{PyBool, PyList, PyLong, PyString};
use pythonize::{depythonize, pythonize};
use serde::de::Error as _;
use serde::{Deserialize, Serialize};
@@ -280,6 +280,35 @@ impl<'source> FromPyObject<'source> for SimpleJsonValue {
}
}
/// A JSON values (list, string, int, boolean, or null).
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(untagged)]
pub enum JsonValue {
Array(Vec<SimpleJsonValue>),
Value(SimpleJsonValue),
}
impl<'source> FromPyObject<'source> for JsonValue {
fn extract(ob: &'source PyAny) -> PyResult<Self> {
if let Ok(l) = <PyList as pyo3::PyTryFrom>::try_from(ob) {
match l.iter().map(SimpleJsonValue::extract).collect() {
Ok(a) => Ok(JsonValue::Array(a)),
Err(e) => Err(PyTypeError::new_err(format!(
"Can't convert to JsonValue::Array: {}",
e
))),
}
} else if let Ok(v) = SimpleJsonValue::extract(ob) {
Ok(JsonValue::Value(v))
} else {
Err(PyTypeError::new_err(format!(
"Can't convert from {} to JsonValue",
ob.get_type().name()?
)))
}
}
}
/// A condition used in push rules to match against an event.
///
/// We need this split as `serde` doesn't give us the ability to have a
@@ -303,10 +332,10 @@ pub enum KnownCondition {
ExactEventMatch(ExactEventMatchCondition),
#[serde(rename = "im.nheko.msc3664.related_event_match")]
RelatedEventMatch(RelatedEventMatchCondition),
#[serde(rename = "org.matrix.msc3966.exact_event_property_contains")]
ExactEventPropertyContains(ExactEventMatchCondition),
#[serde(rename = "org.matrix.msc3952.is_user_mention")]
IsUserMention,
#[serde(rename = "org.matrix.msc3952.is_room_mention")]
IsRoomMention,
ContainsDisplayName,
RoomMemberCount {
#[serde(skip_serializing_if = "Option::is_none")]
@@ -636,17 +665,6 @@ fn test_deserialize_unstable_msc3952_user_condition() {
));
}
#[test]
fn test_deserialize_unstable_msc3952_room_condition() {
let json = r#"{"kind":"org.matrix.msc3952.is_room_mention"}"#;
let condition: Condition = serde_json::from_str(json).unwrap();
assert!(matches!(
condition,
Condition::Known(KnownCondition::IsRoomMention)
));
}
#[test]
fn test_deserialize_custom_condition() {
let json = r#"{"kind":"custom_tag"}"#;
+17 -3
View File
@@ -19,7 +19,8 @@ usage() {
echo "-c"
echo " CI mode. Prints every command that the script runs."
echo "-o <path>"
echo " Directory to output full schema files to."
echo " Directory to output full schema files to. You probably want to use"
echo " '-o synapse/storage/schema'"
echo "-n <schema number>"
echo " Schema number for the new snapshot. Used to set the location of files within "
echo " the output directory, mimicking that of synapse/storage/schemas."
@@ -27,6 +28,11 @@ usage() {
echo "-h"
echo " Display this help text."
echo ""
echo ""
echo "You probably want to invoke this with something like"
echo " docker run --rm -e POSTGRES_PASSWORD=postgres -e POSTGRES_USER=postgres -e POSTGRES_DB=synapse -p 5432:5432 postgres:11-alpine"
echo " echo postgres | scripts-dev/make_full_schema.sh -p postgres -n MY_SCHEMA_NUMBER -o synapse/storage/schema"
echo ""
echo " NB: make sure to run this against the *oldest* supported version of postgres,"
echo " or else pg_dump might output non-backwards-compatible syntax."
}
@@ -189,7 +195,7 @@ python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG"
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
echo "Running db background jobs..."
synapse/_scripts/update_synapse_database.py --database-config "$SQLITE_CONFIG" --run-background-updates
poetry run python synapse/_scripts/update_synapse_database.py --database-config "$SQLITE_CONFIG" --run-background-updates
# Create the PostgreSQL database.
echo "Creating postgres databases..."
@@ -198,7 +204,7 @@ createdb --lc-collate=C --lc-ctype=C --template=template0 "$POSTGRES_MAIN_DB_NAM
createdb --lc-collate=C --lc-ctype=C --template=template0 "$POSTGRES_STATE_DB_NAME"
echo "Running db background jobs..."
synapse/_scripts/update_synapse_database.py --database-config "$POSTGRES_CONFIG" --run-background-updates
poetry run python synapse/_scripts/update_synapse_database.py --database-config "$POSTGRES_CONFIG" --run-background-updates
echo "Dropping unwanted db tables..."
@@ -293,4 +299,12 @@ pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owne
pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner "$POSTGRES_STATE_DB_NAME" | cleanup_pg_schema > "$OUTPUT_DIR/state/full_schemas/$SCHEMA_NUMBER/full.sql.postgres"
pg_dump --format=plain --data-only --inserts --no-tablespaces --no-acl --no-owner "$POSTGRES_STATE_DB_NAME" | cleanup_pg_schema >> "$OUTPUT_DIR/state/full_schemas/$SCHEMA_NUMBER/full.sql.postgres"
if [[ "$OUTPUT_DIR" == *synapse/storage/schema ]]; then
echo "Updating contrib/datagrip symlinks..."
ln -sf "../../synapse/storage/schema/common/full_schemas/$SCHEMA_NUMBER/full.sql.postgres" "contrib/datagrip/common.sql"
ln -sf "../../synapse/storage/schema/main/full_schemas/$SCHEMA_NUMBER/full.sql.postgres" "contrib/datagrip/main.sql"
ln -sf "../../synapse/storage/schema/state/full_schemas/$SCHEMA_NUMBER/full.sql.postgres" "contrib/datagrip/state.sql"
else
echo "Not updating contrib/datagrip symlinks (unknown output directory)"
fi
echo "Done! Files dumped to: $OUTPUT_DIR"
+4 -4
View File
@@ -14,7 +14,7 @@
from typing import Any, Collection, Dict, Mapping, Optional, Sequence, Set, Tuple, Union
from synapse.types import JsonDict, SimpleJsonValue
from synapse.types import JsonDict, JsonValue
class PushRule:
@property
@@ -56,18 +56,18 @@ def get_base_rule_ids() -> Collection[str]: ...
class PushRuleEvaluator:
def __init__(
self,
flattened_keys: Mapping[str, SimpleJsonValue],
flattened_keys: Mapping[str, JsonValue],
has_mentions: bool,
user_mentions: Set[str],
room_mention: bool,
room_member_count: int,
sender_power_level: Optional[int],
notification_power_levels: Mapping[str, int],
related_events_flattened: Mapping[str, Mapping[str, SimpleJsonValue]],
related_events_flattened: Mapping[str, Mapping[str, JsonValue]],
related_event_match_enabled: bool,
room_version_feature_flags: Tuple[str, ...],
msc3931_enabled: bool,
msc3758_exact_event_match: bool,
msc3966_exact_event_property_contains: bool,
): ...
def run(
self,
+4
View File
@@ -108,6 +108,10 @@ class Codes(str, Enum):
USER_AWAITING_APPROVAL = "ORG.MATRIX.MSC3866_USER_AWAITING_APPROVAL"
# Attempt to send a second annotation with the same event type & annotation key
# MSC2677
DUPLICATE_ANNOTATION = "M_DUPLICATE_ANNOTATION"
class CodeMessageException(RuntimeError):
"""An exception with integer code and message string attributes.
+14 -1
View File
@@ -17,7 +17,7 @@ import logging
import os
import sys
import tempfile
from typing import List, Optional
from typing import List, Mapping, Optional
from twisted.internet import defer, task
@@ -222,6 +222,19 @@ class FileExfiltrationWriter(ExfiltrationWriter):
with open(connection_file, "a") as f:
print(json.dumps(connection), file=f)
def write_account_data(
self, file_name: str, account_data: Mapping[str, JsonDict]
) -> None:
account_data_directory = os.path.join(
self.base_directory, "user_data", "account_data"
)
os.makedirs(account_data_directory, exist_ok=True)
account_data_file = os.path.join(account_data_directory, file_name)
with open(account_data_file, "a") as f:
print(json.dumps(account_data), file=f)
def finished(self) -> str:
return self.base_directory
+9 -3
View File
@@ -179,12 +179,18 @@ class ExperimentalConfig(Config):
"msc3783_escape_event_match_key", False
)
# MSC3952: Intentional mentions
self.msc3952_intentional_mentions = experimental.get(
"msc3952_intentional_mentions", False
# MSC3952: Intentional mentions, this depends on MSC3758.
self.msc3952_intentional_mentions = (
experimental.get("msc3952_intentional_mentions", False)
and self.msc3758_exact_event_match
)
# MSC3959: Do not generate notifications for edits.
self.msc3958_supress_edit_notifs = experimental.get(
"msc3958_supress_edit_notifs", False
)
# MSC3966: exact_event_property_contains push rule condition.
self.msc3966_exact_event_property_contains = experimental.get(
"msc3966_exact_event_property_contains", False
)
+1
View File
@@ -33,4 +33,5 @@ class RedisConfig(Config):
self.redis_host = redis_config.get("host", "localhost")
self.redis_port = redis_config.get("port", 6379)
self.redis_dbid = redis_config.get("dbid", None)
self.redis_password = redis_config.get("password")
+5 -6
View File
@@ -884,7 +884,7 @@ class FederationClient(FederationBase):
if 500 <= e.code < 600:
failover = True
elif e.code == 400 and synapse_error.errcode in failover_errcodes:
elif 400 <= e.code < 500 and synapse_error.errcode in failover_errcodes:
failover = True
elif failover_on_unknown_endpoint and self._is_unknown_endpoint(
@@ -999,14 +999,13 @@ class FederationClient(FederationBase):
return destination, ev, room_version
failover_errcodes = {Codes.NOT_FOUND}
# MSC3083 defines additional error codes for room joins. Unfortunately
# we do not yet know the room version, assume these will only be returned
# by valid room versions.
failover_errcodes = (
(Codes.UNABLE_AUTHORISE_JOIN, Codes.UNABLE_TO_GRANT_JOIN)
if membership == Membership.JOIN
else None
)
if membership == Membership.JOIN:
failover_errcodes.add(Codes.UNABLE_AUTHORISE_JOIN)
failover_errcodes.add(Codes.UNABLE_TO_GRANT_JOIN)
return await self._try_destination_list(
"make_" + membership,
+34 -15
View File
@@ -14,7 +14,7 @@
import abc
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class AdminHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._store = hs.get_datastores().main
self._device_handler = hs.get_device_handler()
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
@@ -38,7 +38,7 @@ class AdminHandler:
async def get_whois(self, user: UserID) -> JsonDict:
connections = []
sessions = await self.store.get_user_ip_and_agents(user)
sessions = await self._store.get_user_ip_and_agents(user)
for session in sessions:
connections.append(
{
@@ -57,7 +57,7 @@ class AdminHandler:
async def get_user(self, user: UserID) -> Optional[JsonDict]:
"""Function to get user details"""
user_info_dict = await self.store.get_user_by_id(user.to_string())
user_info_dict = await self._store.get_user_by_id(user.to_string())
if user_info_dict is None:
return None
@@ -89,11 +89,11 @@ class AdminHandler:
}
# Add additional user metadata
profile = await self.store.get_profileinfo(user.localpart)
threepids = await self.store.user_get_threepids(user.to_string())
profile = await self._store.get_profileinfo(user.localpart)
threepids = await self._store.user_get_threepids(user.to_string())
external_ids = [
({"auth_provider": auth_provider, "external_id": external_id})
for auth_provider, external_id in await self.store.get_external_ids_by_user(
for auth_provider, external_id in await self._store.get_external_ids_by_user(
user.to_string()
)
]
@@ -101,7 +101,7 @@ class AdminHandler:
user_info_dict["avatar_url"] = profile.avatar_url
user_info_dict["threepids"] = threepids
user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self.store.is_user_erased(user.to_string())
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
return user_info_dict
@@ -117,7 +117,7 @@ class AdminHandler:
The returned value is that returned by `writer.finished()`.
"""
# Get all rooms the user is in or has been in
rooms = await self.store.get_rooms_for_local_user_where_membership_is(
rooms = await self._store.get_rooms_for_local_user_where_membership_is(
user_id,
membership_list=(
Membership.JOIN,
@@ -131,7 +131,7 @@ class AdminHandler:
# We only try and fetch events for rooms the user has been in. If
# they've been e.g. invited to a room without joining then we handle
# those separately.
rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id)
rooms_user_has_been_in = await self._store.get_rooms_user_has_been_in(user_id)
for index, room in enumerate(rooms):
room_id = room.room_id
@@ -140,7 +140,7 @@ class AdminHandler:
"[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
)
forgotten = await self.store.did_forget(user_id, room_id)
forgotten = await self._store.did_forget(user_id, room_id)
if forgotten:
logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
continue
@@ -152,14 +152,14 @@ class AdminHandler:
if room.membership == Membership.INVITE:
event_id = room.event_id
invite = await self.store.get_event(event_id, allow_none=True)
invite = await self._store.get_event(event_id, allow_none=True)
if invite:
invited_state = invite.unsigned["invite_room_state"]
writer.write_invite(room_id, invite, invited_state)
if room.membership == Membership.KNOCK:
event_id = room.event_id
knock = await self.store.get_event(event_id, allow_none=True)
knock = await self._store.get_event(event_id, allow_none=True)
if knock:
knock_state = knock.unsigned["knock_room_state"]
writer.write_knock(room_id, knock, knock_state)
@@ -170,7 +170,7 @@ class AdminHandler:
# were joined. We estimate that point by looking at the
# stream_ordering of the last membership if it wasn't a join.
if room.membership == Membership.JOIN:
stream_ordering = self.store.get_room_max_stream_ordering()
stream_ordering = self._store.get_room_max_stream_ordering()
else:
stream_ordering = room.stream_ordering
@@ -197,7 +197,7 @@ class AdminHandler:
# events that we have and then filtering, this isn't the most
# efficient method perhaps but it does guarantee we get everything.
while True:
events, _ = await self.store.paginate_room_events(
events, _ = await self._store.paginate_room_events(
room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
)
if not events:
@@ -263,6 +263,13 @@ class AdminHandler:
connections["devices"][""]["sessions"][0]["connections"]
)
# Get all account data the user has global and in rooms
global_data = await self._store.get_global_account_data_for_user(user_id)
by_room_data = await self._store.get_room_account_data_for_user(user_id)
writer.write_account_data("global", global_data)
for room_id in by_room_data:
writer.write_account_data(room_id, by_room_data[room_id])
return writer.finished()
@@ -340,6 +347,18 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
@abc.abstractmethod
def write_account_data(
self, file_name: str, account_data: Mapping[str, JsonDict]
) -> None:
"""Write the account data of a user.
Args:
file_name: file name to write data
account_data: mapping of global or room account_data
"""
raise NotImplementedError()
@abc.abstractmethod
def finished(self) -> Any:
"""Called when all data has successfully been exported and written.
+1 -1
View File
@@ -201,7 +201,7 @@ class AuthHandler:
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
self.checkers[inst.AUTH_TYPE] = inst # type: ignore
self.checkers[inst.AUTH_TYPE] = inst
self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
+20 -1
View File
@@ -952,7 +952,20 @@ class FederationHandler:
#
# Note that this requires the /send_join request to come back to the
# same server.
prev_event_ids = None
if room_version.msc3083_join_rules:
# Note that the room's state can change out from under us and render our
# nice join rules-conformant event non-conformant by the time we build the
# event. When this happens, our validation at the end fails and we respond
# to the requesting server with a 403, which is misleading — it indicates
# that the user is not allowed to join the room and the joining server
# should not bother retrying via this homeserver or any others, when
# in fact we've just messed up with building the event.
#
# To reduce the likelihood of this race, we capture the forward extremities
# of the room (prev_event_ids) just before fetching the current state, and
# hope that the state we fetch corresponds to the prev events we chose.
prev_event_ids = await self.store.get_prev_events_for_room(room_id)
state_ids = await self._state_storage_controller.get_current_state_ids(
room_id
)
@@ -994,7 +1007,8 @@ class FederationHandler:
event,
unpersisted_context,
) = await self.event_creation_handler.create_new_client_event(
builder=builder
builder=builder,
prev_event_ids=prev_event_ids,
)
except SynapseError as e:
logger.warning("Failed to create join to %s because %s", room_id, e)
@@ -1880,6 +1894,11 @@ class FederationHandler:
logger.info("Updating current state for %s", room_id)
# TODO(faster_joins): notify workers in notify_room_un_partial_stated
# https://github.com/matrix-org/synapse/issues/12994
#
# NB: there's a potential race here. If room is purged just before we
# call this, we _might_ end up inserting rows into current_state_events.
# (The logic is hard to chase through.) We think this is fine, but if
# not the HS admin should purge the room again.
await self.state_handler.update_current_state(room_id)
logger.info("Handling any pending device list updates")
+5 -1
View File
@@ -1337,7 +1337,11 @@ class EventCreationHandler:
relation.parent_id, event.type, aggregation_key, event.sender
)
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")
raise SynapseError(
400,
"Can't send same reaction twice",
errcode=Codes.DUPLICATE_ANNOTATION,
)
# Don't attempt to start a thread if the parent event is a relation.
elif relation.rel_type == RelationTypes.THREAD:
+13
View File
@@ -269,6 +269,8 @@ class SyncHandler:
self._state_storage_controller = self._storage_controllers.state
self._device_handler = hs.get_device_handler()
self.should_calculate_push_rules = hs.config.push.enable_push
# TODO: flush cache entries on subsequent sync request.
# Once we get the next /sync request (ie, one with the same access token
# that sets 'since' to 'next_batch'), we know that device won't need a
@@ -1288,6 +1290,12 @@ class SyncHandler:
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
) -> RoomNotifCounts:
if not self.should_calculate_push_rules:
# If push rules have been universally disabled then we know we won't
# have any unread counts in the DB, so we may as well skip asking
# the DB.
return RoomNotifCounts.empty()
with Measure(self.clock, "unread_notifs_for_room_id"):
return await self.store.get_unread_event_push_actions_by_room_for_user(
@@ -1391,6 +1399,11 @@ class SyncHandler:
for room_id, is_partial_state in results.items()
if is_partial_state
)
membership_change_events = [
event
for event in membership_change_events
if not results.get(event.room_id, False)
]
# Incremental eager syncs should additionally include rooms that
# - we are joined to
+14 -4
View File
@@ -13,7 +13,8 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type
from twisted.web.client import PartialDownloadError
@@ -27,19 +28,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class UserInteractiveAuthChecker:
class UserInteractiveAuthChecker(ABC):
"""Abstract base class for an interactive auth checker"""
def __init__(self, hs: "HomeServer"):
# This should really be an "abstract class property", i.e. it should
# be an error to instantiate a subclass that doesn't specify an AUTH_TYPE.
# But calling this a `ClassVar` is simpler than a decorator stack of
# @property @abstractmethod and @classmethod (if that's even the right order).
AUTH_TYPE: ClassVar[str]
def __init__(self, hs: "HomeServer"): # noqa: B027
pass
@abstractmethod
def is_enabled(self) -> bool:
"""Check if the configuration of the homeserver allows this checker to work
Returns:
True if this login type is enabled.
"""
raise NotImplementedError()
@abstractmethod
async def check_auth(self, authdict: dict, clientip: str) -> Any:
"""Given the authentication dict from the client, attempt to check this step
@@ -304,7 +314,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
)
INTERACTIVE_AUTH_CHECKERS = [
INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [
DummyAuthChecker,
TermsAuthChecker,
RecaptchaAuthChecker,
+1 -1
View File
@@ -1267,7 +1267,7 @@ class MatrixFederationHttpClient:
def _flatten_response_never_received(e: BaseException) -> str:
if hasattr(e, "reasons"):
reasons = ", ".join(
_flatten_response_never_received(f.value) for f in e.reasons # type: ignore[attr-defined]
_flatten_response_never_received(f.value) for f in e.reasons
)
return "%s:[%s]" % (type(e).__name__, reasons)
+17 -7
View File
@@ -188,7 +188,7 @@ from typing import (
)
import attr
from typing_extensions import ParamSpec
from typing_extensions import Concatenate, ParamSpec
from twisted.internet import defer
from twisted.web.http import Request
@@ -445,7 +445,7 @@ def init_tracer(hs: "HomeServer") -> None:
opentracing = None # type: ignore[assignment]
return
if not opentracing or not JaegerConfig:
if opentracing is None or JaegerConfig is None:
raise ConfigError(
"The server has been configured to use opentracing but opentracing is not "
"installed."
@@ -872,7 +872,7 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
def _custom_sync_async_decorator(
func: Callable[P, R],
wrapping_logic: Callable[[Callable[P, R], Any, Any], ContextManager[None]],
wrapping_logic: Callable[Concatenate[Callable[P, R], P], ContextManager[None]],
) -> Callable[P, R]:
"""
Decorates a function that is sync or async (coroutines), or that returns a Twisted
@@ -902,10 +902,14 @@ def _custom_sync_async_decorator(
"""
if inspect.iscoroutinefunction(func):
# In this branch, R = Awaitable[RInner], for some other type RInner
@wraps(func)
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
async def _wrapper(
*args: P.args, **kwargs: P.kwargs
) -> Any: # Return type is RInner
with wrapping_logic(func, *args, **kwargs):
# type-ignore: func() returns R, but mypy doesn't know that R is
# Awaitable here.
return await func(*args, **kwargs) # type: ignore[misc]
else:
@@ -972,7 +976,11 @@ def trace_with_opname(
if not opentracing:
return func
return _custom_sync_async_decorator(func, _wrapping_logic)
# type-ignore: mypy seems to be confused by the ParamSpecs here.
# I think the problem is https://github.com/python/mypy/issues/12909
return _custom_sync_async_decorator(
func, _wrapping_logic # type: ignore[arg-type]
)
return _decorator
@@ -1018,7 +1026,9 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
set_tag(SynapseTags.FUNC_KWARGS, str(kwargs))
yield
return _custom_sync_async_decorator(func, _wrapping_logic)
# type-ignore: mypy seems to be confused by the ParamSpecs here.
# I think the problem is https://github.com/python/mypy/issues/12909
return _custom_sync_async_decorator(func, _wrapping_logic) # type: ignore[arg-type]
@contextlib.contextmanager
+14 -11
View File
@@ -44,7 +44,7 @@ from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
from synapse.types import SimpleJsonValue
from synapse.types import JsonValue
from synapse.types.state import StateFilter
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
@@ -259,13 +259,13 @@ class BulkPushRuleEvaluator:
async def _related_events(
self, event: EventBase
) -> Dict[str, Dict[str, SimpleJsonValue]]:
) -> Dict[str, Dict[str, JsonValue]]:
"""Fetches the related events for 'event'. Sets the im.vector.is_falling_back key if the event is from a fallback relation
Returns:
Mapping of relation type to flattened events.
"""
related_events: Dict[str, Dict[str, SimpleJsonValue]] = {}
related_events: Dict[str, Dict[str, JsonValue]] = {}
if self._related_event_match_enabled:
related_event_id = event.content.get("m.relates_to", {}).get("event_id")
relation_type = event.content.get("m.relates_to", {}).get("rel_type")
@@ -400,7 +400,6 @@ class BulkPushRuleEvaluator:
mentions = event.content.get(EventContentFields.MSC3952_MENTIONS)
has_mentions = self._intentional_mentions_enabled and isinstance(mentions, dict)
user_mentions: Set[str] = set()
room_mention = False
if has_mentions:
# mypy seems to have lost the type even though it must be a dict here.
assert isinstance(mentions, dict)
@@ -410,8 +409,6 @@ class BulkPushRuleEvaluator:
user_mentions = set(
filter(lambda item: isinstance(item, str), user_mentions_raw)
)
# Room mention is only true if the value is exactly true.
room_mention = mentions.get("room") is True
evaluator = PushRuleEvaluator(
_flatten_dict(
@@ -420,7 +417,6 @@ class BulkPushRuleEvaluator:
),
has_mentions,
user_mentions,
room_mention,
room_member_count,
sender_power_level,
notification_levels,
@@ -429,6 +425,7 @@ class BulkPushRuleEvaluator:
event.room_version.msc3931_push_features,
self.hs.config.experimental.msc1767_enabled, # MSC3931 flag
self.hs.config.experimental.msc3758_exact_event_match,
self.hs.config.experimental.msc3966_exact_event_property_contains,
)
users = rules_by_user.keys()
@@ -502,18 +499,22 @@ RulesByUser = Dict[str, List[Rule]]
StateGroup = Union[object, int]
def _is_simple_value(value: Any) -> bool:
return isinstance(value, (bool, str)) or type(value) is int or value is None
def _flatten_dict(
d: Union[EventBase, Mapping[str, Any]],
prefix: Optional[List[str]] = None,
result: Optional[Dict[str, SimpleJsonValue]] = None,
result: Optional[Dict[str, JsonValue]] = None,
*,
msc3783_escape_event_match_key: bool = False,
) -> Dict[str, SimpleJsonValue]:
) -> Dict[str, JsonValue]:
"""
Given a JSON dictionary (or event) which might contain sub dictionaries,
flatten it into a single layer dictionary by combining the keys & sub-keys.
String, integer, boolean, and null values are kept. All others are dropped.
String, integer, boolean, null or lists of those values are kept. All others are dropped.
Transforms:
@@ -542,8 +543,10 @@ def _flatten_dict(
# nested fields.
key = key.replace("\\", "\\\\").replace(".", "\\.")
if isinstance(value, (bool, str)) or type(value) is int or value is None:
if _is_simple_value(value):
result[".".join(prefix + [key])] = value
elif isinstance(value, (list, tuple)):
result[".".join(prefix + [key])] = [v for v in value if _is_simple_value(v)]
elif isinstance(value, Mapping):
# do not set `room_version` due to recursion considerations below
_flatten_dict(
+10 -1
View File
@@ -16,7 +16,7 @@ import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, SynapseError
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
@@ -39,6 +39,7 @@ class ReportEventRestServlet(RestServlet):
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
self._event_handler = self.hs.get_event_handler()
async def on_POST(
self, request: SynapseRequest, room_id: str, event_id: str
@@ -61,6 +62,14 @@ class ReportEventRestServlet(RestServlet):
Codes.BAD_JSON,
)
event = await self._event_handler.get_event(
requester.user, room_id, event_id, show_redacted=False
)
if event is None:
raise NotFoundError(
"Unable to report event: it does not exist or you aren't able to see it."
)
await self.store.add_event_report(
room_id=room_id,
event_id=event_id,
+6 -3
View File
@@ -16,6 +16,7 @@
import logging
import os
import urllib
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Awaitable, Dict, Generator, List, Optional, Tuple, Type
@@ -284,13 +285,14 @@ async def respond_with_responder(
finish_request(request)
class Responder:
class Responder(ABC):
"""Represents a response that can be streamed to the requester.
Responder is a context manager which *must* be used, so that any resources
held can be cleaned up.
"""
@abstractmethod
def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer
@@ -300,11 +302,12 @@ class Responder:
Returns:
Resolves once the response has finished being written
"""
raise NotImplementedError()
def __enter__(self) -> None:
def __enter__(self) -> None: # noqa: B027
pass
def __exit__(
def __exit__( # noqa: B027
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
+1
View File
@@ -827,6 +827,7 @@ class HomeServer(metaclass=abc.ABCMeta):
hs=self,
host=self.config.redis.redis_host,
port=self.config.redis.redis_port,
dbid=self.config.redis.redis_dbid,
password=self.config.redis.redis_password,
reconnect=True,
)
@@ -262,7 +262,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
for batch in batch_iter(signature_query, 50):
cross_sigs_result = await self.db_pool.runInteraction(
"get_e2e_cross_signing_signatures",
"get_e2e_cross_signing_signatures_for_devices",
self._get_e2e_cross_signing_signatures_for_devices_txn,
batch,
)
@@ -203,11 +203,18 @@ class RoomNotifCounts:
# Map of thread ID to the notification counts.
threads: Dict[str, NotifCounts]
@staticmethod
def empty() -> "RoomNotifCounts":
return _EMPTY_ROOM_NOTIF_COUNTS
def __len__(self) -> int:
# To properly account for the amount of space in any caches.
return len(self.threads) + 1
_EMPTY_ROOM_NOTIF_COUNTS = RoomNotifCounts(NotifCounts(), {})
def _serialize_action(
actions: Collection[Union[Mapping, str]], is_highlight: bool
) -> str:
@@ -420,12 +420,14 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_push_actions",
"event_search",
"event_failed_pull_attempts",
# Note: the partial state tables have foreign keys between each other, and to
# `events` and `rooms`. We need to delete from them in the right order.
"partial_state_events",
"partial_state_rooms_servers",
"partial_state_rooms",
"events",
"federation_inbound_events_staging",
"local_current_membership",
"partial_state_rooms_servers",
"partial_state_rooms",
"receipts_graph",
"receipts_linearized",
"room_aliases",
+2 -2
View File
@@ -25,7 +25,7 @@ try:
except ImportError:
class PostgresEngine(BaseDatabaseEngine): # type: ignore[no-redef]
def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc]
def __new__(cls, *args: object, **kwargs: object) -> NoReturn:
raise RuntimeError(
f"Cannot create {cls.__name__} -- psycopg2 module is not installed"
)
@@ -36,7 +36,7 @@ try:
except ImportError:
class Sqlite3Engine(BaseDatabaseEngine): # type: ignore[no-redef]
def __new__(cls, *args: object, **kwargs: object) -> NoReturn: # type: ignore[misc]
def __new__(cls, *args: object, **kwargs: object) -> NoReturn:
raise RuntimeError(
f"Cannot create {cls.__name__} -- sqlite3 module is not installed"
)
+62 -12
View File
@@ -12,7 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from types import TracebackType
from typing import Any, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union
from typing import (
Any,
Callable,
Iterator,
List,
Mapping,
Optional,
Sequence,
Tuple,
Type,
Union,
)
from typing_extensions import Protocol
@@ -112,15 +123,35 @@ class DBAPI2Module(Protocol):
# extends from this hierarchy. See
# https://docs.python.org/3/library/sqlite3.html?highlight=sqlite3#exceptions
# https://www.postgresql.org/docs/current/errcodes-appendix.html#ERRCODES-TABLE
Warning: Type[Exception]
Error: Type[Exception]
#
# Note: rather than
# x: T
# we write
# @property
# def x(self) -> T: ...
# which expresses that the protocol attribute `x` is read-only. The mypy docs
# https://mypy.readthedocs.io/en/latest/common_issues.html#covariant-subtyping-of-mutable-protocol-members-is-rejected
# explain why this is necessary for safety. TL;DR: we shouldn't be able to write
# to `x`, only read from it. See also https://github.com/python/mypy/issues/6002 .
@property
def Warning(self) -> Type[Exception]:
...
@property
def Error(self) -> Type[Exception]:
...
# Errors are divided into `InterfaceError`s (something went wrong in the database
# driver) and `DatabaseError`s (something went wrong in the database). These are
# both subclasses of `Error`, but we can't currently express this in type
# annotations due to https://github.com/python/mypy/issues/8397
InterfaceError: Type[Exception]
DatabaseError: Type[Exception]
@property
def InterfaceError(self) -> Type[Exception]:
...
@property
def DatabaseError(self) -> Type[Exception]:
...
# Everything below is a subclass of `DatabaseError`.
@@ -128,7 +159,9 @@ class DBAPI2Module(Protocol):
# - An integer was too big for its data type.
# - An invalid date time was provided.
# - A string contained a null code point.
DataError: Type[Exception]
@property
def DataError(self) -> Type[Exception]:
...
# Roughly: something went wrong in the database, but it's not within the application
# programmer's control. Examples:
@@ -138,28 +171,45 @@ class DBAPI2Module(Protocol):
# - A serialisation failure occurred.
# - The database ran out of resources, such as storage, memory, connections, etc.
# - The database encountered an error from the operating system.
OperationalError: Type[Exception]
@property
def OperationalError(self) -> Type[Exception]:
...
# Roughly: we've given the database data which breaks a rule we asked it to enforce.
# Examples:
# - Stop, criminal scum! You violated the foreign key constraint
# - Also check constraints, non-null constraints, etc.
IntegrityError: Type[Exception]
@property
def IntegrityError(self) -> Type[Exception]:
...
# Roughly: something went wrong within the database server itself.
InternalError: Type[Exception]
@property
def InternalError(self) -> Type[Exception]:
...
# Roughly: the application did something silly that needs to be fixed. Examples:
# - We don't have permissions to do something.
# - We tried to create a table with duplicate column names.
# - We tried to use a reserved name.
# - We referred to a column that doesn't exist.
ProgrammingError: Type[Exception]
@property
def ProgrammingError(self) -> Type[Exception]:
...
# Roughly: we've tried to do something that this database doesn't support.
NotSupportedError: Type[Exception]
@property
def NotSupportedError(self) -> Type[Exception]:
...
def connect(self, **parameters: object) -> Connection:
# We originally wrote
# def connect(self, *args, **kwargs) -> Connection: ...
# But mypy doesn't seem to like that because sqlite3.connect takes a mandatory
# positional argument. We can't make that part of the signature though, because
# psycopg2.connect doesn't have a mandatory positional argument. Instead, we use
# the following slightly unusual workaround.
@property
def connect(self) -> Callable[..., Connection]:
...
+4 -3
View File
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Generic, List, Optional, Tuple, TypeVar
from synapse.types import StrCollection, UserID
@@ -22,7 +22,8 @@ K = TypeVar("K")
R = TypeVar("R")
class EventSource(Generic[K, R]):
class EventSource(ABC, Generic[K, R]):
@abstractmethod
async def get_new_events(
self,
user: UserID,
@@ -32,4 +33,4 @@ class EventSource(Generic[K, R]):
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[R], K]:
...
raise NotImplementedError()
+1
View File
@@ -71,6 +71,7 @@ MutableStateMap = MutableMapping[StateKey, T]
# JSON types. These could be made stronger, but will do for now.
# A "simple" (canonical) JSON value.
SimpleJsonValue = Optional[Union[str, int, bool]]
JsonValue = Union[List[SimpleJsonValue], Tuple[SimpleJsonValue, ...], SimpleJsonValue]
# A JSON-serialisable dict.
JsonDict = Dict[str, Any]
# A JSON-serialisable mapping; roughly speaking an immutable JSONDict.
+6 -2
View File
@@ -67,7 +67,9 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
}
# Listen with the config
self.hs._listen_http(parse_listener_def(0, config))
hs = self.hs
assert isinstance(hs, GenericWorkerServer)
hs._listen_http(parse_listener_def(0, config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
@@ -115,7 +117,9 @@ class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
}
# Listen with the config
self.hs._listener_http(self.hs.config, parse_listener_def(0, config))
hs = self.hs
assert isinstance(hs, SynapseHomeServer)
hs._listener_http(self.hs.config, parse_listener_def(0, config))
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
+2 -4
View File
@@ -11,12 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, List, Optional, Sequence, Tuple, cast
from typing import List, Optional, Sequence, Tuple, cast
from unittest.mock import Mock
from typing_extensions import TypeAlias
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.appservice import (
ApplicationService,
@@ -40,9 +41,6 @@ from tests.test_utils import simple_async_mock
from ..utils import MockClock
if TYPE_CHECKING:
from twisted.internet.testing import MemoryReactor
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self) -> None:
+11 -11
View File
@@ -192,7 +192,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys(
"server9",
time.time() * 1000,
int(time.time() * 1000),
[("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
)
self.get_success(r)
@@ -287,7 +287,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
key1 = signedjson.key.generate_signing_key("1")
r = self.hs.get_datastores().main.store_server_verify_keys(
"server9",
time.time() * 1000,
int(time.time() * 1000),
# None is not a valid value in FetchKeyResult, but we're abusing this
# API to insert null values into the database. The nulls get converted
# to 0 when fetched in KeyStore.get_server_verify_keys.
@@ -466,9 +466,9 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
)
res = key_json[lookup_triplet]
self.assertEqual(len(res), 1)
res = res[0]
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], SERVER_NAME)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
@@ -584,9 +584,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
)
res = key_json[lookup_triplet]
self.assertEqual(len(res), 1)
res = res[0]
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
@@ -705,9 +705,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json([lookup_triplet])
)
res = key_json[lookup_triplet]
self.assertEqual(len(res), 1)
res = res[0]
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
+5 -7
View File
@@ -156,11 +156,11 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = simple_async_mock({})
self.fed_transport_client = Mock(spec=["send_transaction"])
self.fed_transport_client.send_transaction = simple_async_mock({})
hs = self.setup_test_homeserver(
federation_transport_client=fed_transport_client,
federation_transport_client=self.fed_transport_client,
)
load_legacy_presence_router(hs)
@@ -422,7 +422,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
#
# Thus we reset the mock, and try sending all online local user
# presence again
self.hs.get_federation_transport_client().send_transaction.reset_mock()
self.fed_transport_client.send_transaction.reset_mock()
# Broadcast local user online presence
self.get_success(
@@ -447,9 +447,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
}
found_users = set()
calls = (
self.hs.get_federation_transport_client().send_transaction.call_args_list
)
calls = self.fed_transport_client.send_transaction.call_args_list
for call in calls:
call_args = call[0]
federation_transaction: Transaction = call_args[0]
+20 -15
View File
@@ -17,7 +17,7 @@ from unittest.mock import Mock
from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.types import JsonDict, UserID
from synapse.types import JsonDict, UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
@@ -56,7 +56,11 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Artificially raise the complexity
store = self.hs.get_datastores().main
store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
async def get_current_state_event_counts(room_id: str) -> int:
return int(500 * 1.23)
store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
# Get the room complexity again -- make sure it's our artificial value
channel = self.make_signed_federation_request(
@@ -75,12 +79,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
None,
create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
@@ -106,12 +110,12 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
None,
create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
@@ -144,17 +148,18 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
handler.federation_handler.do_invite_join = Mock(
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
# Artificially raise the complexity
self.hs.get_datastores().main.get_current_state_event_counts = (
lambda x: make_awaitable(600)
)
async def get_current_state_event_counts(room_id: str) -> int:
return 600
self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
d = handler._remote_join(
None,
create_requester(u1),
["other.example.com"],
room_1,
UserID.from_string(u1),
@@ -200,12 +205,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
None,
create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
@@ -230,12 +235,12 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
return_value=make_awaitable(("", 1))
)
d = handler._remote_join(
None,
create_requester(u1),
["other.example.com"],
"roomid",
UserID.from_string(u1),
+22 -10
View File
@@ -5,7 +5,11 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.federation.sender import PerDestinationQueue, TransactionManager
from synapse.federation.sender import (
FederationSender,
PerDestinationQueue,
TransactionManager,
)
from synapse.federation.units import Edu, Transaction
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -33,8 +37,9 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.federation_transport_client = Mock(spec=["send_transaction"])
return self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
federation_transport_client=self.federation_transport_client,
)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -52,10 +57,14 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
self.pdus: List[JsonDict] = []
self.failed_pdus: List[JsonDict] = []
self.is_online = True
self.hs.get_federation_transport_client().send_transaction.side_effect = (
self.federation_transport_client.send_transaction.side_effect = (
self.record_transaction
)
federation_sender = hs.get_federation_sender()
assert isinstance(federation_sender, FederationSender)
self.federation_sender = federation_sender
def default_config(self) -> JsonDict:
config = super().default_config()
config["federation_sender_instances"] = None
@@ -229,11 +238,11 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# let's delete the federation transmission queue
# (this pretends we are starting up fresh.)
self.assertFalse(
self.hs.get_federation_sender()
._per_destination_queues["host2"]
.transmission_loop_running
self.federation_sender._per_destination_queues[
"host2"
].transmission_loop_running
)
del self.hs.get_federation_sender()._per_destination_queues["host2"]
del self.federation_sender._per_destination_queues["host2"]
# let's also clear any backoffs
self.get_success(
@@ -322,6 +331,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# also fetch event 5 so we know its last_successful_stream_ordering later
event_5 = self.get_success(self.hs.get_datastores().main.get_event(event_id_5))
assert event_2.internal_metadata.stream_ordering is not None
self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_2.internal_metadata.stream_ordering
@@ -425,15 +435,16 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
def wake_destination_track(destination: str) -> None:
woken.append(destination)
self.hs.get_federation_sender().wake_destination = wake_destination_track
self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment]
# cancel the pre-existing timer for _wake_destinations_needing_catchup
# this is because we are calling it manually rather than waiting for it
# to be called automatically
self.hs.get_federation_sender()._catchup_after_startup_timer.cancel()
assert self.federation_sender._catchup_after_startup_timer is not None
self.federation_sender._catchup_after_startup_timer.cancel()
self.get_success(
self.hs.get_federation_sender()._wake_destinations_needing_catchup(), by=5.0
self.federation_sender._wake_destinations_needing_catchup(), by=5.0
)
# ASSERT (_wake_destinations_needing_catchup):
@@ -475,6 +486,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
)
)
assert event_1.internal_metadata.stream_ordering is not None
self.get_success(
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
"host2", event_1.internal_metadata.stream_ordering
+2 -2
View File
@@ -178,7 +178,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
RoomVersions.V9,
)
)
self.assertIsNotNone(pulled_pdu_info2)
assert pulled_pdu_info2 is not None
remote_pdu2 = pulled_pdu_info2.pdu
# Sanity check that we are working against the same event
@@ -226,7 +226,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
RoomVersions.V9,
)
)
self.assertIsNotNone(pulled_pdu_info)
assert pulled_pdu_info is not None
remote_pdu = pulled_pdu_info.pdu
# check the right call got made to the agent
+24 -31
View File
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
from synapse.federation.units import Transaction
from synapse.handlers.device import DeviceHandler
from synapse.rest import admin
from synapse.rest.client import login
from synapse.server import HomeServer
@@ -41,8 +42,9 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
"""
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.federation_transport_client = Mock(spec=["send_transaction"])
hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
federation_transport_client=self.federation_transport_client,
)
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
@@ -61,9 +63,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
return config
def test_send_receipts(self) -> None:
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
@@ -103,9 +103,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
)
def test_send_receipts_thread(self) -> None:
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
# Create receipts for:
@@ -181,9 +179,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_with_backoff(self) -> None:
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
@@ -277,10 +273,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.federation_transport_client = Mock(
spec=["send_transaction", "query_user_devices"]
)
return self.setup_test_homeserver(
federation_transport_client=Mock(
spec=["send_transaction", "query_user_devices"]
),
federation_transport_client=self.federation_transport_client,
)
def default_config(self) -> JsonDict:
@@ -310,9 +307,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment]
device_handler = hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
self.device_handler = device_handler
# whenever send_transaction is called, record the edu data
self.edus: List[JsonDict] = []
self.hs.get_federation_transport_client().send_transaction.side_effect = (
self.federation_transport_client.send_transaction.side_effect = (
self.record_transaction
)
@@ -353,7 +354,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
self.hs.get_federation_transport_client().query_user_devices.return_value = (
self.federation_transport_client.query_user_devices.return_value = (
make_awaitable(
{
"stream_id": "1",
@@ -364,7 +365,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
)
self.get_success(
self.hs.get_device_handler().device_list_updater.incoming_device_list_update(
self.device_handler.device_list_updater.incoming_device_list_update(
"host2",
{
"user_id": "@user2:host2",
@@ -507,9 +508,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
# delete them again
self.get_success(
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -533,7 +532,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
"""If the destination server is unreachable, all the updates should get sent on
recovery
"""
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
@@ -543,9 +542,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D3")
# delete them again
self.get_success(
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -580,7 +577,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
This case tests the behaviour when the server has never been reachable.
"""
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
@@ -590,9 +587,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D3")
# delete them again
self.get_success(
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -640,7 +635,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
# now the server goes offline
mock_send_txn = self.hs.get_federation_transport_client().send_transaction
mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
self.login("user", "pass", device_id="D2")
@@ -651,9 +646,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.reactor.advance(1)
# delete them again
self.get_success(
self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
)
self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
self.assertGreaterEqual(mock_send_txn.call_count, 3)
+27
View File
@@ -296,3 +296,30 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[0][0]["user_agent"], "user_agent")
self.assertGreater(args[0][0]["last_seen"], 0)
self.assertNotIn("access_token", args[0][0])
def test_account_data(self) -> None:
"""Tests that user account data get exported."""
# add account data
self.get_success(
self._store.add_account_data_for_user(self.user2, "m.global", {"a": 1})
)
self.get_success(
self._store.add_account_data_to_room(
self.user2, "test_room", "m.per_room", {"b": 2}
)
)
writer = Mock()
self.get_success(self.admin_handler.export_user_data(self.user2, writer))
# two calls, one call for user data and one call for room data
writer.write_account_data.assert_called()
args = writer.write_account_data.call_args_list[0][0]
self.assertEqual(args[0], "global")
self.assertEqual(args[1]["m.global"]["a"], 1)
args = writer.write_account_data.call_args_list[1][0]
self.assertEqual(args[0], "test_room")
self.assertEqual(args[1]["m.per_room"]["b"], 2)
+1 -1
View File
@@ -899,7 +899,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
self.hs.get_datastores().main.get_app_services = Mock(
self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
return_value=self._services
)
+4 -4
View File
@@ -61,7 +61,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
cas_response = CasResponse("test_user", {})
request = _mock_request()
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# Map a user via SSO.
cas_response = CasResponse("test_user", {})
@@ -129,7 +129,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
cas_response = CasResponse("föö", {})
request = _mock_request()
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department.
cas_response = CasResponse("test_user", {})
+29 -26
View File
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
from synapse.handlers.device import DeviceHandler
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
@@ -187,37 +188,37 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# we should now have an unused alg1 key
res = self.get_success(
fallback_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, ["alg1"])
self.assertEqual(fallback_res, ["alg1"])
# claiming an OTK when no OTKs are available should return the fallback
# key
res = self.get_success(
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
# we shouldn't have any unused fallback keys again
res = self.get_success(
unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])
self.assertEqual(unused_res, [])
# claiming an OTK again should return the same fallback key
res = self.get_success(
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
)
@@ -231,10 +232,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
res = self.get_success(
unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, [])
self.assertEqual(unused_res, [])
# uploading a new fallback key should result in an unused fallback key
self.get_success(
@@ -245,10 +246,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
res = self.get_success(
unused_res = self.get_success(
self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
)
self.assertEqual(res, ["alg1"])
self.assertEqual(unused_res, ["alg1"])
# if the user uploads a one-time key, the next claim should fetch the
# one-time key, and then go back to the fallback
@@ -258,23 +259,23 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
res = self.get_success(
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
)
res = self.get_success(
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
)
@@ -287,13 +288,13 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
)
res = self.get_success(
claim_res = self.get_success(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res,
claim_res,
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
)
@@ -366,7 +367,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
# upload two device keys, which will be signed later by the self-signing key
device_key_1 = {
device_key_1: JsonDict = {
"user_id": local_user,
"device_id": "abc",
"algorithms": [
@@ -379,7 +380,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
},
"signatures": {local_user: {"ed25519:abc": "base64+signature"}},
}
device_key_2 = {
device_key_2: JsonDict = {
"user_id": local_user,
"device_id": "def",
"algorithms": [
@@ -451,8 +452,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
device_handler = self.hs.get_device_handler()
assert isinstance(device_handler, DeviceHandler)
e = self.get_failure(
self.hs.get_device_handler().check_device_registered(
device_handler.check_device_registered(
user_id=local_user,
device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
initial_device_display_name="new display name",
@@ -475,7 +478,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
device_id = "xyz"
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY"
device_key = {
device_key: JsonDict = {
"user_id": local_user,
"device_id": device_id,
"algorithms": [
@@ -497,7 +500,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk"
master_key = {
master_key: JsonDict = {
"user_id": local_user,
"usage": ["master"],
"keys": {"ed25519:" + master_pubkey: master_pubkey},
@@ -540,7 +543,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# the first user
other_user = "@otherboris:" + self.hs.hostname
other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM"
other_master_key = {
other_master_key: JsonDict = {
# private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI
"user_id": other_user,
"usage": ["master"],
@@ -702,7 +705,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_client_keys = mock.Mock(
self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"device_keys": {remote_user_id: {}},
@@ -782,7 +785,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_user_devices = mock.Mock(
self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment]
return_value=make_awaitable(
{
"user_id": remote_user_id,
+9 -9
View File
@@ -371,14 +371,14 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# We mock out the FederationClient.backfill method, to pretend that a remote
# server has returned our fake event.
federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
self.hs.get_federation_client().backfill = federation_client_backfill_mock
self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment]
# We also mock the persist method with a side effect of itself. This allows us
# to track when it has been called while preserving its function.
persist_events_and_notify_mock = Mock(
side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
)
self.hs.get_federation_event_handler().persist_events_and_notify = (
self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment]
persist_events_and_notify_mock
)
@@ -712,12 +712,12 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Try to start another partial state sync.
# Nothing should happen.
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# End the partial state sync
@@ -729,7 +729,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
# The next attempt to start the partial state sync should work.
is_partial_state = True
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
def test_partial_state_room_sync_restart(self) -> None:
@@ -764,7 +764,7 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
# Start the partial state sync.
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Fail the partial state sync.
@@ -773,11 +773,11 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 1)
# Start the partial state sync again.
fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
fed_handler._start_partial_state_room_sync("hs1", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Deduplicate another partial state sync.
fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
fed_handler._start_partial_state_room_sync("hs3", {"hs2"}, "room_id")
self.assertEqual(mock_sync_partial_state_room.call_count, 2)
# Fail the partial state sync.
@@ -786,6 +786,6 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
self.assertEqual(mock_sync_partial_state_room.call_count, 3)
mock_sync_partial_state_room.assert_called_with(
initial_destination="hs3",
other_destinations=["hs2"],
other_destinations={"hs2"},
room_id="room_id",
)
+4 -2
View File
@@ -29,6 +29,7 @@ from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.state import StateResolutionStore
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict
from synapse.util import Clock
@@ -161,6 +162,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
if prev_exists_as_outlier:
prev_event.internal_metadata.outlier = True
persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
self.get_success(
persistence.persist_event(
prev_event,
@@ -861,7 +863,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event,
},
state_res_store=main_store,
state_res_store=StateResolutionStore(main_store),
)
),
[bert_member_event.event_id, rejected_kick_event.event_id],
@@ -906,7 +908,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
rejected_power_levels_event.event_id,
],
event_map={},
state_res_store=main_store,
state_res_store=StateResolutionStore(main_store),
full_conflicted_set=set(),
)
),
+6 -5
View File
@@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = self.hs.get_event_creation_handler()
self._persist_event_storage_controller = (
self.hs.get_storage_controllers().persistence
)
persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
self._persist_event_storage_controller = persistence
self.user_id = self.register_user("tester", "foobar")
self.access_token = self.login("tester", "foobar")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.info = self.get_success(
info = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(
self.access_token,
)
)
self.token_id = self.info.token_id
assert info is not None
self.token_id = info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
+1 -1
View File
@@ -852,7 +852,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
username: The username to use for the test.
registration: Whether to test with registration URLs.
"""
self.hs.get_identity_handler().send_threepid_validation = Mock(
self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment]
return_value=make_awaitable(0),
)
+9 -5
View File
@@ -62,7 +62,7 @@ class TestSpamChecker:
request_info: Collection[Tuple[str, str]],
auth_provider_id: Optional[str],
) -> RegistrationBehaviour:
pass
return RegistrationBehaviour.ALLOW
class DenyAll(TestSpamChecker):
@@ -111,7 +111,7 @@ class TestLegacyRegistrationSpamChecker:
username: Optional[str],
request_info: Collection[Tuple[str, str]],
) -> RegistrationBehaviour:
pass
return RegistrationBehaviour.ALLOW
class LegacyAllowAll(TestLegacyRegistrationSpamChecker):
@@ -203,7 +203,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self) -> None:
self.store.count_monthly_users = Mock(
self.store.count_monthly_users = Mock( # type: ignore[assignment]
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
@@ -304,7 +304,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test"
self.store.count_real_users = Mock(return_value=make_awaitable(1))
self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -319,7 +319,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
self,
) -> None:
self.store.count_real_users = Mock(return_value=make_awaitable(2))
self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment]
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
@@ -346,6 +346,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly not federated.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None
self.assertFalse(room["federatable"])
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "public")
@@ -375,6 +376,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a public room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None
self.assertEqual(room["join_rules"], "public")
# Both users should be in the room.
@@ -413,6 +415,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "invite")
self.assertEqual(room["guest_access"], "can_join")
@@ -456,6 +459,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Ensure the room is properly a private room.
room = self.get_success(self.store.get_room_with_stats(room_id["room_id"]))
assert room is not None
self.assertFalse(room["public"])
self.assertEqual(room["join_rules"], "invite")
self.assertEqual(room["guest_access"], "can_join")
+7 -7
View File
@@ -134,7 +134,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@@ -164,7 +164,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# Map a user via SSO.
saml_response = FakeAuthnResponse(
@@ -206,11 +206,11 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# mock out the error renderer too
sso_handler = self.hs.get_sso_handler()
sso_handler.render_error = Mock(return_value=None)
sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
request = _mock_request()
@@ -227,9 +227,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler and error renderer
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
sso_handler = self.hs.get_sso_handler()
sso_handler.render_error = Mock(return_value=None)
sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
# register a user to occupy the first-choice MXID
store = self.hs.get_datastores().main
@@ -312,7 +312,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
# The response doesn't have the proper userGroup or department.
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
+5 -7
View File
@@ -74,8 +74,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
# we mock out the federation client too
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
self.mock_federation_client = Mock(spec=["put_json"])
self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
@@ -83,7 +83,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.mock_hs_notifier = Mock()
hs = self.setup_test_homeserver(
notifier=self.mock_hs_notifier,
federation_http_client=mock_federation_client,
federation_http_client=self.mock_federation_client,
keyring=mock_keyring,
replication_streams={},
)
@@ -233,8 +233,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
put_json = self.hs.get_federation_http_client().put_json
put_json.assert_called_once_with(
self.mock_federation_client.put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
@@ -349,8 +348,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
put_json = self.hs.get_federation_http_client().put_json
put_json.assert_called_once_with(
self.mock_federation_client.put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
data=_expect_edu_transaction(
+20 -13
View File
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
from typing import Any, Tuple
from unittest.mock import Mock, patch
from urllib.parse import quote
@@ -24,7 +24,7 @@ from synapse.appservice import ApplicationService
from synapse.rest.client import login, register, room, user_directory
from synapse.server import HomeServer
from synapse.storage.roommember import ProfileInfo
from synapse.types import create_requester
from synapse.types import UserProfile, create_requester
from synapse.util import Clock
from tests import unittest
@@ -34,6 +34,12 @@ from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
# A spam checker which doesn't implement anything, so create a bare object.
class UselessSpamChecker:
def __init__(self, config: Any):
pass
class UserDirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the UserDirectoryHandler.
@@ -773,7 +779,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
async def allow_all(user_profile: ProfileInfo) -> bool:
async def allow_all(user_profile: UserProfile) -> bool:
# Allow all users.
return False
@@ -787,7 +793,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(s["results"]), 1)
# Configure a spam checker that filters all users.
async def block_all(user_profile: ProfileInfo) -> bool:
async def block_all(user_profile: UserProfile) -> bool:
# All users are spammy.
return True
@@ -797,6 +803,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
@override_config(
{
"spam_checker": {
"module": "tests.handlers.test_user_directory.UselessSpamChecker"
}
}
)
def test_legacy_spam_checker(self) -> None:
"""
A spam checker without the expected method should be ignored.
@@ -825,11 +838,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
self.assertEqual(public_users, set())
# Configure a spam checker.
spam_checker = self.hs.get_spam_checker()
# The spam checker doesn't need any methods, so create a bare object.
spam_checker.spam_checker = object()
# We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
@@ -954,10 +962,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
context = self.get_success(unpersisted_context.persist(event))
self.get_success(
self.hs.get_storage_controllers().persistence.persist_event(event, context)
)
persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
self.get_success(persistence.persist_event(event, context))
def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
"""We've chosen to simplify the user directory's implementation by
@@ -30,7 +30,7 @@ from twisted.internet.interfaces import (
IOpenSSLClientConnectionCreator,
IProtocolFactory,
)
from twisted.internet.protocol import Factory
from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web._newclient import ResponseNeverReceived
from twisted.web.client import Agent
@@ -63,7 +63,7 @@ from tests.http import (
get_test_ca_cert_file,
)
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.utils import default_config
from tests.utils import checked_cast, default_config
logger = logging.getLogger(__name__)
@@ -146,8 +146,10 @@ class MatrixFederationAgentTests(unittest.TestCase):
#
# Normally this would be done by the TCP socket code in Twisted, but we are
# stubbing that out here.
client_protocol = client_factory.buildProtocol(dummy_address)
assert isinstance(client_protocol, _WrappingProtocol)
# NB: we use a checked_cast here to workaround https://github.com/Shoobx/mypy-zope/issues/91)
client_protocol = checked_cast(
_WrappingProtocol, client_factory.buildProtocol(dummy_address)
)
client_protocol.makeConnection(
FakeTransport(server_protocol, self.reactor, client_protocol)
)
@@ -446,7 +448,6 @@ class MatrixFederationAgentTests(unittest.TestCase):
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
).buildProtocol(dummy_address)
assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
@@ -465,7 +466,8 @@ class MatrixFederationAgentTests(unittest.TestCase):
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
c2s_transport = client_protocol.transport
assert isinstance(client_protocol, Protocol)
c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
@@ -1529,7 +1531,7 @@ def _check_logcontext(context: LoggingContextOrSentinel) -> None:
def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
) -> IProtocolFactory:
) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
signed by our test CA, valid for the domains in `sanlist`
+22 -23
View File
@@ -28,7 +28,7 @@ from twisted.internet.endpoints import (
_WrappingProtocol,
)
from twisted.internet.interfaces import IProtocol, IProtocolFactory
from twisted.internet.protocol import Factory
from twisted.internet.protocol import Factory, Protocol
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.web.http import HTTPChannel
@@ -43,6 +43,7 @@ from tests.http import (
)
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.unittest import TestCase
from tests.utils import checked_cast
logger = logging.getLogger(__name__)
@@ -620,7 +621,6 @@ class MatrixFederationAgentTests(TestCase):
server_ssl_protocol = _wrap_server_factory_for_tls(
_get_test_protocol_factory()
).buildProtocol(dummy_address)
assert isinstance(server_ssl_protocol, TLSMemoryBIOProtocol)
# Tell the HTTP server to send outgoing traffic back via the proxy's transport.
proxy_server_transport = proxy_server.transport
@@ -644,7 +644,8 @@ class MatrixFederationAgentTests(TestCase):
else:
assert isinstance(proxy_server_transport, FakeTransport)
client_protocol = proxy_server_transport.other
c2s_transport = client_protocol.transport
assert isinstance(client_protocol, Protocol)
c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
c2s_transport.other = server_ssl_protocol
self.reactor.advance(0)
@@ -757,12 +758,14 @@ class MatrixFederationAgentTests(TestCase):
assert isinstance(proxy_server, HTTPChannel)
# fish the transports back out so that we can do the old switcheroo
s2c_transport = proxy_server.transport
assert isinstance(s2c_transport, FakeTransport)
client_protocol = s2c_transport.other
assert isinstance(client_protocol, _WrappingProtocol)
c2s_transport = client_protocol.transport
assert isinstance(c2s_transport, FakeTransport)
# To help mypy out with the various Protocols and wrappers and mocks, we do
# some explicit casting. Without the casts, we hit the bug I reported at
# https://github.com/Shoobx/mypy-zope/issues/91 .
# We also double-checked these casts at runtime (test-time) because I found it
# quite confusing to deduce these types in the first place!
s2c_transport = checked_cast(FakeTransport, proxy_server.transport)
client_protocol = checked_cast(_WrappingProtocol, s2c_transport.other)
c2s_transport = checked_cast(FakeTransport, client_protocol.transport)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
@@ -822,9 +825,9 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
def test_proxy_with_no_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
self.assertEqual(proxy_ep._hostStr, "proxy.com")
self.assertEqual(proxy_ep._port, 8888)
@patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"})
def test_proxy_with_unsupported_scheme(self) -> None:
@@ -834,25 +837,21 @@ class MatrixFederationAgentTests(TestCase):
@patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"})
def test_proxy_with_http_scheme(self) -> None:
http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
assert isinstance(http_proxy_agent.http_proxy_endpoint, HostnameEndpoint)
self.assertEqual(http_proxy_agent.http_proxy_endpoint._hostStr, "proxy.com")
self.assertEqual(http_proxy_agent.http_proxy_endpoint._port, 8888)
proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint)
self.assertEqual(proxy_ep._hostStr, "proxy.com")
self.assertEqual(proxy_ep._port, 8888)
@patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"})
def test_proxy_with_https_scheme(self) -> None:
https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True)
assert isinstance(https_proxy_agent.http_proxy_endpoint, _WrapperEndpoint)
self.assertEqual(
https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._hostStr, "proxy.com"
)
self.assertEqual(
https_proxy_agent.http_proxy_endpoint._wrappedEndpoint._port, 8888
)
proxy_ep = checked_cast(_WrapperEndpoint, https_proxy_agent.http_proxy_endpoint)
self.assertEqual(proxy_ep._wrappedEndpoint._hostStr, "proxy.com")
self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888)
def _wrap_server_factory_for_tls(
factory: IProtocolFactory, sanlist: Optional[List[bytes]] = None
) -> IProtocolFactory:
) -> TLSMemoryBIOFactory:
"""Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
The resultant factory will create a TLS server which presents a certificate
+9 -8
View File
@@ -21,6 +21,7 @@ from synapse.logging import RemoteHandler
from tests.logging import LoggerCleanupMixin
from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
from tests.utils import checked_cast
def connect_logging_client(
@@ -56,8 +57,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
client, server = connect_logging_client(self.reactor, 0)
# Trigger data being sent
assert isinstance(client.transport, FakeTransport)
client.transport.flush()
client_transport = checked_cast(FakeTransport, client.transport)
client_transport.flush()
# One log message, with a single trailing newline
logs = server.data.decode("utf8").splitlines()
@@ -89,8 +90,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
assert isinstance(client.transport, FakeTransport)
client.transport.flush()
client_transport = checked_cast(FakeTransport, client.transport)
client_transport.flush()
# Only the 7 infos made it through, the debugs were elided
logs = server.data.splitlines()
@@ -123,8 +124,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
assert isinstance(client.transport, FakeTransport)
client.transport.flush()
client_transport = checked_cast(FakeTransport, client.transport)
client_transport.flush()
# The 10 warnings made it through, the debugs and infos were elided
logs = server.data.splitlines()
@@ -148,8 +149,8 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase):
# Allow the reconnection
client, server = connect_logging_client(self.reactor, 0)
assert isinstance(client.transport, FakeTransport)
client.transport.flush()
client_transport = checked_cast(FakeTransport, client.transport)
client_transport.flush()
# The first five and last five warnings made it through, the debugs and
# infos were elided
+6 -8
View File
@@ -68,11 +68,11 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# Mock out the calls over federation.
fed_transport_client = Mock(spec=["send_transaction"])
fed_transport_client.send_transaction = simple_async_mock({})
self.fed_transport_client = Mock(spec=["send_transaction"])
self.fed_transport_client.send_transaction = simple_async_mock({})
return self.setup_test_homeserver(
federation_transport_client=fed_transport_client,
federation_transport_client=self.fed_transport_client,
)
def test_can_register_user(self) -> None:
@@ -417,7 +417,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
#
# Thus we reset the mock, and try sending online local user
# presence again
self.hs.get_federation_transport_client().send_transaction.reset_mock()
self.fed_transport_client.send_transaction.reset_mock()
# Broadcast local user online presence
self.get_success(
@@ -429,9 +429,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
# Check that a presence update was sent as part of a federation transaction
found_update = False
calls = (
self.hs.get_federation_transport_client().send_transaction.call_args_list
)
calls = self.fed_transport_client.send_transaction.call_args_list
for call in calls:
call_args = call[0]
federation_transaction: Transaction = call_args[0]
@@ -581,7 +579,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
mocked_remote_join = simple_async_mock(
return_value=("fake-event-id", fake_stream_id)
)
self.hs.get_room_member_handler()._remote_join = mocked_remote_join
self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment]
fake_remote_host = f"{self.module_api.server_name}-remote"
# Given that the join is to be faked, we expect the relevant join event not to
+16 -2
View File
@@ -227,7 +227,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
return len(result) > 0
@override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
@override_config(
{
"experimental_features": {
"msc3758_exact_event_match": True,
"msc3952_intentional_mentions": True,
}
}
)
def test_user_mentions(self) -> None:
"""Test the behavior of an event which includes invalid user mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
@@ -323,7 +330,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
)
)
@override_config({"experimental_features": {"msc3952_intentional_mentions": True}})
@override_config(
{
"experimental_features": {
"msc3758_exact_event_match": True,
"msc3952_intentional_mentions": True,
}
}
)
def test_room_mentions(self) -> None:
"""Test the behavior of an event which includes invalid room mentions."""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
+35 -16
View File
@@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.api.errors import Codes, SynapseError
from synapse.push.emailpusher import EmailPusher
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.util import Clock
@@ -105,6 +106,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(self.access_token)
)
assert user_tuple is not None
self.token_id = user_tuple.token_id
# We need to add email to account before we can create a pusher.
@@ -114,7 +116,7 @@ class EmailPusherTests(HomeserverTestCase):
)
)
self.pusher = self.get_success(
pusher = self.get_success(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id,
access_token=self.token_id,
@@ -127,6 +129,8 @@ class EmailPusherTests(HomeserverTestCase):
data={},
)
)
assert isinstance(pusher, EmailPusher)
self.pusher = pusher
self.auth_handler = hs.get_auth_handler()
self.store = hs.get_datastores().main
@@ -375,10 +379,13 @@ class EmailPusherTests(HomeserverTestCase):
)
# check that the pusher for that email address has been deleted
pushers = self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by(
{"user_name": self.user_id}
)
)
)
pushers = list(pushers)
self.assertEqual(len(pushers), 0)
def test_remove_unlinked_pushers_background_job(self) -> None:
@@ -413,10 +420,13 @@ class EmailPusherTests(HomeserverTestCase):
self.wait_for_background_updates()
# Check that all pushers with unlinked addresses were deleted
pushers = self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by(
{"user_name": self.user_id}
)
)
)
pushers = list(pushers)
self.assertEqual(len(pushers), 0)
def _check_for_mail(self) -> Tuple[Sequence, Dict]:
@@ -428,10 +438,13 @@ class EmailPusherTests(HomeserverTestCase):
that notification.
"""
# Get the stream ordering before it gets sent
pushers = self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by(
{"user_name": self.user_id}
)
)
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0].last_stream_ordering
@@ -439,10 +452,13 @@ class EmailPusherTests(HomeserverTestCase):
self.pump(10)
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by(
{"user_name": self.user_id}
)
)
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
@@ -458,10 +474,13 @@ class EmailPusherTests(HomeserverTestCase):
self.assertEqual(len(self.email_attempts), 1)
# The stream ordering has increased
pushers = self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": self.user_id})
pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by(
{"user_name": self.user_id}
)
)
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
+29 -16
View File
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
from typing import Any, List, Tuple
from unittest.mock import Mock
from twisted.internet.defer import Deferred
@@ -22,7 +22,6 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.server import HomeServer
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import JsonDict
from synapse.util import Clock
@@ -67,9 +66,10 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
assert user_tuple is not None
token_id = user_tuple.token_id
def test_data(data: Optional[JsonDict]) -> None:
def test_data(data: Any) -> None:
self.get_failure(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
@@ -113,6 +113,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -140,10 +141,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.helper.send(room, body="There!", tok=other_access_token)
# Get the stream ordering before it gets sent
pushers = self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
)
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
last_stream_ordering = pushers[0].last_stream_ordering
@@ -151,10 +153,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
# It hasn't succeeded yet, so the stream ordering shouldn't have moved
pushers = self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
)
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
@@ -172,10 +175,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
# The stream ordering has increased
pushers = self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
)
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
last_stream_ordering = pushers[0].last_stream_ordering
@@ -194,10 +198,11 @@ class HTTPPusherTests(HomeserverTestCase):
self.pump()
# The stream ordering has increased, again
pushers = self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
pushers = list(
self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
)
)
pushers = list(pushers)
self.assertEqual(len(pushers), 1)
self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
@@ -229,6 +234,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -349,6 +355,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -435,6 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -512,6 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -618,6 +627,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -753,6 +763,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
assert user_tuple is not None
token_id = user_tuple.token_id
self.get_success(
@@ -895,6 +906,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
assert user_tuple is not None
token_id = user_tuple.token_id
device_id = user_tuple.device_id
@@ -941,9 +953,10 @@ class HTTPPusherTests(HomeserverTestCase):
)
# Look up the user info for the access token so we can compare the device ID.
lookup_result: TokenLookupResult = self.get_success(
lookup_result = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
assert lookup_result is not None
# Get the user's devices and check it has the correct device ID.
channel = self.make_request("GET", "/pushers", access_token=access_token)
+50 -26
View File
@@ -32,6 +32,7 @@ from synapse.storage.databases.main.appservice import _make_exclusive_regex
from synapse.synapse_rust.push import PushRuleEvaluator
from synapse.types import JsonDict, JsonMapping, UserID
from synapse.util import Clock
from synapse.util.frozenutils import freeze
from tests import unittest
from tests.test_utils.event_injection import create_event, inject_member_event
@@ -57,17 +58,24 @@ class FlattenDictTestCase(unittest.TestCase):
)
def test_non_string(self) -> None:
"""Booleans, ints, and nulls should be kept while other items are dropped."""
"""String, booleans, ints, nulls and list of those should be kept while other items are dropped."""
input: Dict[str, Any] = {
"woo": "woo",
"foo": True,
"bar": 1,
"baz": None,
"fuzz": [],
"fuzz": ["woo", True, 1, None, [], {}],
"boo": {},
}
self.assertEqual(
{"woo": "woo", "foo": True, "bar": 1, "baz": None}, _flatten_dict(input)
{
"woo": "woo",
"foo": True,
"bar": 1,
"baz": None,
"fuzz": ["woo", True, 1, None],
},
_flatten_dict(input),
)
def test_event(self) -> None:
@@ -117,6 +125,7 @@ class FlattenDictTestCase(unittest.TestCase):
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
"content.org.matrix.msc1767.markup": [],
}
self.assertEqual(expected, _flatten_dict(event))
@@ -128,6 +137,7 @@ class FlattenDictTestCase(unittest.TestCase):
"room_id": "!test:test",
"sender": "@alice:test",
"type": "m.room.message",
"content.org.matrix.msc1767.markup": [],
}
self.assertEqual(expected, _flatten_dict(event))
@@ -139,7 +149,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
*,
has_mentions: bool = False,
user_mentions: Optional[Set[str]] = None,
room_mention: bool = False,
related_events: Optional[JsonDict] = None,
) -> PushRuleEvaluator:
event = FrozenEvent(
@@ -160,7 +169,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
_flatten_dict(event),
has_mentions,
user_mentions or set(),
room_mention,
room_member_count,
sender_power_level,
cast(Dict[str, int], power_levels.get("notifications", {})),
@@ -169,6 +177,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True,
msc3758_exact_event_match=True,
msc3966_exact_event_property_contains=True,
)
def test_display_name(self) -> None:
@@ -221,27 +230,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
# Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
# since the BulkPushRuleEvaluator is what handles data sanitisation.
def test_room_mentions(self) -> None:
"""Check for room mentions."""
condition = {"kind": "org.matrix.msc3952.is_room_mention"}
# No room mention shouldn't match.
evaluator = self._get_evaluator({}, has_mentions=True)
self.assertFalse(evaluator.matches(condition, None, None))
# Room mention should match.
evaluator = self._get_evaluator({}, has_mentions=True, room_mention=True)
self.assertTrue(evaluator.matches(condition, None, None))
# A room mention and user mention is valid.
evaluator = self._get_evaluator(
{}, has_mentions=True, user_mentions={"@another:test"}, room_mention=True
)
self.assertTrue(evaluator.matches(condition, None, None))
# Note that invalid data is tested at tests.push.test_bulk_push_rule_evaluator.TestBulkPushRuleEvaluator.test_mentions
# since the BulkPushRuleEvaluator is what handles data sanitisation.
def _assert_matches(
self, condition: JsonDict, content: JsonMapping, msg: Optional[str] = None
) -> None:
@@ -549,6 +537,42 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
"incorrect types should not match",
)
def test_exact_event_property_contains(self) -> None:
"""Check that exact_event_property_contains conditions work as expected."""
condition = {
"kind": "org.matrix.msc3966.exact_event_property_contains",
"key": "content.value",
"value": "foobaz",
}
self._assert_matches(
condition,
{"value": ["foobaz"]},
"exact value should match",
)
self._assert_matches(
condition,
{"value": ["foobaz", "bugz"]},
"extra values should match",
)
self._assert_not_matches(
condition,
{"value": ["FoobaZ"]},
"values should match and be case-sensitive",
)
self._assert_not_matches(
condition,
{"value": "foobaz"},
"does not search in a string",
)
# it should work on frozendicts too
self._assert_matches(
condition,
freeze({"value": ["foobaz"]}),
"values should match on frozendicts",
)
def test_no_body(self) -> None:
"""Not having a body shouldn't break the evaluator."""
evaluator = self._get_evaluator({})
+5 -5
View File
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional
from typing import Any, List, Optional, Sequence
from twisted.test.proto_helpers import MemoryReactor
@@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
fork_point: List[str] = self.get_success(
fork_point: Sequence[str] = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)
@@ -168,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
pl_event = self.get_success(
inject_event(
self.hs,
prev_event_ids=prev_events,
prev_event_ids=list(prev_events),
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
@@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
)
# this is the point in the DAG where we make a fork
fork_point: List[str] = self.get_success(
fork_point: Sequence[str] = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)
@@ -323,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
e = self.get_success(
inject_event(
self.hs,
prev_event_ids=prev_events,
prev_event_ids=list(prev_events),
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
@@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
room_id = self.helper.create_room_as("@bob:test")
# Mark the room as partial-stated.
self.get_success(
self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1")
self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1")
)
worker = self.make_worker_hs("synapse.app.generic_worker")
+3 -1
View File
@@ -13,7 +13,7 @@
# limitations under the License.
from unittest.mock import Mock
from synapse.handlers.typing import RoomMember
from synapse.handlers.typing import RoomMember, TypingWriterHandler
from synapse.replication.tcp.streams import TypingStream
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -33,6 +33,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
def test_typing(self) -> None:
typing = self.hs.get_typing_handler()
assert isinstance(typing, TypingWriterHandler)
self.reconnect()
@@ -88,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
sends the proper position and RDATA).
"""
typing = self.hs.get_typing_handler()
assert isinstance(typing, TypingWriterHandler)
self.reconnect()
+1
View File
@@ -127,6 +127,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
# ... updating the cache ID gen on the master still shouldn't cause the
# deferred to wake up.
assert store._cache_id_gen is not None
ctx = store._cache_id_gen.get_next()
self.get_success(ctx.__aenter__())
self.get_success(ctx.__aexit__(None, None, None))
@@ -16,6 +16,7 @@ from unittest.mock import Mock
from synapse.api.constants import EventTypes, Membership
from synapse.events.builder import EventBuilderFactory
from synapse.handlers.typing import TypingWriterHandler
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client import login, room
from synapse.types import UserID, create_requester
@@ -174,6 +175,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
token = self.login("user3", "pass")
typing_handler = self.hs.get_typing_handler()
assert isinstance(typing_handler, TypingWriterHandler)
sent_on_1 = False
sent_on_2 = False
+1
View File
@@ -50,6 +50,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
assert user_dict is not None
token_id = user_dict.token_id
self.get_success(
+3 -2
View File
@@ -2913,7 +2913,8 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
other_user_tok = self.login("user", "pass")
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
storage_controllers = self.hs.get_storage_controllers()
persistence = self.hs.get_storage_controllers().persistence
assert persistence is not None
# Create two rooms, one with a local user only and one with both a local
# and remote user.
@@ -2940,7 +2941,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
context = self.get_success(unpersisted_context.persist(event))
self.get_success(storage_controllers.persistence.persist_event(event, context))
self.get_success(persistence.persist_event(event, context))
# Now get rooms
url = "/_synapse/admin/v1/users/@joiner:remote_hs/joined_rooms"
+11 -4
View File
@@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -33,9 +35,14 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
async def check_username(username: str) -> bool:
if username == "allowed":
return True
async def check_username(
localpart: str,
guest_access_token: Optional[str] = None,
assigned_user_id: Optional[str] = None,
inhibit_user_in_use_error: bool = False,
) -> None:
if localpart == "allowed":
return
raise SynapseError(
400,
"User ID already taken.",
@@ -43,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase):
)
handler = self.hs.get_registration_handler()
handler.check_username = check_username
handler.check_username = check_username # type: ignore[assignment]
def test_username_available(self) -> None:
"""

Some files were not shown because too many files have changed in this diff Show More