diff --git a/.github/workflows/triage_labelled.yml b/.github/workflows/triage_labelled.yml index e506be393f..c9ed653b62 100644 --- a/.github/workflows/triage_labelled.yml +++ b/.github/workflows/triage_labelled.yml @@ -16,6 +16,10 @@ jobs: with: project-url: "https://github.com/orgs/matrix-org/projects/67" github-token: ${{ secrets.ELEMENT_BOT_TOKEN }} + # This action will error if the issue already exists on the project. Which is + # common as `X-Needs-Info` will often be added to issues that are already in + # the triage queue. Prevent the whole job from failing in this case. + continue-on-error: true - name: Set status env: GITHUB_TOKEN: ${{ secrets.ELEMENT_BOT_TOKEN }} diff --git a/CHANGES.md b/CHANGES.md index d75a55d1e5..0616664741 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,12 @@ +# Synapse 1.136.0 (2025-08-12) + +Note: This release includes the security fixes from `1.135.2` and `1.136.0rc2`, detailed below. + +### Bugfixes + +- Fix bug introduced in 1.135.2 and 1.136.0rc2 where the [Make Room Admin API](https://element-hq.github.io/synapse/latest/admin_api/rooms.html#make-room-admin-api) would not treat a room v12's creator power level as the highest in room. ([\#18805](https://github.com/element-hq/synapse/issues/18805)) + + # Synapse 1.135.2 (2025-08-11) This is the Synapse portion of the [Matrix coordinated security release](https://matrix.org/blog/2025/07/security-predisclosure/). This release includes support for [room version](https://spec.matrix.org/v1.15/rooms/) 12 which fixes a number of security vulnerabilities, including [CVE-2025-49090](https://www.cve.org/CVERecord?id=CVE-2025-49090). @@ -23,7 +32,77 @@ Two patched Synapse releases are now available: - Speed up upgrading a room with large numbers of banned users. ([\#18574](https://github.com/element-hq/synapse/issues/18574)) +# Synapse 1.136.0rc2 (2025-08-11) +- Update MSC4293 redaction logic for room v12. ([\#80](https://github.com/element-hq/synapse/issues/80)) + +### Internal Changes + +- Add a parameter to `upgrade_rooms(..)` to allow auto join local users. ([\#83](https://github.com/element-hq/synapse/issues/83)) + + +# Synapse 1.136.0rc1 (2025-08-05) + +Please check [the relevant section in the upgrade notes](https://github.com/element-hq/synapse/blob/develop/docs/upgrade.md#upgrading-to-v11360) as this release contains changes to MAS support, metrics labels and the module API which may require your attention when upgrading. + +### Features + +- Add configurable rate limiting for the creation of rooms. ([\#18514](https://github.com/element-hq/synapse/issues/18514)) +- Add support for [MSC4293](https://github.com/matrix-org/matrix-spec-proposals/pull/4293) - Redact on Kick/Ban. ([\#18540](https://github.com/element-hq/synapse/issues/18540)) +- When admins enable themselves to see soft-failed events, they will also see if the cause is due to the policy server flagging them as spam via `unsigned`. ([\#18585](https://github.com/element-hq/synapse/issues/18585)) +- Add ability to configure forward/outbound proxy via homeserver config instead of environment variables. See `http_proxy`, `https_proxy`, `no_proxy_hosts`. ([\#18686](https://github.com/element-hq/synapse/issues/18686)) +- Advertise experimental support for [MSC4306](https://github.com/matrix-org/matrix-spec-proposals/pull/4306) (Thread Subscriptions) through `/_matrix/clients/versions` if enabled. ([\#18722](https://github.com/element-hq/synapse/issues/18722)) +- Stabilise support for delegating authentication to [Matrix Authentication Service](https://github.com/element-hq/matrix-authentication-service/). ([\#18759](https://github.com/element-hq/synapse/issues/18759)) +- Implement the push rules for experimental [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306). ([\#18762](https://github.com/element-hq/synapse/issues/18762)) + +### Bugfixes + +- Allow return code 403 (allowed by C2S Spec since v1.2) when fetching profiles via federation. ([\#18696](https://github.com/element-hq/synapse/issues/18696)) +- Register the MSC4306 (Thread Subscriptions) endpoints in the CS API when the experimental feature is enabled. ([\#18726](https://github.com/element-hq/synapse/issues/18726)) +- Fix a long-standing bug where suspended users could not have server notices sent to them (a 403 was returned to the admin). ([\#18750](https://github.com/element-hq/synapse/issues/18750)) +- Fix an issue that could cause logcontexts to be lost on rate-limited requests. Found by @realtyem. ([\#18763](https://github.com/element-hq/synapse/issues/18763)) +- Fix invalidation of storage cache that was broken in 1.135.0. ([\#18786](https://github.com/element-hq/synapse/issues/18786)) + +### Improved Documentation + +- Minor improvements to README. ([\#18700](https://github.com/element-hq/synapse/issues/18700)) +- Document that there can be multiple workers handling the `receipts` stream. ([\#18760](https://github.com/element-hq/synapse/issues/18760)) +- Improve worker documentation for some device paths. ([\#18761](https://github.com/element-hq/synapse/issues/18761)) + +### Deprecations and Removals + +- Deprecate `run_as_background_process` exported as part of the module API interface in favor of `ModuleApi.run_as_background_process`. See [the relevant section in the upgrade notes](https://github.com/element-hq/synapse/blob/develop/docs/upgrade.md#upgrading-to-v11360) for more information. ([\#18737](https://github.com/element-hq/synapse/issues/18737)) + +### Internal Changes + +- Add debug logging for HMAC digest verification failures when using the admin API to register users. ([\#18474](https://github.com/element-hq/synapse/issues/18474)) +- Speed up upgrading a room with large numbers of banned users. ([\#18574](https://github.com/element-hq/synapse/issues/18574)) +- Fix config documentation generation script on Windows by enforcing UTF-8. ([\#18580](https://github.com/element-hq/synapse/issues/18580)) +- Refactor cache, background process, `Counter`, `LaterGauge`, `GaugeBucketCollector`, `Histogram`, and `Gauge` metrics to be homeserver-scoped. ([\#18656](https://github.com/element-hq/synapse/issues/18656), [\#18714](https://github.com/element-hq/synapse/issues/18714), [\#18715](https://github.com/element-hq/synapse/issues/18715), [\#18724](https://github.com/element-hq/synapse/issues/18724), [\#18753](https://github.com/element-hq/synapse/issues/18753), [\#18725](https://github.com/element-hq/synapse/issues/18725), [\#18670](https://github.com/element-hq/synapse/issues/18670), [\#18748](https://github.com/element-hq/synapse/issues/18748), [\#18751](https://github.com/element-hq/synapse/issues/18751)) +- Reduce database usage in Sliding Sync by not querying for background update completion after the update is known to be complete. ([\#18718](https://github.com/element-hq/synapse/issues/18718)) +- Improve order of validation and ratelimiting in room creation. ([\#18723](https://github.com/element-hq/synapse/issues/18723)) +- Bump minimum version bound on Twisted to 21.2.0. ([\#18727](https://github.com/element-hq/synapse/issues/18727), [\#18729](https://github.com/element-hq/synapse/issues/18729)) +- Use `twisted.internet.testing` module in tests instead of deprecated `twisted.test.proto_helpers`. ([\#18728](https://github.com/element-hq/synapse/issues/18728)) +- Remove obsolete `/send_event` replication endpoint. ([\#18730](https://github.com/element-hq/synapse/issues/18730)) +- Update metrics linting to be able to handle custom metrics. ([\#18733](https://github.com/element-hq/synapse/issues/18733)) +- Work around `twisted.protocols.amp.TooLong` error by reducing logging in some tests. ([\#18736](https://github.com/element-hq/synapse/issues/18736)) +- Prevent "Move labelled issues to correct projects" GitHub Actions workflow from failing when an issue is already on the project board. ([\#18755](https://github.com/element-hq/synapse/issues/18755)) +- Bump minimum supported Rust version (MSRV) to 1.82.0. Missed in [#18553](https://github.com/element-hq/synapse/pull/18553) (released in Synapse 1.134.0). ([\#18757](https://github.com/element-hq/synapse/issues/18757)) +- Make `Clock.sleep(...)` return a coroutine, so that mypy can catch places where we don't await on it. ([\#18772](https://github.com/element-hq/synapse/issues/18772)) +- Update implementation of [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306) to include automatic subscription conflict prevention as introduced in later drafts. ([\#18756](https://github.com/element-hq/synapse/issues/18756)) + + + +### Updates to locked dependencies + +* Bump gitpython from 3.1.44 to 3.1.45. ([\#18743](https://github.com/element-hq/synapse/issues/18743)) +* Bump mypy-zope from 1.0.12 to 1.0.13. ([\#18744](https://github.com/element-hq/synapse/issues/18744)) +* Bump phonenumbers from 9.0.9 to 9.0.10. ([\#18741](https://github.com/element-hq/synapse/issues/18741)) +* Bump ruff from 0.12.4 to 0.12.5. ([\#18742](https://github.com/element-hq/synapse/issues/18742)) +* Bump sentry-sdk from 2.32.0 to 2.33.2. ([\#18745](https://github.com/element-hq/synapse/issues/18745)) +* Bump tokio from 1.46.1 to 1.47.0. ([\#18740](https://github.com/element-hq/synapse/issues/18740)) +* Bump types-jsonschema from 4.24.0.20250708 to 4.25.0.20250720. ([\#18703](https://github.com/element-hq/synapse/issues/18703)) +* Bump types-psycopg2 from 2.9.21.20250516 to 2.9.21.20250718. ([\#18706](https://github.com/element-hq/synapse/issues/18706)) # Synapse 1.135.0 (2025-08-01) diff --git a/Cargo.lock b/Cargo.lock index d11dc3e8a8..0ddd5ab396 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13,9 +13,9 @@ dependencies = [ [[package]] name = "adler2" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "aho-corasick" @@ -46,15 +46,15 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.3.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "backtrace" -version = "0.3.74" +version = "0.3.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" dependencies = [ "addr2line", "cfg-if", @@ -73,9 +73,9 @@ checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bitflags" -version = "2.8.0" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +checksum = "1b8e56985ec62d17e9c1001dc89c88ecd7dc08e47eba5ec7c29c7b5eeecde967" [[package]] name = "blake2" @@ -97,9 +97,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.16.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +checksum = "46c5e41b57b8bba42a04676d81cb89e9ee8e859a1a66f80a5a72e1cb76b34d43" [[package]] name = "bytes" @@ -109,18 +109,18 @@ checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" [[package]] name = "cc" -version = "1.2.19" +version = "1.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e3a13707ac958681c13b39b458c073d0d9bc8a22cb1b2f4c8e55eb72c13f362" +checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" dependencies = [ "shlex", ] [[package]] name = "cfg-if" -version = "1.0.0" +version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "9555578bc9e57714c812a1f84e4fc5b4d21fcb063490c624de019f7464c91268" [[package]] name = "cfg_aliases" @@ -130,9 +130,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "core-foundation" -version = "0.10.0" +version = "0.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +checksum = "b2a6cd9ae233e7f62ba4e9353e81a88df7fc8a5987b8d445b4d90c879bd156f6" dependencies = [ "core-foundation-sys", "libc", @@ -155,9 +155,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.12" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -316,25 +316,29 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", + "js-sys", "libc", - "wasi 0.11.0+wasi-snapshot-preview1", + "wasi 0.11.1+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] name = "getrandom" -version = "0.3.1" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", + "js-sys", "libc", - "wasi 0.13.3+wasi-0.2.2", - "windows-targets", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", + "wasm-bindgen", ] [[package]] @@ -345,9 +349,9 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "h2" -version = "0.4.9" +version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75249d144030531f8dee69fe9cea04d3edf809a017ae445e2abdff6629e86633" +checksum = "17da50a276f1e01e0ba6c029e47b7100754904ee8a278f886546e98575380785" dependencies = [ "atomic-waker", "bytes", @@ -364,9 +368,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.2" +version = "0.15.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +checksum = "5971ac85611da7067dbfcabef3c70ebb5606018acd9e2a3903a0da507521e0d5" [[package]] name = "headers" @@ -472,11 +476,10 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.5" +version = "0.27.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" dependencies = [ - "futures-util", "http", "hyper", "hyper-util", @@ -490,9 +493,9 @@ dependencies = [ [[package]] name = "hyper-util" -version = "0.1.14" +version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc2fdfdbff08affe55bb779f33b053aa1fe5dd5b54c257343c17edfa55711bdb" +checksum = "8d9b05277c7e8da2c93a568989bb6207bef0112e8d17df7a6eda4a3cf143bc5e" dependencies = [ "base64", "bytes", @@ -506,24 +509,12 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2", + "socket2 0.6.0", "tokio", "tower-service", "tracing", ] -[[package]] -name = "icu_collections" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" -dependencies = [ - "displaydoc", - "yoke 0.7.5", - "zerofrom", - "zerovec 0.10.4", -] - [[package]] name = "icu_collections" version = "2.0.0" @@ -532,9 +523,9 @@ checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" dependencies = [ "displaydoc", "potential_utf", - "yoke 0.8.0", + "yoke", "zerofrom", - "zerovec 0.11.2", + "zerovec", ] [[package]] @@ -544,13 +535,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ae5921528335e91da1b6c695dbf1ec37df5ac13faa3f91e5640be93aa2fbefd" dependencies = [ "displaydoc", - "icu_collections 2.0.0", + "icu_collections", "icu_locale_core", "icu_locale_data", - "icu_provider 2.0.0", + "icu_provider", "potential_utf", - "tinystr 0.8.1", - "zerovec 0.11.2", + "tinystr", + "zerovec", ] [[package]] @@ -560,10 +551,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" dependencies = [ "displaydoc", - "litemap 0.8.0", - "tinystr 0.8.1", - "writeable 0.6.1", - "zerovec 0.11.2", + "litemap", + "tinystr", + "writeable", + "zerovec", ] [[package]] @@ -572,100 +563,48 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4fdef0c124749d06a743c69e938350816554eb63ac979166590e2b4ee4252765" -[[package]] -name = "icu_locid" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" -dependencies = [ - "displaydoc", - "litemap 0.7.5", - "tinystr 0.7.6", - "writeable 0.5.5", - "zerovec 0.10.4", -] - -[[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider 1.5.0", - "tinystr 0.7.6", - "zerovec 0.10.4", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7515e6d781098bf9f7205ab3fc7e9709d34554ae0b21ddbcb5febfa4bc7df11d" - [[package]] name = "icu_normalizer" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" dependencies = [ "displaydoc", - "icu_collections 1.5.0", + "icu_collections", "icu_normalizer_data", "icu_properties", - "icu_provider 1.5.0", + "icu_provider", "smallvec", - "utf16_iter", - "utf8_iter", - "write16", - "zerovec 0.10.4", + "zerovec", ] [[package]] name = "icu_normalizer_data" -version = "1.5.1" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5e8338228bdc8ab83303f16b797e177953730f601a96c25d10cb3ab0daa0cb7" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" [[package]] name = "icu_properties" -version = "1.5.1" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +checksum = "016c619c1eeb94efb86809b015c58f479963de65bdb6253345c1a1276f22e32b" dependencies = [ "displaydoc", - "icu_collections 1.5.0", - "icu_locid_transform", + "icu_collections", + "icu_locale_core", "icu_properties_data", - "icu_provider 1.5.0", - "tinystr 0.7.6", - "zerovec 0.10.4", + "icu_provider", + "potential_utf", + "zerotrie", + "zerovec", ] [[package]] name = "icu_properties_data" -version = "1.5.1" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85fb8799753b75aee8d2a21d7c14d9f38921b54b3dbda10f5a3c7a7b82dba5e2" - -[[package]] -name = "icu_provider" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_provider_macros", - "stable_deref_trait", - "tinystr 0.7.6", - "writeable 0.5.5", - "yoke 0.7.5", - "zerofrom", - "zerovec 0.10.4", -] +checksum = "298459143998310acd25ffe6810ed544932242d3f07083eee1084d83a71bd632" [[package]] name = "icu_provider" @@ -676,23 +615,12 @@ dependencies = [ "displaydoc", "icu_locale_core", "stable_deref_trait", - "tinystr 0.8.1", - "writeable 0.6.1", - "yoke 0.8.0", + "tinystr", + "writeable", + "yoke", "zerofrom", "zerotrie", - "zerovec 0.11.2", -] - -[[package]] -name = "icu_provider_macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn", + "zerovec", ] [[package]] @@ -703,14 +631,14 @@ checksum = "e185fc13b6401c138cf40db12b863b35f5edf31b88192a545857b41aeaf7d3d3" dependencies = [ "core_maths", "displaydoc", - "icu_collections 2.0.0", + "icu_collections", "icu_locale", "icu_locale_core", - "icu_provider 2.0.0", + "icu_provider", "icu_segmenter_data", "potential_utf", "utf8_iter", - "zerovec 0.11.2", + "zerovec", ] [[package]] @@ -732,9 +660,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" dependencies = [ "icu_normalizer", "icu_properties", @@ -742,9 +670,9 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.9.0" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" +checksum = "fe4cd85333e22411419a0bcae1297d25e58c9443848b11dc6a86fefe8c78a661" dependencies = [ "equivalent", "hashbrown", @@ -752,15 +680,15 @@ dependencies = [ [[package]] name = "indoc" -version = "2.0.5" +version = "2.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" +checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" [[package]] name = "io-uring" -version = "0.7.8" +version = "0.7.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b86e202f00093dcba4275d4636b93ef9dd75d025ae560d2521b45ea28ab49013" +checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" dependencies = [ "bitflags", "cfg-if", @@ -785,9 +713,9 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.11" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "js-sys" @@ -807,9 +735,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.172" +version = "0.2.174" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +checksum = "1171693293099992e19cddea4e8b849964e9846f4acee11b3948bcc337be8776" [[package]] name = "libm" @@ -817,12 +745,6 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" -[[package]] -name = "litemap" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" - [[package]] name = "litemap" version = "0.8.0" @@ -836,10 +758,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] -name = "memchr" -version = "2.7.2" +name = "lru-slab" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + +[[package]] +name = "memchr" +version = "2.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a282da65faaf38286cf3be983213fcf1d2e2a58700e808f83f4ea9a4804bc0" [[package]] name = "memoffset" @@ -858,22 +786,22 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "miniz_oxide" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", ] [[package]] name = "mio" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" +checksum = "78bed444cc8a2160f01cbcf811ef18cac863ad68ae8ca62092e8db51d51c761c" dependencies = [ "libc", - "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys 0.52.0", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.59.0", ] [[package]] @@ -917,9 +845,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "portable-atomic" -version = "1.6.0" +version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" [[package]] name = "potential_utf" @@ -928,20 +856,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" dependencies = [ "serde", - "zerovec 0.11.2", + "zerovec", ] [[package]] name = "ppv-lite86" -version = "0.2.17" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] [[package]] name = "proc-macro2" -version = "1.0.89" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -1032,92 +963,82 @@ dependencies = [ [[package]] name = "quinn" -version = "0.11.5" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c7c5fdde3cdae7203427dc4f0a68fe0ed09833edc525a03456b153b79828684" +checksum = "626214629cda6781b6dc1d316ba307189c85ba657213ce642d9c77670f8202c8" dependencies = [ "bytes", + "cfg_aliases", "pin-project-lite", "quinn-proto", "quinn-udp", "rustc-hash", "rustls", - "socket2", + "socket2 0.5.10", "thiserror", "tokio", "tracing", + "web-time", ] [[package]] name = "quinn-proto" -version = "0.11.8" +version = "0.11.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6" +checksum = "49df843a9161c85bb8aae55f101bc0bac8bcafd637a620d9122fd7e0b2f7422e" dependencies = [ "bytes", - "rand 0.8.5", + "getrandom 0.3.3", + "lru-slab", + "rand", "ring", "rustc-hash", "rustls", + "rustls-pki-types", "slab", "thiserror", "tinyvec", "tracing", + "web-time", ] [[package]] name = "quinn-udp" -version = "0.5.11" +version = "0.5.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "541d0f57c6ec747a90738a52741d3221f7960e8ac2f0ff4b1a63680e033b4ab5" +checksum = "fcebb1209ee276352ef14ff8732e24cc2b02bbac986cd74a4c81bcb2f9881970" dependencies = [ "cfg_aliases", "libc", "once_cell", - "socket2", + "socket2 0.5.10", "tracing", "windows-sys 0.59.0", ] [[package]] name = "quote" -version = "1.0.36" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] [[package]] -name = "rand" -version = "0.8.5" +name = "r-efi" +version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.4", -] +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "rand" -version = "0.9.0" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "rand_chacha 0.9.0", - "rand_core 0.9.0", - "zerocopy", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core 0.6.4", + "rand_chacha", + "rand_core", ] [[package]] @@ -1127,26 +1048,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core 0.9.0", + "rand_core", ] [[package]] name = "rand_core" -version = "0.6.4" +version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.2.15", -] - -[[package]] -name = "rand_core" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b08f3c9802962f7e1b25113931d94f43ed9725bebc59db9d0c3e9a23b67e15ff" -dependencies = [ - "getrandom 0.3.1", - "zerocopy", + "getrandom 0.3.3", ] [[package]] @@ -1163,9 +1074,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.8" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", @@ -1228,7 +1139,7 @@ checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.15", + "getrandom 0.2.16", "libc", "untrusted", "windows-sys 0.52.0", @@ -1236,9 +1147,9 @@ dependencies = [ [[package]] name = "rustc-demangle" -version = "0.1.24" +version = "0.1.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +checksum = "56f7d92ca342cea22a06f2121d944b4fd82af56988c270852495420f961d4ace" [[package]] name = "rustc-hash" @@ -1248,9 +1159,9 @@ checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustls" -version = "0.23.26" +version = "0.23.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0" +checksum = "c0ebcbd2f03de0fc1122ad9bb24b127a5a6cd51d72604a3f3c50ac459762b6cc" dependencies = [ "once_cell", "ring", @@ -1274,15 +1185,19 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "web-time", + "zeroize", +] [[package]] name = "rustls-webpki" -version = "0.103.1" +version = "0.103.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03" +checksum = "0a17884ae0c1b773f1ccd2bd4a8c72f16da897310a98b0e84bf349ad5ead92fc" dependencies = [ "ring", "rustls-pki-types", @@ -1291,15 +1206,15 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.20" +version = "1.0.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" +checksum = "8a0d197bd2c9dc6e53b84da9556a69ba4cdfab8619eb41a8bd1cc2027a0f6b1d" [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "schannel" @@ -1407,29 +1322,36 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "slab" -version = "0.4.9" +version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" -dependencies = [ - "autocfg", -] +checksum = "04dc19736151f35336d325007ac991178d504a119863a2fcb3758cdb5e52c50d" [[package]] name = "smallvec" -version = "1.15.0" +version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" [[package]] name = "socket2" -version = "0.5.9" +version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f5fd57c80058a56cf5c777ab8a126398ece8e442983605d280a44ce79d0edef" +checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678" dependencies = [ "libc", "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -1438,15 +1360,15 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "subtle" -version = "2.5.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.85" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5023162dfcd14ef8f32034d8bcd4cc5ddc61ef7a247c024a33e24e1f24d21b56" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", @@ -1511,34 +1433,24 @@ checksum = "e502f78cdbb8ba4718f566c418c52bc729126ffd16baee5baa718cf25dd5a69a" [[package]] name = "thiserror" -version = "1.0.65" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d11abd9594d9b38965ef50805c5e469ca9cc6f197f883f717e0269a3057b3d5" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.65" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae71770322cbd277e69d762a16c444af02aa0575ac0d174f0b9562d3b37f8602" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", "syn", ] -[[package]] -name = "tinystr" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" -dependencies = [ - "displaydoc", - "zerovec 0.10.4", -] - [[package]] name = "tinystr" version = "0.8.1" @@ -1546,7 +1458,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" dependencies = [ "displaydoc", - "zerovec 0.11.2", + "zerovec", ] [[package]] @@ -1566,9 +1478,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.46.1" +version = "1.47.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc3a2344dafbe23a245241fe8b09735b521110d30fcefbbd5feb1797ca35d17" +checksum = "43864ed400b6043a4757a25c7a64a8efde741aed79a056a2fb348a406701bb35" dependencies = [ "backtrace", "bytes", @@ -1577,8 +1489,8 @@ dependencies = [ "mio", "pin-project-lite", "slab", - "socket2", - "windows-sys 0.52.0", + "socket2 0.6.0", + "windows-sys 0.59.0", ] [[package]] @@ -1676,9 +1588,9 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "typenum" -version = "1.17.0" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" [[package]] name = "ulid" @@ -1686,21 +1598,21 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "470dbf6591da1b39d43c14523b2b469c86879a53e8b758c8e090a470fe7b1fbe" dependencies = [ - "rand 0.9.0", + "rand", "web-time", ] [[package]] name = "unicode-ident" -version = "1.0.12" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] name = "unindent" -version = "0.2.3" +version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +checksum = "7264e107f553ccae879d21fbea1d6724ac785e8c3bfc762137959b5802826ef3" [[package]] name = "untrusted" @@ -1719,12 +1631,6 @@ dependencies = [ "percent-encoding", ] -[[package]] -name = "utf16_iter" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" - [[package]] name = "utf8_iter" version = "1.0.4" @@ -1733,9 +1639,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "version_check" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" [[package]] name = "want" @@ -1748,15 +1654,15 @@ dependencies = [ [[package]] name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" +version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" [[package]] name = "wasi" -version = "0.13.3+wasi-0.2.2" +version = "0.14.2+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" dependencies = [ "wit-bindgen-rt", ] @@ -1949,43 +1855,19 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "wit-bindgen-rt" -version = "0.33.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" dependencies = [ "bitflags", ] -[[package]] -name = "write16" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" - -[[package]] -name = "writeable" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" - [[package]] name = "writeable" version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" -[[package]] -name = "yoke" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" -dependencies = [ - "serde", - "stable_deref_trait", - "yoke-derive 0.7.5", - "zerofrom", -] - [[package]] name = "yoke" version = "0.8.0" @@ -1994,22 +1876,10 @@ checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" dependencies = [ "serde", "stable_deref_trait", - "yoke-derive 0.8.0", + "yoke-derive", "zerofrom", ] -[[package]] -name = "yoke-derive" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "synstructure", -] - [[package]] name = "yoke-derive" version = "0.8.0" @@ -2024,18 +1894,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.17" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa91407dacce3a68c56de03abe2760159582b846c6a4acd2f456618087f12713" +checksum = "1039dd0d3c310cf05de012d8a39ff557cb0d23087fd44cad61df08fc31907a2f" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.17" +version = "0.8.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06718a168365cad3d5ff0bb133aad346959a2074bd4a85c121255a11304a8626" +checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" dependencies = [ "proc-macro2", "quote", @@ -2076,17 +1946,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" dependencies = [ "displaydoc", -] - -[[package]] -name = "zerovec" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" -dependencies = [ - "yoke 0.7.5", + "yoke", "zerofrom", - "zerovec-derive 0.10.3", ] [[package]] @@ -2095,20 +1956,9 @@ version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428" dependencies = [ - "yoke 0.8.0", + "yoke", "zerofrom", - "zerovec-derive 0.11.1", -] - -[[package]] -name = "zerovec-derive" -version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" -dependencies = [ - "proc-macro2", - "quote", - "syn", + "zerovec-derive", ] [[package]] diff --git a/README.rst b/README.rst index 8974990ed1..92854f631c 100644 --- a/README.rst +++ b/README.rst @@ -8,7 +8,7 @@ Synapse is an open source `Matrix `__ homeserver implementation, written and maintained by `Element `_. `Matrix `__ is the open standard for -secure and interoperable real time communications. You can directly run +secure and interoperable real-time communications. You can directly run and manage the source code in this repository, available under an AGPL license (or alternatively under a commercial license from Element). There is no support provided by Element unless you have a @@ -23,13 +23,13 @@ ESS builds on Synapse to offer a complete Matrix-based backend including the ful `Admin Console product `_, giving admins the power to easily manage an organization-wide deployment. It includes advanced identity management, auditing, -moderation and data retention options as well as Long Term Support and -SLAs. ESS can be used to support any Matrix-based frontend client. +moderation and data retention options as well as Long-Term Support and +SLAs. ESS supports any Matrix-compatible client. .. contents:: -🛠️ Installing and configuration -=============================== +🛠️ Installation and configuration +================================== The Synapse documentation describes `how to install Synapse `_. We recommend using `Docker images `_ or `Debian packages from Matrix.org @@ -133,7 +133,7 @@ connect from a client: see An easy way to get started is to login or register via Element at https://app.element.io/#/login or https://app.element.io/#/register respectively. You will need to change the server you are logging into from ``matrix.org`` -and instead specify a Homeserver URL of ``https://:8448`` +and instead specify a homeserver URL of ``https://:8448`` (or just ``https://`` if you are using a reverse proxy). If you prefer to use another client, refer to our `client breakdown `_. @@ -162,16 +162,15 @@ the public internet. Without it, anyone can freely register accounts on your hom This can be exploited by attackers to create spambots targeting the rest of the Matrix federation. -Your new user name will be formed partly from the ``server_name``, and partly -from a localpart you specify when you create the account. Your name will take -the form of:: +Your new Matrix ID will be formed partly from the ``server_name``, and partly +from a localpart you specify when you create the account in the form of:: @localpart:my.domain.name (pronounced "at localpart on my dot domain dot name"). As when logging in, you will need to specify a "Custom server". Specify your -desired ``localpart`` in the 'User name' box. +desired ``localpart`` in the 'Username' box. 🎯 Troubleshooting and support ============================== @@ -209,10 +208,10 @@ Identity servers have the job of mapping email addresses and other 3rd Party IDs (3PIDs) to Matrix user IDs, as well as verifying the ownership of 3PIDs before creating that mapping. -**They are not where accounts or credentials are stored - these live on home -servers. Identity Servers are just for mapping 3rd party IDs to matrix IDs.** +**Identity servers do not store accounts or credentials - these are stored and managed on homeservers. +Identity Servers are just for mapping 3rd Party IDs to Matrix IDs.** -This process is very security-sensitive, as there is obvious risk of spam if it +This process is highly security-sensitive, as there is an obvious risk of spam if it is too easy to sign up for Matrix accounts or harvest 3PID data. In the longer term, we hope to create a decentralised system to manage it (`matrix-doc #712 `_), but in the meantime, @@ -238,9 +237,9 @@ email address. We welcome contributions to Synapse from the community! The best place to get started is our `guide for contributors `_. -This is part of our larger `documentation `_, which includes - +This is part of our broader `documentation `_, which includes information for Synapse developers as well as Synapse administrators. + Developers might be particularly interested in: * `Synapse's database schema `_, diff --git a/build_rust.py b/build_rust.py index d2726cee26..5c796af461 100644 --- a/build_rust.py +++ b/build_rust.py @@ -19,17 +19,17 @@ def build(setup_kwargs: Dict[str, Any]) -> None: # This flag is a no-op in the latest versions. Instead, we need to # specify this in the `bdist_wheel` config below. py_limited_api=True, - # We force always building in release mode, as we can't tell the - # difference between using `poetry` in development vs production. + # We always build in release mode, as we can't distinguish + # between using `poetry` in development vs production. debug=False, ) setup_kwargs.setdefault("rust_extensions", []).append(extension) setup_kwargs["zip_safe"] = False - # We lookup the minimum supported python version by looking at - # `python_requires` (e.g. ">=3.9.0,<4.0.0") and finding the first python + # We look up the minimum supported Python version with + # `python_requires` (e.g. ">=3.9.0,<4.0.0") and finding the first Python # version that matches. We then convert that into the `py_limited_api` form, - # e.g. cp39 for python 3.9. + # e.g. cp39 for Python 3.9. py_limited_api: str python_bounds = SpecifierSet(setup_kwargs["python_requires"]) for minor_version in itertools.count(start=8): diff --git a/contrib/grafana/synapse.json b/contrib/grafana/synapse.json index 62b58a199d..e23afcf2d3 100644 --- a/contrib/grafana/synapse.json +++ b/contrib/grafana/synapse.json @@ -4396,7 +4396,7 @@ "exemplar": false, "expr": "(time() - max without (job, index, host) (avg_over_time(synapse_federation_last_received_pdu_time[10m]))) / 60", "instant": false, - "legendFormat": "{{server_name}} ", + "legendFormat": "{{origin_server_name}} ", "range": true, "refId": "A" } @@ -4518,7 +4518,7 @@ "exemplar": false, "expr": "(time() - max without (job, index, host) (avg_over_time(synapse_federation_last_sent_pdu_time[10m]))) / 60", "instant": false, - "legendFormat": "{{server_name}}", + "legendFormat": "{{destination_server_name}}", "range": true, "refId": "A" } diff --git a/debian/changelog b/debian/changelog index bfaa18ac4f..838e7115de 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,21 @@ +matrix-synapse-py3 (1.136.0) stable; urgency=medium + + * New Synapse release 1.136.0. + + -- Synapse Packaging team Tue, 12 Aug 2025 13:18:03 +0100 + +matrix-synapse-py3 (1.136.0~rc2) stable; urgency=medium + + * New Synapse release 1.136.0rc2. + + -- Synapse Packaging team Mon, 11 Aug 2025 12:18:52 -0600 + +matrix-synapse-py3 (1.136.0~rc1) stable; urgency=medium + + * New Synapse release 1.136.0rc1. + + -- Synapse Packaging team Tue, 05 Aug 2025 08:13:30 -0600 + matrix-synapse-py3 (1.135.2) stable; urgency=medium * New Synapse release 1.135.2. diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 48b44ddf90..168c385191 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -98,6 +98,10 @@ rc_delayed_event_mgmt: per_second: 9999 burst_count: 9999 +rc_room_creation: + per_second: 9999 + burst_count: 9999 + federation_rr_transactions_per_room_per_second: 9999 allow_device_name_lookup_over_federation: true diff --git a/docs/admin_api/client_server_api_extensions.md b/docs/admin_api/client_server_api_extensions.md index 9cf74b23eb..08fac6289b 100644 --- a/docs/admin_api/client_server_api_extensions.md +++ b/docs/admin_api/client_server_api_extensions.md @@ -22,4 +22,46 @@ To receive soft failed events in APIs like `/sync` and `/messages`, set `return_ to `true` in the admin client config. When `false`, the normal behaviour of these endpoints is to exclude soft failed events. +**Note**: If the policy server flagged the event as spam and that caused soft failure, that will be indicated +in the event's `unsigned` content like so: + +```json +{ + "type": "m.room.message", + "other": "event_fields_go_here", + "unsigned": { + "io.element.synapse.soft_failed": true, + "io.element.synapse.policy_server_spammy": true + } +} +``` + Default: `false` + +## See events marked spammy by policy servers + +Learn more about policy servers from [MSC4284](https://github.com/matrix-org/matrix-spec-proposals/pull/4284). + +Similar to `return_soft_failed_events`, clients logged in with admin accounts can see events which were +flagged by the policy server as spammy (and thus soft failed) by setting `return_policy_server_spammy_events` +to `true`. + +`return_policy_server_spammy_events` may be `true` while `return_soft_failed_events` is `false` to only see +policy server-flagged events. When `return_soft_failed_events` is `true` however, `return_policy_server_spammy_events` +is always `true`. + +Events which were flagged by the policy will be flagged as `io.element.synapse.policy_server_spammy` in the +event's `unsigned` content, like so: + +```json +{ + "type": "m.room.message", + "other": "event_fields_go_here", + "unsigned": { + "io.element.synapse.soft_failed": true, + "io.element.synapse.policy_server_spammy": true + } +} +``` + +Default: `true` if `return_soft_failed_events` is `true`, otherwise `false` diff --git a/docs/setup/forward_proxy.md b/docs/setup/forward_proxy.md index f02c7b5fc5..eab8bb9951 100644 --- a/docs/setup/forward_proxy.md +++ b/docs/setup/forward_proxy.md @@ -7,8 +7,23 @@ proxy is supported, not SOCKS proxy or anything else. ## Configure -The `http_proxy`, `https_proxy`, `no_proxy` environment variables are used to -specify proxy settings. The environment variable is not case sensitive. +The proxy settings can be configured in the homeserver configuration file via +[`http_proxy`](../usage/configuration/config_documentation.md#http_proxy), +[`https_proxy`](../usage/configuration/config_documentation.md#https_proxy), and +[`no_proxy_hosts`](../usage/configuration/config_documentation.md#no_proxy_hosts). + +`homeserver.yaml` example: +```yaml +http_proxy: http://USERNAME:PASSWORD@10.0.1.1:8080/ +https_proxy: http://USERNAME:PASSWORD@proxy.example.com:8080/ +no_proxy_hosts: + - master.hostname.example.com + - 10.1.0.0/16 + - 172.30.0.0/16 +``` + +The proxy settings can also be configured via the `http_proxy`, `https_proxy`, +`no_proxy` environment variables. The environment variable is not case sensitive. - `http_proxy`: Proxy server to use for HTTP requests. - `https_proxy`: Proxy server to use for HTTPS requests. - `no_proxy`: Comma-separated list of hosts, IP addresses, or IP ranges in CIDR @@ -44,7 +59,7 @@ The proxy will be **used** for: - phone-home stats - recaptcha validation - CAS auth validation -- OpenID Connect +- OpenID Connect (OIDC) - Outbound federation - Federation (checking public key revocation) - Fetching public keys of other servers @@ -53,7 +68,7 @@ The proxy will be **used** for: It will **not be used** for: - Application Services -- Identity servers +- Matrix Identity servers - In worker configurations - connections between workers - connections from workers to Redis diff --git a/docs/upgrade.md b/docs/upgrade.md index e79ca93c04..082d204b58 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -117,6 +117,77 @@ each upgrade are complete before moving on to the next upgrade, to avoid stacking them up. You can monitor the currently running background updates with [the Admin API](usage/administration/admin_api/background_updates.html#status). +# Upgrading to v1.136.0 + +## Deprecate `run_as_background_process` exported as part of the module API interface in favor of `ModuleApi.run_as_background_process` + +The `run_as_background_process` function is now a method of the `ModuleApi` class. If +you were using the function directly from the module API, it will continue to work fine +but the background process metrics will not include an accurate `server_name` label. +This kind of metric labeling isn't relevant for many use cases and is used to +differentiate Synapse instances running in the same Python process (relevant to Synapse +Pro: Small Hosts). We recommend updating your usage to use the new +`ModuleApi.run_as_background_process` method to stay on top of future changes. + +
+Example run_as_background_process upgrade + +Before: +```python +class MyModule: + def __init__(self, module_api: ModuleApi) -> None: + run_as_background_process(__name__ + ":setup_database", self.setup_database) +``` + +After: +```python +class MyModule: + def __init__(self, module_api: ModuleApi) -> None: + module_api.run_as_background_process(__name__ + ":setup_database", self.setup_database) +``` + +
+ +## Metric labels have changed on `synapse_federation_last_received_pdu_time` and `synapse_federation_last_sent_pdu_time` + +Previously, the `synapse_federation_last_received_pdu_time` and +`synapse_federation_last_sent_pdu_time` metrics both used the `server_name` label to +differentiate between different servers that we send and receive events from. + +Since we're now using the `server_name` label to differentiate between different Synapse +homeserver instances running in the same process, these metrics have been changed as follows: + + - `synapse_federation_last_received_pdu_time` now uses the `origin_server_name` label + - `synapse_federation_last_sent_pdu_time` now uses the `destination_server_name` label + +The Grafana dashboard JSON in `contrib/grafana/synapse.json` has been updated to reflect +this change but you will need to manually update your own existing Grafana dashboards +using these metrics. + +## Stable integration with Matrix Authentication Service + +Support for [Matrix Authentication Service (MAS)](https://github.com/element-hq/matrix-authentication-service) is now stable, with a simplified configuration. +This stable integration requires MAS 0.20.0 or later. + +The existing `experimental_features.msc3861` configuration option is now deprecated and will be removed in Synapse v1.137.0. + +Synapse deployments already using MAS should now use the new configuration options: + +```yaml +matrix_authentication_service: + # Enable the MAS integration + enabled: true + # The base URL where Synapse will contact MAS + endpoint: http://localhost:8080 + # The shared secret used to authenticate MAS requests, must be the same as `matrix.secret` in the MAS configuration + # See https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#matrix + secret: "asecurerandomsecretstring" +``` + +They must remove the `experimental_features.msc3861` configuration option from their configuration. + +They can also remove the client previously used by Synapse [in the MAS configuration](https://element-hq.github.io/matrix-authentication-service/reference/configuration.html#clients) as it is no longer in use. + # Upgrading to v1.135.0 ## `on_user_registration` module API callback may now run on any worker @@ -137,10 +208,10 @@ native ICU library on your system is no longer required. ## Documented endpoint which can be delegated to a federation worker The endpoint `^/_matrix/federation/v1/version$` can be delegated to a federation -worker. This is not new behaviour, but had not been documented yet. The -[list of delegatable endpoints](workers.md#synapseappgeneric_worker) has +worker. This is not new behaviour, but had not been documented yet. The +[list of delegatable endpoints](workers.md#synapseappgeneric_worker) has been updated to include it. Make sure to check your reverse proxy rules if you -are using workers. +are using workers. # Upgrading to v1.126.0 diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 44aab77e5a..68303308cd 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -610,6 +610,61 @@ manhole_settings: ssh_pub_key_path: CONFDIR/id_rsa.pub ``` --- +### `http_proxy` + +*(string|null)* Proxy server to use for HTTP requests. +For more details, see the [forward proxy documentation](../../setup/forward_proxy.md). There is no default for this option. + +Example configuration: +```yaml +http_proxy: http://USERNAME:PASSWORD@10.0.1.1:8080/ +``` +--- +### `https_proxy` + +*(string|null)* Proxy server to use for HTTPS requests. +For more details, see the [forward proxy documentation](../../setup/forward_proxy.md). There is no default for this option. + +Example configuration: +```yaml +https_proxy: http://USERNAME:PASSWORD@proxy.example.com:8080/ +``` +--- +### `no_proxy_hosts` + +*(array)* List of hosts, IP addresses, or IP ranges in CIDR format which should not use the proxy. Synapse will directly connect to these hosts. +For more details, see the [forward proxy documentation](../../setup/forward_proxy.md). There is no default for this option. + +Example configuration: +```yaml +no_proxy_hosts: +- master.hostname.example.com +- 10.1.0.0/16 +- 172.30.0.0/16 +``` +--- +### `matrix_authentication_service` + +*(object)* The `matrix_authentication_service` setting configures integration with [Matrix Authentication Service (MAS)](https://github.com/element-hq/matrix-authentication-service). + +This setting has the following sub-options: + +* `enabled` (boolean): Whether or not to enable the MAS integration. If this is set to `false`, Synapse will use its legacy internal authentication API. Defaults to `false`. + +* `endpoint` (string): The URL where Synapse can reach MAS. This *must* have the `discovery` and `oauth` resources mounted. Defaults to `"http://localhost:8080"`. + +* `secret` (string|null): A shared secret that will be used to authenticate requests from and to MAS. + +* `secret_path` (string|null): Alternative to `secret`, reading the shared secret from a file. The file should be a plain text file, containing only the secret. Synapse reads the secret from the given file once at startup. + +Example configuration: +```yaml +matrix_authentication_service: + enabled: true + secret: someverysecuresecret + endpoint: http://localhost:8080 +``` +--- ### `dummy_events_threshold` *(integer)* Forward extremities can build up in a room due to networking delays between homeservers. Once this happens in a large room, calculation of the state of that room can become quite expensive. To mitigate this, once the number of forward extremities reaches a given threshold, Synapse will send an `org.matrix.dummy_event` event, which will reduce the forward extremities in the room. @@ -1963,6 +2018,31 @@ rc_reports: burst_count: 20.0 ``` --- +### `rc_room_creation` + +*(object)* Sets rate limits for how often users are able to create rooms. + +This setting has the following sub-options: + +* `per_second` (number): Maximum number of requests a client can send per second. + +* `burst_count` (number): Maximum number of requests a client can send before being throttled. + +Default configuration: +```yaml +rc_room_creation: + per_user: + per_second: 0.016 + burst_count: 10.0 +``` + +Example configuration: +```yaml +rc_room_creation: + per_second: 1.0 + burst_count: 5.0 +``` +--- ### `federation_rr_transactions_per_room_per_second` *(integer)* Sets outgoing federation transaction frequency for sending read-receipts, per-room. diff --git a/docs/workers.md b/docs/workers.md index 59c60dd0ad..c275b4acd5 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -260,7 +260,7 @@ information. ^/_matrix/client/(r0|v3|unstable)/keys/claim$ ^/_matrix/client/(r0|v3|unstable)/room_keys/ ^/_matrix/client/(r0|v3|unstable)/keys/upload - ^/_matrix/client/(api/v1|r0|v3|unstable/keys/device_signing/upload$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/device_signing/upload$ ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/signatures/upload$ # Registration/login requests @@ -532,8 +532,9 @@ the stream writer for the `account_data` stream: ##### The `receipts` stream -The following endpoints should be routed directly to the worker configured as -the stream writer for the `receipts` stream: +The `receipts` stream supports multiple writers. The following endpoints +can be handled by any worker, but should be routed directly to one of the workers +configured as stream writer for the `receipts` stream: ^/_matrix/client/(r0|v3|unstable)/rooms/.*/receipt ^/_matrix/client/(r0|v3|unstable)/rooms/.*/read_markers @@ -555,13 +556,13 @@ the stream writer for the `push_rules` stream: ##### The `device_lists` stream The `device_lists` stream supports multiple writers. The following endpoints -can be handled by any worker, but should be routed directly one of the workers +can be handled by any worker, but should be routed directly to one of the workers configured as stream writer for the `device_lists` stream: ^/_matrix/client/(r0|v3)/delete_devices$ - ^/_matrix/client/(api/v1|r0|v3|unstable)/devices/ + ^/_matrix/client/(api/v1|r0|v3|unstable)/devices(/|$) ^/_matrix/client/(r0|v3|unstable)/keys/upload - ^/_matrix/client/(api/v1|r0|v3|unstable/keys/device_signing/upload$ + ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/device_signing/upload$ ^/_matrix/client/(api/v1|r0|v3|unstable)/keys/signatures/upload$ #### Restrict outbound federation traffic to a specific set of workers diff --git a/mypy.ini b/mypy.ini index cf64248cc5..ae903f858a 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,17 @@ [mypy] namespace_packages = True -plugins = pydantic.mypy, mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py +# Our custom mypy plugin should remain first in this list. +# +# mypy has a limitation where it only chooses the first plugin that returns a non-None +# value for each hook (known-limitation, c.f. +# https://github.com/python/mypy/issues/19524). We workaround this by putting our custom +# plugin first in the plugin order and then manually calling any other conflicting +# plugin hooks in our own plugin followed by our own checks. +# +# If you add a new plugin, make sure to check whether the hooks being used conflict with +# our custom plugin hooks and if so, manually call the other plugin's hooks in our +# custom plugin. (also applies to if the plugin is updated in the future) +plugins = scripts-dev/mypy_synapse_plugin.py, pydantic.mypy, mypy_zope:plugin follow_imports = normal show_error_codes = True show_traceback = True @@ -99,3 +110,6 @@ ignore_missing_imports = True [mypy-multipart.*] ignore_missing_imports = True + +[mypy-mypy_zope.*] +ignore_missing_imports = True diff --git a/poetry.lock b/poetry.lock index 109d7512c0..dc6d7711f7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -504,18 +504,19 @@ smmap = ">=3.0.1,<6" [[package]] name = "gitpython" -version = "3.1.44" +version = "3.1.45" description = "GitPython is a Python library used to interact with Git repositories" optional = false python-versions = ">=3.7" groups = ["dev"] files = [ - {file = "GitPython-3.1.44-py3-none-any.whl", hash = "sha256:9e0e10cda9bed1ee64bc9a6de50e7e38a9c9943241cd7f585f6df3ed28011110"}, - {file = "gitpython-3.1.44.tar.gz", hash = "sha256:c87e30b26253bf5418b01b0660f818967f3c503193838337fe5e573331249269"}, + {file = "gitpython-3.1.45-py3-none-any.whl", hash = "sha256:8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77"}, + {file = "gitpython-3.1.45.tar.gz", hash = "sha256:85b0ee964ceddf211c41b9f27a49086010a190fd8132a24e21f362a4b36a791c"}, ] [package.dependencies] gitdb = ">=4.0.1,<5" +typing-extensions = {version = ">=3.10.0.2", markers = "python_version < \"3.10\""} [package.extras] doc = ["sphinx (>=7.1.2,<7.2)", "sphinx-autodoc-typehints", "sphinx_rtd_theme"] @@ -1453,18 +1454,18 @@ files = [ [[package]] name = "mypy-zope" -version = "1.0.12" +version = "1.0.13" description = "Plugin for mypy to support zope interfaces" optional = false python-versions = "*" groups = ["dev"] files = [ - {file = "mypy_zope-1.0.12-py3-none-any.whl", hash = "sha256:f2ecf169f886fbc266e9339db0c2f3818528a7536b9bb4f5ece1d5854dc2f27c"}, - {file = "mypy_zope-1.0.12.tar.gz", hash = "sha256:d6f8f99eb5644885553b4ec7afc8d68f5daf412c9bf238ec3c36b65d97df6cbe"}, + {file = "mypy_zope-1.0.13-py3-none-any.whl", hash = "sha256:13740c4cbc910cca2c143c6709e1c483c991abeeeb7b629ad6f73d8ac1edad15"}, + {file = "mypy_zope-1.0.13.tar.gz", hash = "sha256:63fb4d035ea874baf280dc69e714dcde4bd2a4a4837a0fd8d90ce91bea510f99"}, ] [package.dependencies] -mypy = ">=1.0.0,<1.17.0" +mypy = ">=1.0.0,<1.18.0" "zope.interface" = "*" "zope.schema" = "*" @@ -1542,14 +1543,14 @@ files = [ [[package]] name = "phonenumbers" -version = "9.0.9" +version = "9.0.10" description = "Python version of Google's common library for parsing, formatting, storing and validating international phone numbers." optional = false python-versions = "*" groups = ["main"] files = [ - {file = "phonenumbers-9.0.9-py2.py3-none-any.whl", hash = "sha256:13b91aa153f87675902829b38a556bad54824f9c121b89588bbb5fa8550d97ef"}, - {file = "phonenumbers-9.0.9.tar.gz", hash = "sha256:c640545019a07e68b0bea57a5fede6eef45c7391165d28935f45615f9a567a5b"}, + {file = "phonenumbers-9.0.10-py2.py3-none-any.whl", hash = "sha256:13b12d269be1f2b363c9bc2868656a7e2e8b50f1a1cef629c75005da6c374c6b"}, + {file = "phonenumbers-9.0.10.tar.gz", hash = "sha256:c2d15a6a9d0534b14a7764f51246ada99563e263f65b80b0251d1a760ac4a1ba"}, ] [[package]] @@ -2408,30 +2409,30 @@ files = [ [[package]] name = "ruff" -version = "0.12.4" +version = "0.12.7" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" groups = ["dev"] files = [ - {file = "ruff-0.12.4-py3-none-linux_armv6l.whl", hash = "sha256:cb0d261dac457ab939aeb247e804125a5d521b21adf27e721895b0d3f83a0d0a"}, - {file = "ruff-0.12.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:55c0f4ca9769408d9b9bac530c30d3e66490bd2beb2d3dae3e4128a1f05c7442"}, - {file = "ruff-0.12.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a8224cc3722c9ad9044da7f89c4c1ec452aef2cfe3904365025dd2f51daeae0e"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e9949d01d64fa3672449a51ddb5d7548b33e130240ad418884ee6efa7a229586"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:be0593c69df9ad1465e8a2d10e3defd111fdb62dcd5be23ae2c06da77e8fcffb"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a7dea966bcb55d4ecc4cc3270bccb6f87a337326c9dcd3c07d5b97000dbff41c"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:afcfa3ab5ab5dd0e1c39bf286d829e042a15e966b3726eea79528e2e24d8371a"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c057ce464b1413c926cdb203a0f858cd52f3e73dcb3270a3318d1630f6395bb3"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e64b90d1122dc2713330350626b10d60818930819623abbb56535c6466cce045"}, - {file = "ruff-0.12.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2abc48f3d9667fdc74022380b5c745873499ff827393a636f7a59da1515e7c57"}, - {file = "ruff-0.12.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2b2449dc0c138d877d629bea151bee8c0ae3b8e9c43f5fcaafcd0c0d0726b184"}, - {file = "ruff-0.12.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:56e45bb11f625db55f9b70477062e6a1a04d53628eda7784dce6e0f55fd549eb"}, - {file = "ruff-0.12.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:478fccdb82ca148a98a9ff43658944f7ab5ec41c3c49d77cd99d44da019371a1"}, - {file = "ruff-0.12.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:0fc426bec2e4e5f4c4f182b9d2ce6a75c85ba9bcdbe5c6f2a74fcb8df437df4b"}, - {file = "ruff-0.12.4-py3-none-win32.whl", hash = "sha256:4de27977827893cdfb1211d42d84bc180fceb7b72471104671c59be37041cf93"}, - {file = "ruff-0.12.4-py3-none-win_amd64.whl", hash = "sha256:fe0b9e9eb23736b453143d72d2ceca5db323963330d5b7859d60d101147d461a"}, - {file = "ruff-0.12.4-py3-none-win_arm64.whl", hash = "sha256:0618ec4442a83ab545e5b71202a5c0ed7791e8471435b94e655b570a5031a98e"}, - {file = "ruff-0.12.4.tar.gz", hash = "sha256:13efa16df6c6eeb7d0f091abae50f58e9522f3843edb40d56ad52a5a4a4b6873"}, + {file = "ruff-0.12.7-py3-none-linux_armv6l.whl", hash = "sha256:76e4f31529899b8c434c3c1dede98c4483b89590e15fb49f2d46183801565303"}, + {file = "ruff-0.12.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:789b7a03e72507c54fb3ba6209e4bb36517b90f1a3569ea17084e3fd295500fb"}, + {file = "ruff-0.12.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:2e1c2a3b8626339bb6369116e7030a4cf194ea48f49b64bb505732a7fce4f4e3"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:32dec41817623d388e645612ec70d5757a6d9c035f3744a52c7b195a57e03860"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:47ef751f722053a5df5fa48d412dbb54d41ab9b17875c6840a58ec63ff0c247c"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a828a5fc25a3efd3e1ff7b241fd392686c9386f20e5ac90aa9234a5faa12c423"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:5726f59b171111fa6a69d82aef48f00b56598b03a22f0f4170664ff4d8298efb"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:74e6f5c04c4dd4aba223f4fe6e7104f79e0eebf7d307e4f9b18c18362124bccd"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d0bfe4e77fba61bf2ccadf8cf005d6133e3ce08793bbe870dd1c734f2699a3e"}, + {file = "ruff-0.12.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06bfb01e1623bf7f59ea749a841da56f8f653d641bfd046edee32ede7ff6c606"}, + {file = "ruff-0.12.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e41df94a957d50083fd09b916d6e89e497246698c3f3d5c681c8b3e7b9bb4ac8"}, + {file = "ruff-0.12.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4000623300563c709458d0ce170c3d0d788c23a058912f28bbadc6f905d67afa"}, + {file = "ruff-0.12.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:69ffe0e5f9b2cf2b8e289a3f8945b402a1b19eff24ec389f45f23c42a3dd6fb5"}, + {file = "ruff-0.12.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:a07a5c8ffa2611a52732bdc67bf88e243abd84fe2d7f6daef3826b59abbfeda4"}, + {file = "ruff-0.12.7-py3-none-win32.whl", hash = "sha256:c928f1b2ec59fb77dfdf70e0419408898b63998789cc98197e15f560b9e77f77"}, + {file = "ruff-0.12.7-py3-none-win_amd64.whl", hash = "sha256:9c18f3d707ee9edf89da76131956aba1270c6348bfee8f6c647de841eac7194f"}, + {file = "ruff-0.12.7-py3-none-win_arm64.whl", hash = "sha256:dfce05101dbd11833a0776716d5d1578641b7fddb537fe7fa956ab85d1769b69"}, + {file = "ruff-0.12.7.tar.gz", hash = "sha256:1fc3193f238bc2d7968772c82831a4ff69252f673be371fb49663f0068b7ec71"}, ] [[package]] @@ -2469,15 +2470,15 @@ doc = ["Sphinx", "sphinx-rtd-theme"] [[package]] name = "sentry-sdk" -version = "2.32.0" +version = "2.34.1" description = "Python client for Sentry (https://sentry.io)" optional = true python-versions = ">=3.6" groups = ["main"] markers = "extra == \"all\" or extra == \"sentry\"" files = [ - {file = "sentry_sdk-2.32.0-py2.py3-none-any.whl", hash = "sha256:6cf51521b099562d7ce3606da928c473643abe99b00ce4cb5626ea735f4ec345"}, - {file = "sentry_sdk-2.32.0.tar.gz", hash = "sha256:9016c75d9316b0f6921ac14c8cd4fb938f26002430ac5be9945ab280f78bec6b"}, + {file = "sentry_sdk-2.34.1-py2.py3-none-any.whl", hash = "sha256:b7a072e1cdc5abc48101d5146e1ae680fa81fe886d8d95aaa25a0b450c818d32"}, + {file = "sentry_sdk-2.34.1.tar.gz", hash = "sha256:69274eb8c5c38562a544c3e9f68b5be0a43be4b697f5fd385bf98e4fbe672687"}, ] [package.dependencies] @@ -2931,14 +2932,14 @@ files = [ [[package]] name = "types-jsonschema" -version = "4.24.0.20250708" +version = "4.25.0.20250720" description = "Typing stubs for jsonschema" optional = false python-versions = ">=3.9" groups = ["dev"] files = [ - {file = "types_jsonschema-4.24.0.20250708-py3-none-any.whl", hash = "sha256:d574aa3421d178a8435cc898cf4cf5e5e8c8f37b949c8e3ceeff06da433a18bf"}, - {file = "types_jsonschema-4.24.0.20250708.tar.gz", hash = "sha256:a910e4944681cbb1b18a93ffb502e09910db788314312fc763df08d8ac2aadb7"}, + {file = "types_jsonschema-4.25.0.20250720-py3-none-any.whl", hash = "sha256:7d7897c715310d8bf9ae27a2cedba78bbb09e4cad83ce06d2aa79b73a88941df"}, + {file = "types_jsonschema-4.25.0.20250720.tar.gz", hash = "sha256:765a3b6144798fe3161fd8cbe570a756ed3e8c0e5adb7c09693eb49faad39dbd"}, ] [package.dependencies] @@ -2982,14 +2983,14 @@ files = [ [[package]] name = "types-psycopg2" -version = "2.9.21.20250516" +version = "2.9.21.20250718" description = "Typing stubs for psycopg2" optional = false python-versions = ">=3.9" groups = ["dev"] files = [ - {file = "types_psycopg2-2.9.21.20250516-py3-none-any.whl", hash = "sha256:2a9212d1e5e507017b31486ce8147634d06b85d652769d7a2d91d53cb4edbd41"}, - {file = "types_psycopg2-2.9.21.20250516.tar.gz", hash = "sha256:6721018279175cce10b9582202e2a2b4a0da667857ccf82a97691bdb5ecd610f"}, + {file = "types_psycopg2-2.9.21.20250718-py3-none-any.whl", hash = "sha256:bcf085d4293bda48f5943a46dadf0389b2f98f7e8007722f7e1c12ee0f541858"}, + {file = "types_psycopg2-2.9.21.20250718.tar.gz", hash = "sha256:dc09a97272ef67e739e57b9f4740b761208f4514257e311c0b05c8c7a37d04b4"}, ] [[package]] @@ -3352,4 +3353,4 @@ url-preview = ["lxml"] [metadata] lock-version = "2.1" python-versions = "^3.9.0" -content-hash = "b1a0f4708465fd597d0bc7ebb09443ce0e2613cd58a33387a28036249f26856b" +content-hash = "600a349d08dde732df251583094a121b5385eb43ae0c6ceff10dcf9749359446" diff --git a/pyproject.toml b/pyproject.toml index 44336344cd..0d298a6135 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,7 +101,7 @@ module-name = "synapse.synapse_rust" [tool.poetry] name = "matrix-synapse" -version = "1.135.2" +version = "1.136.0" description = "Homeserver for the Matrix decentralised comms protocol" authors = ["Matrix.org Team and Contributors "] license = "AGPL-3.0-or-later" @@ -178,8 +178,13 @@ signedjson = "^1.1.0" service-identity = ">=18.1.0" # Twisted 18.9 introduces some logger improvements that the structured # logger utilises -Twisted = {extras = ["tls"], version = ">=18.9.0"} -treq = ">=15.1" +# Twisted 19.7.0 moves test helpers to a new module and deprecates the old location. +# Twisted 21.2.0 introduces contextvar support. +# We could likely bump this to 22.1 without making distro packagers' +# lives hard (as of 2025-07, distro support is Ubuntu LTS: 22.1, Debian stable: 22.4, +# RHEL 9: 22.10) +Twisted = {extras = ["tls"], version = ">=21.2.0"} +treq = ">=21.5.0" # Twisted has required pyopenssl 16.0 since about Twisted 16.6. pyOpenSSL = ">=16.0.0" PyYAML = ">=5.3" @@ -319,7 +324,7 @@ all = [ # failing on new releases. Keeping lower bounds loose here means that dependabot # can bump versions without having to update the content-hash in the lockfile. # This helps prevents merge conflicts when running a batch of dependabot updates. -ruff = "0.12.4" +ruff = "0.12.7" # Type checking only works with the pydantic.v1 compat module from pydantic v2 pydantic = "^2" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index ab87de33ab..0706357294 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -7,7 +7,7 @@ name = "synapse" version = "0.1.0" edition = "2021" -rust-version = "1.81.0" +rust-version = "1.82.0" [lib] name = "synapse" diff --git a/rust/benches/evaluator.rs b/rust/benches/evaluator.rs index 28537e187e..96169fd45d 100644 --- a/rust/benches/evaluator.rs +++ b/rust/benches/evaluator.rs @@ -61,6 +61,7 @@ fn bench_match_exact(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -71,10 +72,10 @@ fn bench_match_exact(b: &mut Bencher) { }, )); - let matched = eval.match_condition(&condition, None, None).unwrap(); + let matched = eval.match_condition(&condition, None, None, None).unwrap(); assert!(matched, "Didn't match"); - b.iter(|| eval.match_condition(&condition, None, None).unwrap()); + b.iter(|| eval.match_condition(&condition, None, None, None).unwrap()); } #[bench] @@ -107,6 +108,7 @@ fn bench_match_word(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -117,10 +119,10 @@ fn bench_match_word(b: &mut Bencher) { }, )); - let matched = eval.match_condition(&condition, None, None).unwrap(); + let matched = eval.match_condition(&condition, None, None, None).unwrap(); assert!(matched, "Didn't match"); - b.iter(|| eval.match_condition(&condition, None, None).unwrap()); + b.iter(|| eval.match_condition(&condition, None, None, None).unwrap()); } #[bench] @@ -153,6 +155,7 @@ fn bench_match_word_miss(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -163,10 +166,10 @@ fn bench_match_word_miss(b: &mut Bencher) { }, )); - let matched = eval.match_condition(&condition, None, None).unwrap(); + let matched = eval.match_condition(&condition, None, None, None).unwrap(); assert!(!matched, "Didn't match"); - b.iter(|| eval.match_condition(&condition, None, None).unwrap()); + b.iter(|| eval.match_condition(&condition, None, None, None).unwrap()); } #[bench] @@ -199,6 +202,7 @@ fn bench_eval_message(b: &mut Bencher) { vec![], false, false, + false, ) .unwrap(); @@ -210,7 +214,8 @@ fn bench_eval_message(b: &mut Bencher) { false, false, false, + false, ); - b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); + b.iter(|| eval.run(&rules, Some("bob"), Some("person"), None)); } diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs index eeb6074c10..4711fc540f 100644 --- a/rust/src/events/internal_metadata.rs +++ b/rust/src/events/internal_metadata.rs @@ -54,6 +54,7 @@ enum EventInternalMetadataData { RecheckRedaction(bool), SoftFailed(bool), ProactivelySend(bool), + PolicyServerSpammy(bool), Redacted(bool), TxnId(Box), TokenId(i64), @@ -96,6 +97,13 @@ impl EventInternalMetadataData { .to_owned() .into_any(), ), + EventInternalMetadataData::PolicyServerSpammy(o) => ( + pyo3::intern!(py, "policy_server_spammy"), + o.into_pyobject(py) + .unwrap_infallible() + .to_owned() + .into_any(), + ), EventInternalMetadataData::Redacted(o) => ( pyo3::intern!(py, "redacted"), o.into_pyobject(py) @@ -155,6 +163,11 @@ impl EventInternalMetadataData { .extract() .with_context(|| format!("'{key_str}' has invalid type"))?, ), + "policy_server_spammy" => EventInternalMetadataData::PolicyServerSpammy( + value + .extract() + .with_context(|| format!("'{key_str}' has invalid type"))?, + ), "redacted" => EventInternalMetadataData::Redacted( value .extract() @@ -427,6 +440,17 @@ impl EventInternalMetadata { set_property!(self, ProactivelySend, obj); } + #[getter] + fn get_policy_server_spammy(&self) -> PyResult { + Ok(get_property_opt!(self, PolicyServerSpammy) + .copied() + .unwrap_or(false)) + } + #[setter] + fn set_policy_server_spammy(&mut self, obj: bool) { + set_property!(self, PolicyServerSpammy, obj); + } + #[getter] fn get_redacted(&self) -> PyResult { let bool = get_property!(self, Redacted)?; diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index e0832ada1c..ec027ca251 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -290,6 +290,26 @@ pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule { }]; pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ + PushRule { + rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.unsubscribed_thread"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known( + KnownCondition::Msc4306ThreadSubscription { subscribed: false }, + )]), + actions: Cow::Borrowed(&[]), + default: true, + default_enabled: true, + }, + PushRule { + rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.subscribed_thread"), + priority_class: 1, + conditions: Cow::Borrowed(&[Condition::Known( + KnownCondition::Msc4306ThreadSubscription { subscribed: true }, + )]), + actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]), + default: true, + default_enabled: true, + }, PushRule { rule_id: Cow::Borrowed("global/underride/.m.rule.call"), priority_class: 1, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index db406acb88..1cbca4c635 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -106,8 +106,11 @@ pub struct PushRuleEvaluator { /// flag as MSC1767 (extensible events core). msc3931_enabled: bool, - // If MSC4210 (remove legacy mentions) is enabled. + /// If MSC4210 (remove legacy mentions) is enabled. msc4210_enabled: bool, + + /// If MSC4306 (thread subscriptions) is enabled. + msc4306_enabled: bool, } #[pymethods] @@ -126,6 +129,7 @@ impl PushRuleEvaluator { room_version_feature_flags, msc3931_enabled, msc4210_enabled, + msc4306_enabled, ))] pub fn py_new( flattened_keys: BTreeMap, @@ -138,6 +142,7 @@ impl PushRuleEvaluator { room_version_feature_flags: Vec, msc3931_enabled: bool, msc4210_enabled: bool, + msc4306_enabled: bool, ) -> Result { let body = match flattened_keys.get("content.body") { Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone().into_owned(), @@ -156,6 +161,7 @@ impl PushRuleEvaluator { room_version_feature_flags, msc3931_enabled, msc4210_enabled, + msc4306_enabled, }) } @@ -167,12 +173,19 @@ impl PushRuleEvaluator { /// /// Returns the set of actions, if any, that match (filtering out any /// `dont_notify` and `coalesce` actions). - #[pyo3(signature = (push_rules, user_id=None, display_name=None))] + /// + /// msc4306_thread_subscription_state: (Only populated if MSC4306 is enabled) + /// The thread subscription state corresponding to the thread containing this event. + /// - `None` if the event is not in a thread, or if MSC4306 is disabled. + /// - `Some(true)` if the event is in a thread and the user has a subscription for that thread + /// - `Some(false)` if the event is in a thread and the user does NOT have a subscription for that thread + #[pyo3(signature = (push_rules, user_id=None, display_name=None, msc4306_thread_subscription_state=None))] pub fn run( &self, push_rules: &FilteredPushRules, user_id: Option<&str>, display_name: Option<&str>, + msc4306_thread_subscription_state: Option, ) -> Vec { 'outer: for (push_rule, enabled) in push_rules.iter() { if !enabled { @@ -204,7 +217,12 @@ impl PushRuleEvaluator { Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }), ); - match self.match_condition(condition, user_id, display_name) { + match self.match_condition( + condition, + user_id, + display_name, + msc4306_thread_subscription_state, + ) { Ok(true) => {} Ok(false) => continue 'outer, Err(err) => { @@ -237,14 +255,20 @@ impl PushRuleEvaluator { } /// Check if the given condition matches. - #[pyo3(signature = (condition, user_id=None, display_name=None))] + #[pyo3(signature = (condition, user_id=None, display_name=None, msc4306_thread_subscription_state=None))] fn matches( &self, condition: Condition, user_id: Option<&str>, display_name: Option<&str>, + msc4306_thread_subscription_state: Option, ) -> bool { - match self.match_condition(&condition, user_id, display_name) { + match self.match_condition( + &condition, + user_id, + display_name, + msc4306_thread_subscription_state, + ) { Ok(true) => true, Ok(false) => false, Err(err) => { @@ -262,6 +286,7 @@ impl PushRuleEvaluator { condition: &Condition, user_id: Option<&str>, display_name: Option<&str>, + msc4306_thread_subscription_state: Option, ) -> Result { let known_condition = match condition { Condition::Known(known) => known, @@ -393,6 +418,13 @@ impl PushRuleEvaluator { && self.room_version_feature_flags.contains(&flag) } } + KnownCondition::Msc4306ThreadSubscription { subscribed } => { + if !self.msc4306_enabled { + false + } else { + msc4306_thread_subscription_state == Some(*subscribed) + } + } }; Ok(result) @@ -536,10 +568,11 @@ fn push_rule_evaluator() { vec![], true, false, + false, ) .unwrap(); - let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); + let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"), None); assert_eq!(result.len(), 3); } @@ -566,6 +599,7 @@ fn test_requires_room_version_supports_condition() { flags, true, false, + false, ) .unwrap(); @@ -575,6 +609,7 @@ fn test_requires_room_version_supports_condition() { &FilteredPushRules::default(), Some("@bob:example.org"), None, + None, ); assert_eq!(result.len(), 3); @@ -593,7 +628,17 @@ fn test_requires_room_version_supports_condition() { }; let rules = PushRules::new(vec![custom_rule]); result = evaluator.run( - &FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false, false), + &FilteredPushRules::py_new( + rules, + BTreeMap::new(), + true, + false, + true, + false, + false, + false, + ), + None, None, None, ); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index bd0e853ac3..b07a12e5cc 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -369,6 +369,10 @@ pub enum KnownCondition { RoomVersionSupports { feature: Cow<'static, str>, }, + #[serde(rename = "io.element.msc4306.thread_subscription")] + Msc4306ThreadSubscription { + subscribed: bool, + }, } impl<'source> IntoPyObject<'source> for Condition { @@ -547,11 +551,13 @@ pub struct FilteredPushRules { msc3664_enabled: bool, msc4028_push_encrypted_events: bool, msc4210_enabled: bool, + msc4306_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] + #[allow(clippy::too_many_arguments)] pub fn py_new( push_rules: PushRules, enabled_map: BTreeMap, @@ -560,6 +566,7 @@ impl FilteredPushRules { msc3664_enabled: bool, msc4028_push_encrypted_events: bool, msc4210_enabled: bool, + msc4306_enabled: bool, ) -> Self { Self { push_rules, @@ -569,6 +576,7 @@ impl FilteredPushRules { msc3664_enabled, msc4028_push_encrypted_events, msc4210_enabled, + msc4306_enabled, } } @@ -619,6 +627,10 @@ impl FilteredPushRules { return false; } + if !self.msc4306_enabled && rule.rule_id.contains("/.io.element.msc4306.rule.") { + return false; + } + true }) .map(|r| { diff --git a/schema/synapse-config.schema.yaml b/schema/synapse-config.schema.yaml index 1b3bdcd27a..584f6e0ae8 100644 --- a/schema/synapse-config.schema.yaml +++ b/schema/synapse-config.schema.yaml @@ -1,5 +1,5 @@ $schema: https://element-hq.github.io/synapse/latest/schema/v1/meta.schema.json -$id: https://element-hq.github.io/synapse/schema/synapse/v1.135/synapse-config.schema.json +$id: https://element-hq.github.io/synapse/schema/synapse/v1.136/synapse-config.schema.json type: object properties: modules: @@ -629,6 +629,70 @@ properties: password: mypassword ssh_priv_key_path: CONFDIR/id_rsa ssh_pub_key_path: CONFDIR/id_rsa.pub + http_proxy: + type: ["string", "null"] + description: >- + Proxy server to use for HTTP requests. + + For more details, see the [forward proxy documentation](../../setup/forward_proxy.md). + examples: + - "http://USERNAME:PASSWORD@10.0.1.1:8080/" + https_proxy: + type: ["string", "null"] + description: >- + Proxy server to use for HTTPS requests. + + For more details, see the [forward proxy documentation](../../setup/forward_proxy.md). + examples: + - "http://USERNAME:PASSWORD@proxy.example.com:8080/" + no_proxy_hosts: + type: array + description: >- + List of hosts, IP addresses, or IP ranges in CIDR format which should not use the + proxy. Synapse will directly connect to these hosts. + + For more details, see the [forward proxy documentation](../../setup/forward_proxy.md). + examples: + - - master.hostname.example.com + - 10.1.0.0/16 + - 172.30.0.0/16 + matrix_authentication_service: + type: object + description: >- + The `matrix_authentication_service` setting configures integration with + [Matrix Authentication Service (MAS)](https://github.com/element-hq/matrix-authentication-service). + properties: + enabled: + type: boolean + description: >- + Whether or not to enable the MAS integration. If this is set to + `false`, Synapse will use its legacy internal authentication API. + default: false + + endpoint: + type: string + format: uri + description: >- + The URL where Synapse can reach MAS. This *must* have the `discovery` + and `oauth` resources mounted. + default: http://localhost:8080 + + secret: + type: ["string", "null"] + description: >- + A shared secret that will be used to authenticate requests from and to MAS. + + secret_path: + type: ["string", "null"] + description: >- + Alternative to `secret`, reading the shared secret from a file. + The file should be a plain text file, containing only the secret. + Synapse reads the secret from the given file once at startup. + + examples: + - enabled: true + secret: someverysecuresecret + endpoint: http://localhost:8080 dummy_events_threshold: type: integer description: >- @@ -2201,6 +2265,17 @@ properties: examples: - per_second: 2.0 burst_count: 20.0 + rc_room_creation: + $ref: "#/$defs/rc" + description: >- + Sets rate limits for how often users are able to create rooms. + default: + per_user: + per_second: 0.016 + burst_count: 10.0 + examples: + - per_second: 1.0 + burst_count: 5.0 federation_rr_transactions_per_room_per_second: type: integer description: >- diff --git a/scripts-dev/gen_config_documentation.py b/scripts-dev/gen_config_documentation.py index 8e9d402c6a..9a49c07a34 100755 --- a/scripts-dev/gen_config_documentation.py +++ b/scripts-dev/gen_config_documentation.py @@ -473,6 +473,10 @@ def section(prop: str, values: dict) -> str: def main() -> None: + # For Windows: reconfigure the terminal to be UTF-8 for `print()` calls. + if sys.platform == "win32": + sys.stdout.reconfigure(encoding="utf-8") + def usage(err_msg: str) -> int: script_name = (sys.argv[:1] or ["__main__.py"])[0] print(err_msg, file=sys.stderr) @@ -485,7 +489,10 @@ def main() -> None: exit(usage("Too many arguments.")) if not (filepath := (sys.argv[1:] or [""])[0]): exit(usage("No schema file provided.")) - with open(filepath) as f: + with open(filepath, "r", encoding="utf-8") as f: + # Note: Windows requires that we specify the encoding otherwise it uses + # things like CP-1251, which can cause explosions. + # See https://github.com/yaml/pyyaml/issues/123 for more info. return yaml.safe_load(f) schema = read_json_file_arg() diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index a15c3c005c..610dec415a 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -23,28 +23,195 @@ can crop up, e.g the cache descriptors. """ -from typing import Callable, Optional, Tuple, Type, Union +import enum +from typing import Callable, Mapping, Optional, Tuple, Type, Union +import attr import mypy.types from mypy.erasetype import remove_instance_last_known_values from mypy.errorcodes import ErrorCode -from mypy.nodes import ARG_NAMED_OPT, TempNode, Var -from mypy.plugin import FunctionSigContext, MethodSigContext, Plugin +from mypy.nodes import ARG_NAMED_OPT, ListExpr, NameExpr, TempNode, TupleExpr, Var +from mypy.plugin import ( + ClassDefContext, + Context, + FunctionLike, + FunctionSigContext, + MethodSigContext, + MypyFile, + Plugin, +) from mypy.typeops import bind_self from mypy.types import ( AnyType, CallableType, Instance, NoneType, + Options, TupleType, TypeAliasType, TypeVarType, UninhabitedType, UnionType, ) +from mypy_zope import plugin as mypy_zope_plugin +from pydantic.mypy import plugin as mypy_pydantic_plugin + +PROMETHEUS_METRIC_MISSING_SERVER_NAME_LABEL = ErrorCode( + "missing-server-name-label", + "`SERVER_NAME_LABEL` required in metric", + category="per-homeserver-tenant-metrics", +) + +PROMETHEUS_METRIC_MISSING_FROM_LIST_TO_CHECK = ErrorCode( + "metric-type-missing-from-list", + "Every Prometheus metric type must be included in the `prometheus_metric_fullname_to_label_arg_map`.", + category="per-homeserver-tenant-metrics", +) + + +class Sentinel(enum.Enum): + # defining a sentinel in this way allows mypy to correctly handle the + # type of a dictionary lookup and subsequent type narrowing. + UNSET_SENTINEL = object() + + +@attr.s(auto_attribs=True) +class ArgLocation: + keyword_name: str + """ + The keyword argument name for this argument + """ + position: int + """ + The 0-based positional index of this argument + """ + + +prometheus_metric_fullname_to_label_arg_map: Mapping[str, Optional[ArgLocation]] = { + # `Collector` subclasses: + "prometheus_client.metrics.MetricWrapperBase": ArgLocation("labelnames", 2), + "prometheus_client.metrics.Counter": ArgLocation("labelnames", 2), + "prometheus_client.metrics.Histogram": ArgLocation("labelnames", 2), + "prometheus_client.metrics.Gauge": ArgLocation("labelnames", 2), + "prometheus_client.metrics.Summary": ArgLocation("labelnames", 2), + "prometheus_client.metrics.Info": ArgLocation("labelnames", 2), + "prometheus_client.metrics.Enum": ArgLocation("labelnames", 2), + "synapse.metrics.LaterGauge": ArgLocation("labelnames", 2), + "synapse.metrics.InFlightGauge": ArgLocation("labels", 2), + "synapse.metrics.GaugeBucketCollector": ArgLocation("labelnames", 2), + "prometheus_client.registry.Collector": None, + "prometheus_client.registry._EmptyCollector": None, + "prometheus_client.registry.CollectorRegistry": None, + "prometheus_client.process_collector.ProcessCollector": None, + "prometheus_client.platform_collector.PlatformCollector": None, + "prometheus_client.gc_collector.GCCollector": None, + "synapse.metrics._gc.GCCounts": None, + "synapse.metrics._gc.PyPyGCStats": None, + "synapse.metrics._reactor_metrics.ReactorLastSeenMetric": None, + "synapse.metrics.CPUMetrics": None, + "synapse.metrics.jemalloc.JemallocCollector": None, + "synapse.util.metrics.DynamicCollectorRegistry": None, + "synapse.metrics.background_process_metrics._Collector": None, + # + # `Metric` subclasses: + "prometheus_client.metrics_core.Metric": None, + "prometheus_client.metrics_core.UnknownMetricFamily": ArgLocation("labels", 3), + "prometheus_client.metrics_core.CounterMetricFamily": ArgLocation("labels", 3), + "prometheus_client.metrics_core.GaugeMetricFamily": ArgLocation("labels", 3), + "prometheus_client.metrics_core.SummaryMetricFamily": ArgLocation("labels", 3), + "prometheus_client.metrics_core.InfoMetricFamily": ArgLocation("labels", 3), + "prometheus_client.metrics_core.HistogramMetricFamily": ArgLocation("labels", 3), + "prometheus_client.metrics_core.GaugeHistogramMetricFamily": ArgLocation( + "labels", 4 + ), + "prometheus_client.metrics_core.StateSetMetricFamily": ArgLocation("labels", 3), + "synapse.metrics.GaugeHistogramMetricFamilyWithLabels": ArgLocation( + "labelnames", 4 + ), +} +""" +Map from the fullname of the Prometheus `Metric`/`Collector` classes to the keyword +argument name and positional index of the label names. This map is useful because +different metrics have different signatures for passing in label names and we just need +to know where to look. + +This map should include any metrics that we collect with Prometheus. Which corresponds +to anything that inherits from `prometheus_client.registry.Collector` +(`synapse.metrics._types.Collector`) or `prometheus_client.metrics_core.Metric`. The +exhaustiveness of this list is enforced by `analyze_prometheus_metric_classes`. + +The entries with `None` always fail the lint because they don't have a `labelnames` +argument (therefore, no `SERVER_NAME_LABEL`), but we include them here so that people +can notice and manually allow via a type ignore comment as the source of truth +should be in the source code. +""" + +# Unbound at this point because we don't know the mypy version yet. +# This is set in the `plugin(...)` function below. +MypyPydanticPluginClass: Type[Plugin] +MypyZopePluginClass: Type[Plugin] class SynapsePlugin(Plugin): + def __init__(self, options: Options): + super().__init__(options) + self.mypy_pydantic_plugin = MypyPydanticPluginClass(options) + self.mypy_zope_plugin = MypyZopePluginClass(options) + + def set_modules(self, modules: dict[str, MypyFile]) -> None: + """ + This is called by mypy internals. We have to override this to ensure it's also + called for any other plugins that we're manually handling. + + Here is how mypy describes it: + + > [`self._modules`] can't be set in `__init__` because it is executed too soon + > in `build.py`. Therefore, `build.py` *must* set it later before graph processing + > starts by calling `set_modules()`. + """ + super().set_modules(modules) + self.mypy_pydantic_plugin.set_modules(modules) + self.mypy_zope_plugin.set_modules(modules) + + def get_base_class_hook( + self, fullname: str + ) -> Optional[Callable[[ClassDefContext], None]]: + def _get_base_class_hook(ctx: ClassDefContext) -> None: + # Run any `get_base_class_hook` checks from other plugins first. + # + # Unfortunately, because mypy only chooses the first plugin that returns a + # non-None value (known-limitation, c.f. + # https://github.com/python/mypy/issues/19524), we workaround this by + # putting our custom plugin first in the plugin order and then calling the + # other plugin's hook manually followed by our own checks. + if callback := self.mypy_pydantic_plugin.get_base_class_hook(fullname): + callback(ctx) + if callback := self.mypy_zope_plugin.get_base_class_hook(fullname): + callback(ctx) + + # Now run our own checks + analyze_prometheus_metric_classes(ctx) + + return _get_base_class_hook + + def get_function_signature_hook( + self, fullname: str + ) -> Optional[Callable[[FunctionSigContext], FunctionLike]]: + # Strip off the unique identifier for classes that are dynamically created inside + # functions. ex. `synapse.metrics.jemalloc.JemallocCollector@185` (this is the line + # number) + if "@" in fullname: + fullname = fullname.split("@", 1)[0] + + # Look for any Prometheus metrics to make sure they have the `SERVER_NAME_LABEL` + # label. + if fullname in prometheus_metric_fullname_to_label_arg_map.keys(): + # Because it's difficult to determine the `fullname` of the function in the + # callback, let's just pass it in while we have it. + return lambda ctx: check_prometheus_metric_instantiation(ctx, fullname) + + return None + def get_method_signature_hook( self, fullname: str ) -> Optional[Callable[[MethodSigContext], CallableType]]: @@ -65,6 +232,157 @@ class SynapsePlugin(Plugin): return None +def analyze_prometheus_metric_classes(ctx: ClassDefContext) -> None: + """ + Cross-check the list of Prometheus metric classes against the + `prometheus_metric_fullname_to_label_arg_map` to ensure the list is exhaustive and + up-to-date. + """ + + fullname = ctx.cls.fullname + # Strip off the unique identifier for classes that are dynamically created inside + # functions. ex. `synapse.metrics.jemalloc.JemallocCollector@185` (this is the line + # number) + if "@" in fullname: + fullname = fullname.split("@", 1)[0] + + if any( + ancestor_type.fullname + in ( + # All of the Prometheus metric classes inherit from the `Collector`. + "prometheus_client.registry.Collector", + "synapse.metrics._types.Collector", + # And custom metrics that inherit from `Metric`. + "prometheus_client.metrics_core.Metric", + ) + for ancestor_type in ctx.cls.info.mro + ): + if fullname not in prometheus_metric_fullname_to_label_arg_map: + ctx.api.fail( + f"Expected {fullname} to be in `prometheus_metric_fullname_to_label_arg_map`, " + f"but it was not found. This is a problem with our custom mypy plugin. " + f"Please add it to the map.", + Context(), + code=PROMETHEUS_METRIC_MISSING_FROM_LIST_TO_CHECK, + ) + + +def check_prometheus_metric_instantiation( + ctx: FunctionSigContext, fullname: str +) -> CallableType: + """ + Ensure that the `prometheus_client` metrics include the `SERVER_NAME_LABEL` label + when instantiated. + + This is important because we support multiple Synapse instances running in the same + process, where all metrics share a single global `REGISTRY`. The `server_name` label + ensures metrics are correctly separated by homeserver. + + There are also some metrics that apply at the process level, such as CPU usage, + Python garbage collection, and Twisted reactor tick time, which shouldn't have the + `SERVER_NAME_LABEL`. In those cases, use a type ignore comment to disable the + check, e.g. `# type: ignore[missing-server-name-label]`. + + Args: + ctx: The `FunctionSigContext` from mypy. + fullname: The fully qualified name of the function being called, + e.g. `"prometheus_client.metrics.Counter"` + """ + # The true signature, this isn't being modified so this is what will be returned. + signature = ctx.default_signature + + # Find where the label names argument is in the function signature. + arg_location = prometheus_metric_fullname_to_label_arg_map.get( + fullname, Sentinel.UNSET_SENTINEL + ) + assert arg_location is not Sentinel.UNSET_SENTINEL, ( + f"Expected to find {fullname} in `prometheus_metric_fullname_to_label_arg_map`, " + f"but it was not found. This is a problem with our custom mypy plugin. " + f"Please add it to the map. Context: {ctx.context}" + ) + # People should be using `# type: ignore[missing-server-name-label]` for + # process-level metrics that should not have the `SERVER_NAME_LABEL`. + if arg_location is None: + ctx.api.fail( + f"{signature.name} does not have a `labelnames`/`labels` argument " + "(if this is untrue, update `prometheus_metric_fullname_to_label_arg_map` " + "in our custom mypy plugin) and should probably have a type ignore comment, " + "e.g. `# type: ignore[missing-server-name-label]`. The reason we don't " + "automatically ignore this is the source of truth should be in the source code.", + ctx.context, + code=PROMETHEUS_METRIC_MISSING_SERVER_NAME_LABEL, + ) + return signature + + # Sanity check the arguments are still as expected in this version of + # `prometheus_client`. ex. `Counter(name, documentation, labelnames, ...)` + # + # `signature.arg_names` should be: ["name", "documentation", "labelnames", ...] + if ( + len(signature.arg_names) < (arg_location.position + 1) + or signature.arg_names[arg_location.position] != arg_location.keyword_name + ): + ctx.api.fail( + f"Expected argument number {arg_location.position + 1} of {signature.name} to be `labelnames`/`labels`, " + f"but got {signature.arg_names[arg_location.position]}", + ctx.context, + ) + return signature + + # Ensure mypy is passing the correct number of arguments because we are doing some + # dirty indexing into `ctx.args` later on. + assert len(ctx.args) == len(signature.arg_names), ( + f"Expected the list of arguments in the {signature.name} signature ({len(signature.arg_names)})" + f"to match the number of arguments from the function signature context ({len(ctx.args)})" + ) + + # Check if the `labelnames` argument includes `SERVER_NAME_LABEL` + # + # `ctx.args` should look like this: + # ``` + # [ + # [StrExpr("name")], + # [StrExpr("documentation")], + # [ListExpr([StrExpr("label1"), StrExpr("label2")])] + # ... + # ] + # ``` + labelnames_arg_expression = ( + ctx.args[arg_location.position][0] + if len(ctx.args[arg_location.position]) > 0 + else None + ) + if isinstance(labelnames_arg_expression, (ListExpr, TupleExpr)): + # Check if the `labelnames` argument includes the `server_name` label (`SERVER_NAME_LABEL`). + for labelname_expression in labelnames_arg_expression.items: + if ( + isinstance(labelname_expression, NameExpr) + and labelname_expression.fullname == "synapse.metrics.SERVER_NAME_LABEL" + ): + # Found the `SERVER_NAME_LABEL`, all good! + break + else: + ctx.api.fail( + f"Expected {signature.name} to include `SERVER_NAME_LABEL` in the list of labels. " + "If this is a process-level metric (vs homeserver-level), use a type ignore comment " + "to disable this check.", + ctx.context, + code=PROMETHEUS_METRIC_MISSING_SERVER_NAME_LABEL, + ) + else: + ctx.api.fail( + f"Expected the `labelnames` argument of {signature.name} to be a list of label names " + f"(including `SERVER_NAME_LABEL`), but got {labelnames_arg_expression}. " + "If this is a process-level metric (vs homeserver-level), use a type ignore comment " + "to disable this check.", + ctx.context, + code=PROMETHEUS_METRIC_MISSING_SERVER_NAME_LABEL, + ) + return signature + + return signature + + def _get_true_return_type(signature: CallableType) -> mypy.types.Type: """ Get the "final" return type of a callable which might return an Awaitable/Deferred. @@ -372,10 +690,13 @@ def is_cacheable( def plugin(version: str) -> Type[SynapsePlugin]: + global MypyPydanticPluginClass, MypyZopePluginClass # This is the entry point of the plugin, and lets us deal with the fact # that the mypy plugin interface is *not* stable by looking at the version # string. # # However, since we pin the version of mypy Synapse uses in CI, we don't # really care. + MypyPydanticPluginClass = mypy_pydantic_plugin(version) + MypyZopePluginClass = mypy_zope_plugin(version) return SynapsePlugin diff --git a/synapse/__init__.py b/synapse/__init__.py index e7784ac5d7..3bd1b3307e 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -45,16 +45,6 @@ if py_version < (3, 9): # Allow using the asyncio reactor via env var. if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")): - from incremental import Version - - import twisted - - # We need a bugfix that is included in Twisted 21.2.0: - # https://twistedmatrix.com/trac/ticket/9787 - if twisted.version < Version("Twisted", 21, 2, 0): - print("Using asyncio reactor requires Twisted>=21.2.0") - sys.exit(1) - import asyncio from twisted.internet import asyncioreactor diff --git a/synapse/_pydantic_compat.py b/synapse/_pydantic_compat.py index e9b43aebe3..a520c0e897 100644 --- a/synapse/_pydantic_compat.py +++ b/synapse/_pydantic_compat.py @@ -34,9 +34,11 @@ HAS_PYDANTIC_V2: bool = Version(pydantic_version).major == 2 if TYPE_CHECKING or HAS_PYDANTIC_V2: from pydantic.v1 import ( + AnyHttpUrl, BaseModel, Extra, Field, + FilePath, MissingError, PydanticValueError, StrictBool, @@ -55,9 +57,11 @@ if TYPE_CHECKING or HAS_PYDANTIC_V2: from pydantic.v1.typing import get_args else: from pydantic import ( + AnyHttpUrl, BaseModel, Extra, Field, + FilePath, MissingError, PydanticValueError, StrictBool, @@ -77,6 +81,7 @@ else: __all__ = ( "HAS_PYDANTIC_V2", + "AnyHttpUrl", "BaseModel", "constr", "conbytes", @@ -85,6 +90,7 @@ __all__ = ( "ErrorWrapper", "Extra", "Field", + "FilePath", "get_args", "MissingError", "parse_obj_as", diff --git a/synapse/_scripts/review_recent_signups.py b/synapse/_scripts/review_recent_signups.py index 62723c539d..0ff7fae567 100644 --- a/synapse/_scripts/review_recent_signups.py +++ b/synapse/_scripts/review_recent_signups.py @@ -29,19 +29,21 @@ import attr from synapse.config._base import ( Config, + ConfigError, RootConfig, find_config_files, read_config_files, ) from synapse.config.database import DatabaseConfig +from synapse.config.server import ServerConfig from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn from synapse.storage.engines import create_engine class ReviewConfig(RootConfig): - "A config class that just pulls out the database config" + "A config class that just pulls out the server and database config" - config_classes = [DatabaseConfig] + config_classes = [ServerConfig, DatabaseConfig] @attr.s(auto_attribs=True) @@ -148,6 +150,10 @@ def main() -> None: config_dict = read_config_files(config_files) config.parse_config_dict(config_dict, "", "") + server_name = config.server.server_name + if not isinstance(server_name, str): + raise ConfigError("Must be a string", ("server_name",)) + since_ms = time.time() * 1000 - Config.parse_duration(config_args.since) exclude_users_with_email = config_args.exclude_emails exclude_users_with_appservice = config_args.exclude_app_service @@ -159,7 +165,12 @@ def main() -> None: engine = create_engine(database_config.config) - with make_conn(database_config, engine, "review_recent_signups") as db_conn: + with make_conn( + db_config=database_config, + engine=engine, + default_txn_name="review_recent_signups", + server_name=server_name, + ) as db_conn: # This generates a type of Cursor, not LoggingTransaction. user_infos = get_recent_users( db_conn.cursor(), diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 9a0b459e65..0f54cfc64a 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -672,8 +672,14 @@ class Porter: engine = create_engine(db_config.config) hs = MockHomeserver(self.hs_config) + server_name = hs.hostname - with make_conn(db_config, engine, "portdb") as db_conn: + with make_conn( + db_config=db_config, + engine=engine, + default_txn_name="portdb", + server_name=server_name, + ) as db_conn: engine.check_database( db_conn, allow_outdated_version=allow_outdated_version ) diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py index d8b4dbd6c6..70e5598418 100644 --- a/synapse/_scripts/update_synapse_database.py +++ b/synapse/_scripts/update_synapse_database.py @@ -53,6 +53,7 @@ class MockHomeserver(HomeServer): def run_background_updates(hs: HomeServer) -> None: + server_name = hs.hostname main = hs.get_datastores().main state = hs.get_datastores().state @@ -66,7 +67,11 @@ def run_background_updates(hs: HomeServer) -> None: def run() -> None: # Apply all background updates on the database. defer.ensureDeferred( - run_as_background_process("background_updates", run_background_updates) + run_as_background_process( + "background_updates", + server_name, + run_background_updates, + ) ) reactor.callWhenRunning(run) diff --git a/synapse/api/auth/__init__.py b/synapse/api/auth/__init__.py index 1b801d3ad3..d253938329 100644 --- a/synapse/api/auth/__init__.py +++ b/synapse/api/auth/__init__.py @@ -20,10 +20,13 @@ # from typing import TYPE_CHECKING, Optional, Protocol, Tuple +from prometheus_client import Histogram + from twisted.web.server import Request from synapse.appservice import ApplicationService from synapse.http.site import SynapseRequest +from synapse.metrics import SERVER_NAME_LABEL from synapse.types import Requester if TYPE_CHECKING: @@ -33,6 +36,13 @@ if TYPE_CHECKING: GUEST_DEVICE_ID = "guest_device" +introspection_response_timer = Histogram( + "synapse_api_auth_delegated_introspection_response", + "Time taken to get a response for an introspection request", + labelnames=["code", SERVER_NAME_LABEL], +) + + class Auth(Protocol): """The interface that an auth provider must implement.""" diff --git a/synapse/api/auth/mas.py b/synapse/api/auth/mas.py new file mode 100644 index 0000000000..00bad76856 --- /dev/null +++ b/synapse/api/auth/mas.py @@ -0,0 +1,432 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# +import logging +from typing import TYPE_CHECKING, Optional +from urllib.parse import urlencode + +from synapse._pydantic_compat import ( + BaseModel, + Extra, + StrictBool, + StrictInt, + StrictStr, + ValidationError, +) +from synapse.api.auth.base import BaseAuth +from synapse.api.errors import ( + AuthError, + HttpResponseException, + InvalidClientTokenError, + SynapseError, + UnrecognizedRequestError, +) +from synapse.http.site import SynapseRequest +from synapse.logging.context import PreserveLoggingContext +from synapse.logging.opentracing import ( + active_span, + force_tracing, + inject_request_headers, + start_active_span, +) +from synapse.metrics import SERVER_NAME_LABEL +from synapse.synapse_rust.http_client import HttpClient +from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.util import json_decoder +from synapse.util.caches.cached_call import RetryOnExceptionCachedCall +from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext + +from . import introspection_response_timer + +if TYPE_CHECKING: + from synapse.rest.admin.experimental_features import ExperimentalFeature + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + +# Scope as defined by MSC2967 +# https://github.com/matrix-org/matrix-spec-proposals/pull/2967 +SCOPE_MATRIX_API = "urn:matrix:org.matrix.msc2967.client:api:*" +SCOPE_MATRIX_DEVICE_PREFIX = "urn:matrix:org.matrix.msc2967.client:device:" + + +class ServerMetadata(BaseModel): + class Config: + extra = Extra.allow + + issuer: StrictStr + account_management_uri: StrictStr + + +class IntrospectionResponse(BaseModel): + retrieved_at_ms: StrictInt + active: StrictBool + scope: Optional[StrictStr] + username: Optional[StrictStr] + sub: Optional[StrictStr] + device_id: Optional[StrictStr] + expires_in: Optional[StrictInt] + + class Config: + extra = Extra.allow + + def get_scope_set(self) -> set[str]: + if not self.scope: + return set() + + return {token for token in self.scope.split(" ") if token} + + def is_active(self, now_ms: int) -> bool: + if not self.active: + return False + + # Compatibility tokens don't expire and don't have an 'expires_in' field + if self.expires_in is None: + return True + + absolute_expiry_ms = self.expires_in * 1000 + self.retrieved_at_ms + return now_ms < absolute_expiry_ms + + +class MasDelegatedAuth(BaseAuth): + def __init__(self, hs: "HomeServer"): + super().__init__(hs) + + self.server_name = hs.hostname + self._clock = hs.get_clock() + self._config = hs.config.mas + + self._http_client = hs.get_proxied_http_client() + self._rust_http_client = HttpClient( + reactor=hs.get_reactor(), + user_agent=self._http_client.user_agent.decode("utf8"), + ) + self._server_metadata = RetryOnExceptionCachedCall[ServerMetadata]( + self._load_metadata + ) + self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users + + # # Token Introspection Cache + # This remembers what users/devices are represented by which access tokens, + # in order to reduce overall system load: + # - on Synapse (as requests are relatively expensive) + # - on the network + # - on MAS + # + # Since there is no invalidation mechanism currently, + # the entries expire after 2 minutes. + # This does mean tokens can be treated as valid by Synapse + # for longer than reality. + # + # Ideally, tokens should logically be invalidated in the following circumstances: + # - If a session logout happens. + # In this case, MAS will delete the device within Synapse + # anyway and this is good enough as an invalidation. + # - If the client refreshes their token in MAS. + # In this case, the device still exists and it's not the end of the world for + # the old access token to continue working for a short time. + self._introspection_cache: ResponseCache[str] = ResponseCache( + clock=self._clock, + name="mas_token_introspection", + server_name=self.server_name, + timeout_ms=120_000, + # don't log because the keys are access tokens + enable_logging=False, + ) + + @property + def _metadata_url(self) -> str: + return f"{self._config.endpoint.rstrip('/')}/.well-known/openid-configuration" + + @property + def _introspection_endpoint(self) -> str: + return f"{self._config.endpoint.rstrip('/')}/oauth2/introspect" + + async def _load_metadata(self) -> ServerMetadata: + response = await self._http_client.get_json(self._metadata_url) + metadata = ServerMetadata(**response) + return metadata + + async def issuer(self) -> str: + metadata = await self._server_metadata.get() + return metadata.issuer + + async def account_management_url(self) -> str: + metadata = await self._server_metadata.get() + return metadata.account_management_uri + + async def auth_metadata(self) -> JsonDict: + metadata = await self._server_metadata.get() + return metadata.dict() + + def is_request_using_the_shared_secret(self, request: SynapseRequest) -> bool: + """ + Check if the request is using the shared secret. + + Args: + request: The request to check. + + Returns: + True if the request is using the shared secret, False otherwise. + """ + access_token = self.get_access_token_from_request(request) + shared_secret = self._config.secret() + if not shared_secret: + return False + + return access_token == shared_secret + + async def _introspect_token( + self, token: str, cache_context: ResponseCacheContext[str] + ) -> IntrospectionResponse: + """ + Send a token to the introspection endpoint and returns the introspection response + + Parameters: + token: The token to introspect + + Raises: + HttpResponseException: If the introspection endpoint returns a non-2xx response + ValueError: If the introspection endpoint returns an invalid JSON response + JSONDecodeError: If the introspection endpoint returns a non-JSON response + Exception: If the HTTP request fails + + Returns: + The introspection response + """ + + # By default, we shouldn't cache the result unless we know it's valid + cache_context.should_cache = False + raw_headers: dict[str, str] = { + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "application/json", + "Authorization": f"Bearer {self._config.secret()}", + # Tell MAS that we support reading the device ID as an explicit + # value, not encoded in the scope. This is supported by MAS 0.15+ + "X-MAS-Supports-Device-Id": "1", + } + + args = {"token": token, "token_type_hint": "access_token"} + body = urlencode(args, True) + + # Do the actual request + + logger.debug("Fetching token from MAS") + start_time = self._clock.time() + try: + with start_active_span("mas-introspect-token"): + inject_request_headers(raw_headers) + with PreserveLoggingContext(): + resp_body = await self._rust_http_client.post( + url=self._introspection_endpoint, + response_limit=1 * 1024 * 1024, + headers=raw_headers, + request_body=body, + ) + except HttpResponseException as e: + end_time = self._clock.time() + introspection_response_timer.labels( + code=e.code, **{SERVER_NAME_LABEL: self.server_name} + ).observe(end_time - start_time) + raise + except Exception: + end_time = self._clock.time() + introspection_response_timer.labels( + code="ERR", **{SERVER_NAME_LABEL: self.server_name} + ).observe(end_time - start_time) + raise + + logger.debug("Fetched token from MAS") + + end_time = self._clock.time() + introspection_response_timer.labels( + code=200, **{SERVER_NAME_LABEL: self.server_name} + ).observe(end_time - start_time) + + raw_response = json_decoder.decode(resp_body.decode("utf-8")) + try: + response = IntrospectionResponse( + retrieved_at_ms=self._clock.time_msec(), + **raw_response, + ) + except ValidationError as e: + raise ValueError( + "The introspection endpoint returned an invalid JSON response" + ) from e + + # We had a valid response, so we can cache it + cache_context.should_cache = True + return response + + async def is_server_admin(self, requester: Requester) -> bool: + return "urn:synapse:admin:*" in requester.scope + + async def get_user_by_req( + self, + request: SynapseRequest, + allow_guest: bool = False, + allow_expired: bool = False, + allow_locked: bool = False, + ) -> Requester: + parent_span = active_span() + with start_active_span("get_user_by_req"): + access_token = self.get_access_token_from_request(request) + + requester = await self.get_appservice_user(request, access_token) + if not requester: + requester = await self.get_user_by_access_token( + token=access_token, + allow_expired=allow_expired, + ) + + await self._record_request(request, requester) + + request.requester = requester + + if parent_span: + if requester.authenticated_entity in self._force_tracing_for_users: + # request tracing is enabled for this user, so we need to force it + # tracing on for the parent span (which will be the servlet span). + # + # It's too late for the get_user_by_req span to inherit the setting, + # so we also force it on for that. + force_tracing() + force_tracing(parent_span) + parent_span.set_tag( + "authenticated_entity", requester.authenticated_entity + ) + parent_span.set_tag("user_id", requester.user.to_string()) + if requester.device_id is not None: + parent_span.set_tag("device_id", requester.device_id) + if requester.app_service is not None: + parent_span.set_tag("appservice_id", requester.app_service.id) + return requester + + async def get_user_by_access_token( + self, + token: str, + allow_expired: bool = False, + ) -> Requester: + try: + introspection_result = await self._introspection_cache.wrap( + token, self._introspect_token, token, cache_context=True + ) + except Exception: + logger.exception("Failed to introspect token") + raise SynapseError(503, "Unable to introspect the access token") + + logger.debug("Introspection result: %r", introspection_result) + if not introspection_result.is_active(self._clock.time_msec()): + raise InvalidClientTokenError("Token is not active") + + # Let's look at the scope + scope = introspection_result.get_scope_set() + + # Determine type of user based on presence of particular scopes + if SCOPE_MATRIX_API not in scope: + raise InvalidClientTokenError( + "Token doesn't grant access to the Matrix C-S API" + ) + + if introspection_result.username is None: + raise AuthError( + 500, + "Invalid username claim in the introspection result", + ) + + user_id = UserID( + localpart=introspection_result.username, + domain=self.server_name, + ) + + # Try to find a user from the username claim + user_info = await self.store.get_user_by_id(user_id=user_id.to_string()) + if user_info is None: + raise AuthError( + 500, + "User not found", + ) + + # MAS will give us the device ID as an explicit value for *compatibility* sessions + # If present, we get it from here, if not we get it in the scope for next-gen sessions + device_id = introspection_result.device_id + if device_id is None: + # Find device_ids in scope + # We only allow a single device_id in the scope, so we find them all in the + # scope list, and raise if there are more than one. The OIDC server should be + # the one enforcing valid scopes, so we raise a 500 if we find an invalid scope. + device_ids = [ + tok[len(SCOPE_MATRIX_DEVICE_PREFIX) :] + for tok in scope + if tok.startswith(SCOPE_MATRIX_DEVICE_PREFIX) + ] + + if len(device_ids) > 1: + raise AuthError( + 500, + "Multiple device IDs in scope", + ) + + device_id = device_ids[0] if device_ids else None + + if device_id is not None: + # Sanity check the device_id + if len(device_id) > 255 or len(device_id) < 1: + raise AuthError( + 500, + "Invalid device ID in introspection result", + ) + + # Make sure the device exists. This helps with introspection cache + # invalidation: if we log out, the device gets deleted by MAS + device = await self.store.get_device( + user_id=user_id.to_string(), + device_id=device_id, + ) + if device is None: + # Invalidate the introspection cache, the device was deleted + self._introspection_cache.unset(token) + raise InvalidClientTokenError("Token is not active") + + return create_requester( + user_id=user_id, + device_id=device_id, + scope=scope, + ) + + async def get_user_by_req_experimental_feature( + self, + request: SynapseRequest, + feature: "ExperimentalFeature", + allow_guest: bool = False, + allow_expired: bool = False, + allow_locked: bool = False, + ) -> Requester: + try: + requester = await self.get_user_by_req( + request, + allow_guest=allow_guest, + allow_expired=allow_expired, + allow_locked=allow_locked, + ) + if await self.store.is_feature_enabled(requester.user.to_string(), feature): + return requester + + raise UnrecognizedRequestError(code=404) + except (AuthError, InvalidClientTokenError): + if feature.is_globally_enabled(self.hs.config): + # If its globally enabled then return the auth error + raise + + raise UnrecognizedRequestError(code=404) diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py index 567f2e834c..928b2c8f8b 100644 --- a/synapse/api/auth/msc3861_delegated.py +++ b/synapse/api/auth/msc3861_delegated.py @@ -28,7 +28,6 @@ from authlib.oauth2.auth import encode_client_secret_basic, encode_client_secret from authlib.oauth2.rfc7523 import ClientSecretJWT, PrivateKeyJWT, private_key_jwt_sign from authlib.oauth2.rfc7662 import IntrospectionToken from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url -from prometheus_client import Histogram from synapse.api.auth.base import BaseAuth from synapse.api.errors import ( @@ -47,25 +46,21 @@ from synapse.logging.opentracing import ( inject_request_headers, start_active_span, ) +from synapse.metrics import SERVER_NAME_LABEL from synapse.synapse_rust.http_client import HttpClient from synapse.types import Requester, UserID, create_requester from synapse.util import json_decoder from synapse.util.caches.cached_call import RetryOnExceptionCachedCall from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext +from . import introspection_response_timer + if TYPE_CHECKING: from synapse.rest.admin.experimental_features import ExperimentalFeature from synapse.server import HomeServer logger = logging.getLogger(__name__) -introspection_response_timer = Histogram( - "synapse_api_auth_delegated_introspection_response", - "Time taken to get a response for an introspection request", - ["code"], -) - - # Scope as defined by MSC2967 # https://github.com/matrix-org/matrix-spec-proposals/pull/2967 SCOPE_MATRIX_API = "urn:matrix:org.matrix.msc2967.client:api:*" @@ -341,17 +336,23 @@ class MSC3861DelegatedAuth(BaseAuth): ) except HttpResponseException as e: end_time = self._clock.time() - introspection_response_timer.labels(e.code).observe(end_time - start_time) + introspection_response_timer.labels( + code=e.code, **{SERVER_NAME_LABEL: self.server_name} + ).observe(end_time - start_time) raise except Exception: end_time = self._clock.time() - introspection_response_timer.labels("ERR").observe(end_time - start_time) + introspection_response_timer.labels( + code="ERR", **{SERVER_NAME_LABEL: self.server_name} + ).observe(end_time - start_time) raise logger.debug("Fetched token from MAS") end_time = self._clock.time() - introspection_response_timer.labels(200).observe(end_time - start_time) + introspection_response_timer.labels( + code=200, **{SERVER_NAME_LABEL: self.server_name} + ).observe(end_time - start_time) resp = json_decoder.decode(resp_body.decode("utf-8")) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index b832c2f6a1..ec4d707b7b 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -140,6 +140,12 @@ class Codes(str, Enum): # Part of MSC4155 INVITE_BLOCKED = "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED" + # Part of MSC4306: Thread Subscriptions + MSC4306_CONFLICTING_UNSUBSCRIPTION = ( + "IO.ELEMENT.MSC4306.M_CONFLICTING_UNSUBSCRIPTION" + ) + MSC4306_NOT_IN_THREAD = "IO.ELEMENT.MSC4306.M_NOT_IN_THREAD" + class CodeMessageException(RuntimeError): """An exception with integer code, a message string attributes and optional headers. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 16aab93cd6..48989540bb 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -75,7 +75,7 @@ from synapse.http.site import SynapseSite from synapse.logging.context import PreserveLoggingContext from synapse.logging.opentracing import init_tracer from synapse.metrics import install_gc_manager, register_threadpool -from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.jemalloc import setup_jemalloc_stats from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( @@ -512,6 +512,7 @@ async def start(hs: "HomeServer") -> None: Args: hs: homeserver instance """ + server_name = hs.hostname reactor = hs.get_reactor() # We want to use a separate thread pool for the resolver so that large @@ -524,22 +525,34 @@ async def start(hs: "HomeServer") -> None: ) # Register the threadpools with our metrics. - register_threadpool("default", reactor.getThreadPool()) - register_threadpool("gai_resolver", resolver_threadpool) + register_threadpool( + name="default", server_name=server_name, threadpool=reactor.getThreadPool() + ) + register_threadpool( + name="gai_resolver", server_name=server_name, threadpool=resolver_threadpool + ) # Set up the SIGHUP machinery. if hasattr(signal, "SIGHUP"): - @wrap_as_background_process("sighup") - async def handle_sighup(*args: Any, **kwargs: Any) -> None: - # Tell systemd our state, if we're using it. This will silently fail if - # we're not using systemd. - sdnotify(b"RELOADING=1") + def handle_sighup(*args: Any, **kwargs: Any) -> "defer.Deferred[None]": + async def _handle_sighup(*args: Any, **kwargs: Any) -> None: + # Tell systemd our state, if we're using it. This will silently fail if + # we're not using systemd. + sdnotify(b"RELOADING=1") - for i, args, kwargs in _sighup_callbacks: - i(*args, **kwargs) + for i, args, kwargs in _sighup_callbacks: + i(*args, **kwargs) - sdnotify(b"READY=1") + sdnotify(b"READY=1") + + return run_as_background_process( + "sighup", + server_name, + _handle_sighup, + *args, + **kwargs, + ) # We defer running the sighup handlers until next reactor tick. This # is so that we're in a sane state, e.g. flushing the logs may fail diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 7e8c7cf37e..69d3ac78fd 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -26,7 +26,12 @@ from typing import TYPE_CHECKING, List, Mapping, Sized, Tuple from prometheus_client import Gauge -from synapse.metrics.background_process_metrics import wrap_as_background_process +from twisted.internet import defer + +from synapse.metrics import SERVER_NAME_LABEL +from synapse.metrics.background_process_metrics import ( + run_as_background_process, +) from synapse.types import JsonDict from synapse.util.constants import ONE_HOUR_SECONDS, ONE_MINUTE_SECONDS @@ -53,138 +58,158 @@ Phone home stats are sent every 3 hours _stats_process: List[Tuple[int, "resource.struct_rusage"]] = [] # Gauges to expose monthly active user control metrics -current_mau_gauge = Gauge("synapse_admin_mau_current", "Current MAU") +current_mau_gauge = Gauge( + "synapse_admin_mau_current", + "Current MAU", + labelnames=[SERVER_NAME_LABEL], +) current_mau_by_service_gauge = Gauge( "synapse_admin_mau_current_mau_by_service", "Current MAU by service", - ["app_service"], + labelnames=["app_service", SERVER_NAME_LABEL], +) +max_mau_gauge = Gauge( + "synapse_admin_mau_max", + "MAU Limit", + labelnames=[SERVER_NAME_LABEL], ) -max_mau_gauge = Gauge("synapse_admin_mau_max", "MAU Limit") registered_reserved_users_mau_gauge = Gauge( "synapse_admin_mau_registered_reserved_users", "Registered users with reserved threepids", + labelnames=[SERVER_NAME_LABEL], ) -@wrap_as_background_process("phone_stats_home") -async def phone_stats_home( +def phone_stats_home( hs: "HomeServer", stats: JsonDict, stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process, -) -> None: - """Collect usage statistics and send them to the configured endpoint. +) -> "defer.Deferred[None]": + server_name = hs.hostname - Args: - hs: the HomeServer object to use for gathering usage data. - stats: the dict in which to store the statistics sent to the configured - endpoint. Mostly used in tests to figure out the data that is supposed to - be sent. - stats_process: statistics about resource usage of the process. - """ + async def _phone_stats_home( + hs: "HomeServer", + stats: JsonDict, + stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process, + ) -> None: + """Collect usage statistics and send them to the configured endpoint. - logger.info("Gathering stats for reporting") - now = int(hs.get_clock().time()) - # Ensure the homeserver has started. - assert hs.start_time is not None - uptime = int(now - hs.start_time) - if uptime < 0: - uptime = 0 + Args: + hs: the HomeServer object to use for gathering usage data. + stats: the dict in which to store the statistics sent to the configured + endpoint. Mostly used in tests to figure out the data that is supposed to + be sent. + stats_process: statistics about resource usage of the process. + """ - # - # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test. - # - old = stats_process[0] - new = (now, resource.getrusage(resource.RUSAGE_SELF)) - stats_process[0] = new + logger.info("Gathering stats for reporting") + now = int(hs.get_clock().time()) + # Ensure the homeserver has started. + assert hs.start_time is not None + uptime = int(now - hs.start_time) + if uptime < 0: + uptime = 0 - # Get RSS in bytes - stats["memory_rss"] = new[1].ru_maxrss + # + # Performance statistics. Keep this early in the function to maintain reliability of `test_performance_100` test. + # + old = stats_process[0] + new = (now, resource.getrusage(resource.RUSAGE_SELF)) + stats_process[0] = new - # Get CPU time in % of a single core, not % of all cores - used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( - old[1].ru_utime + old[1].ru_stime - ) - if used_cpu_time == 0 or new[0] == old[0]: - stats["cpu_average"] = 0 - else: - stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) + # Get RSS in bytes + stats["memory_rss"] = new[1].ru_maxrss - # - # General statistics - # - - store = hs.get_datastores().main - common_metrics = await hs.get_common_usage_metrics_manager().get_metrics() - - stats["homeserver"] = hs.config.server.server_name - stats["server_context"] = hs.config.server.server_context - stats["timestamp"] = now - stats["uptime_seconds"] = uptime - version = sys.version_info - stats["python_version"] = "{}.{}.{}".format( - version.major, version.minor, version.micro - ) - stats["total_users"] = await store.count_all_users() - - total_nonbridged_users = await store.count_nonbridged_users() - stats["total_nonbridged_users"] = total_nonbridged_users - - daily_user_type_results = await store.count_daily_user_type() - for name, count in daily_user_type_results.items(): - stats["daily_user_type_" + name] = count - - room_count = await store.get_room_count() - stats["total_room_count"] = room_count - - stats["daily_active_users"] = common_metrics.daily_active_users - stats["monthly_active_users"] = await store.count_monthly_users() - daily_active_e2ee_rooms = await store.count_daily_active_e2ee_rooms() - stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms - stats["daily_e2ee_messages"] = await store.count_daily_e2ee_messages() - daily_sent_e2ee_messages = await store.count_daily_sent_e2ee_messages() - stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages - stats["daily_active_rooms"] = await store.count_daily_active_rooms() - stats["daily_messages"] = await store.count_daily_messages() - daily_sent_messages = await store.count_daily_sent_messages() - stats["daily_sent_messages"] = daily_sent_messages - - r30v2_results = await store.count_r30v2_users() - for name, count in r30v2_results.items(): - stats["r30v2_users_" + name] = count - - stats["cache_factor"] = hs.config.caches.global_factor - stats["event_cache_size"] = hs.config.caches.event_cache_size - - # - # Database version - # - - # This only reports info about the *main* database. - stats["database_engine"] = store.db_pool.engine.module.__name__ - stats["database_server_version"] = store.db_pool.engine.server_version - - # - # Logging configuration - # - synapse_logger = logging.getLogger("synapse") - log_level = synapse_logger.getEffectiveLevel() - stats["log_level"] = logging.getLevelName(log_level) - - logger.info( - "Reporting stats to %s: %s", hs.config.metrics.report_stats_endpoint, stats - ) - try: - await hs.get_proxied_http_client().put_json( - hs.config.metrics.report_stats_endpoint, stats + # Get CPU time in % of a single core, not % of all cores + used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - ( + old[1].ru_utime + old[1].ru_stime ) - except Exception as e: - logger.warning("Error reporting stats: %s", e) + if used_cpu_time == 0 or new[0] == old[0]: + stats["cpu_average"] = 0 + else: + stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100) + + # + # General statistics + # + + store = hs.get_datastores().main + common_metrics = await hs.get_common_usage_metrics_manager().get_metrics() + + stats["homeserver"] = hs.config.server.server_name + stats["server_context"] = hs.config.server.server_context + stats["timestamp"] = now + stats["uptime_seconds"] = uptime + version = sys.version_info + stats["python_version"] = "{}.{}.{}".format( + version.major, version.minor, version.micro + ) + stats["total_users"] = await store.count_all_users() + + total_nonbridged_users = await store.count_nonbridged_users() + stats["total_nonbridged_users"] = total_nonbridged_users + + daily_user_type_results = await store.count_daily_user_type() + for name, count in daily_user_type_results.items(): + stats["daily_user_type_" + name] = count + + room_count = await store.get_room_count() + stats["total_room_count"] = room_count + + stats["daily_active_users"] = common_metrics.daily_active_users + stats["monthly_active_users"] = await store.count_monthly_users() + daily_active_e2ee_rooms = await store.count_daily_active_e2ee_rooms() + stats["daily_active_e2ee_rooms"] = daily_active_e2ee_rooms + stats["daily_e2ee_messages"] = await store.count_daily_e2ee_messages() + daily_sent_e2ee_messages = await store.count_daily_sent_e2ee_messages() + stats["daily_sent_e2ee_messages"] = daily_sent_e2ee_messages + stats["daily_active_rooms"] = await store.count_daily_active_rooms() + stats["daily_messages"] = await store.count_daily_messages() + daily_sent_messages = await store.count_daily_sent_messages() + stats["daily_sent_messages"] = daily_sent_messages + + r30v2_results = await store.count_r30v2_users() + for name, count in r30v2_results.items(): + stats["r30v2_users_" + name] = count + + stats["cache_factor"] = hs.config.caches.global_factor + stats["event_cache_size"] = hs.config.caches.event_cache_size + + # + # Database version + # + + # This only reports info about the *main* database. + stats["database_engine"] = store.db_pool.engine.module.__name__ + stats["database_server_version"] = store.db_pool.engine.server_version + + # + # Logging configuration + # + synapse_logger = logging.getLogger("synapse") + log_level = synapse_logger.getEffectiveLevel() + stats["log_level"] = logging.getLevelName(log_level) + + logger.info( + "Reporting stats to %s: %s", hs.config.metrics.report_stats_endpoint, stats + ) + try: + await hs.get_proxied_http_client().put_json( + hs.config.metrics.report_stats_endpoint, stats + ) + except Exception as e: + logger.warning("Error reporting stats: %s", e) + + return run_as_background_process( + "phone_stats_home", server_name, _phone_stats_home, hs, stats, stats_process + ) def start_phone_stats_home(hs: "HomeServer") -> None: """ Start the background tasks which report phone home stats. """ + server_name = hs.hostname clock = hs.get_clock() stats: JsonDict = {} @@ -210,25 +235,39 @@ def start_phone_stats_home(hs: "HomeServer") -> None: ) hs.get_datastores().main.reap_monthly_active_users() - @wrap_as_background_process("generate_monthly_active_users") - async def generate_monthly_active_users() -> None: - current_mau_count = 0 - current_mau_count_by_service: Mapping[str, int] = {} - reserved_users: Sized = () - store = hs.get_datastores().main - if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: - current_mau_count = await store.get_monthly_active_count() - current_mau_count_by_service = ( - await store.get_monthly_active_count_by_service() + def generate_monthly_active_users() -> "defer.Deferred[None]": + async def _generate_monthly_active_users() -> None: + current_mau_count = 0 + current_mau_count_by_service: Mapping[str, int] = {} + reserved_users: Sized = () + store = hs.get_datastores().main + if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: + current_mau_count = await store.get_monthly_active_count() + current_mau_count_by_service = ( + await store.get_monthly_active_count_by_service() + ) + reserved_users = await store.get_registered_reserved_users() + current_mau_gauge.labels(**{SERVER_NAME_LABEL: server_name}).set( + float(current_mau_count) ) - reserved_users = await store.get_registered_reserved_users() - current_mau_gauge.set(float(current_mau_count)) - for app_service, count in current_mau_count_by_service.items(): - current_mau_by_service_gauge.labels(app_service).set(float(count)) + for app_service, count in current_mau_count_by_service.items(): + current_mau_by_service_gauge.labels( + app_service=app_service, **{SERVER_NAME_LABEL: server_name} + ).set(float(count)) - registered_reserved_users_mau_gauge.set(float(len(reserved_users))) - max_mau_gauge.set(float(hs.config.server.max_mau_value)) + registered_reserved_users_mau_gauge.labels( + **{SERVER_NAME_LABEL: server_name} + ).set(float(len(reserved_users))) + max_mau_gauge.labels(**{SERVER_NAME_LABEL: server_name}).set( + float(hs.config.server.max_mau_value) + ) + + return run_as_background_process( + "generate_monthly_active_users", + server_name, + _generate_monthly_active_users, + ) if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: generate_monthly_active_users() diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 8c21e0951a..55069cc5d3 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -48,6 +48,7 @@ from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig, serialize_event from synapse.http.client import SimpleHttpClient, is_unknown_endpoint from synapse.logging import opentracing +from synapse.metrics import SERVER_NAME_LABEL from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache @@ -59,29 +60,31 @@ logger = logging.getLogger(__name__) sent_transactions_counter = Counter( "synapse_appservice_api_sent_transactions", "Number of /transactions/ requests sent", - ["service"], + labelnames=["service", SERVER_NAME_LABEL], ) failed_transactions_counter = Counter( "synapse_appservice_api_failed_transactions", "Number of /transactions/ requests that failed to send", - ["service"], + labelnames=["service", SERVER_NAME_LABEL], ) sent_events_counter = Counter( - "synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"] + "synapse_appservice_api_sent_events", + "Number of events sent to the AS", + labelnames=["service", SERVER_NAME_LABEL], ) sent_ephemeral_counter = Counter( "synapse_appservice_api_sent_ephemeral", "Number of ephemeral events sent to the AS", - ["service"], + labelnames=["service", SERVER_NAME_LABEL], ) sent_todevice_counter = Counter( "synapse_appservice_api_sent_todevice", "Number of todevice messages sent to the AS", - ["service"], + labelnames=["service", SERVER_NAME_LABEL], ) HOUR_IN_MS = 60 * 60 * 1000 @@ -382,6 +385,7 @@ class ApplicationServiceApi(SimpleHttpClient): "left": list(device_list_summary.left), } + labels = {"service": service.id, SERVER_NAME_LABEL: self.server_name} try: args = None if self.config.use_appservice_legacy_authorization: @@ -399,10 +403,10 @@ class ApplicationServiceApi(SimpleHttpClient): service.url, [event.get("event_id") for event in events], ) - sent_transactions_counter.labels(service.id).inc() - sent_events_counter.labels(service.id).inc(len(serialized_events)) - sent_ephemeral_counter.labels(service.id).inc(len(ephemeral)) - sent_todevice_counter.labels(service.id).inc(len(to_device_messages)) + sent_transactions_counter.labels(**labels).inc() + sent_events_counter.labels(**labels).inc(len(serialized_events)) + sent_ephemeral_counter.labels(**labels).inc(len(ephemeral)) + sent_todevice_counter.labels(**labels).inc(len(to_device_messages)) return True except CodeMessageException as e: logger.warning( @@ -421,7 +425,7 @@ class ApplicationServiceApi(SimpleHttpClient): ex.args, exc_info=logger.isEnabledFor(logging.DEBUG), ) - failed_transactions_counter.labels(service.id).inc() + failed_transactions_counter.labels(**labels).inc() return False async def claim_client_keys( diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index 9d7fc0995a..01f77c4cb6 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -103,18 +103,16 @@ MAX_TO_DEVICE_MESSAGES_PER_TRANSACTION = 100 class ApplicationServiceScheduler: - """Public facing API for this module. Does the required DI to tie the - components together. This also serves as the "event_pool", which in this + """ + Public facing API for this module. Does the required dependency injection (DI) to + tie the components together. This also serves as the "event_pool", which in this case is a simple array. """ def __init__(self, hs: "HomeServer"): - self.clock = hs.get_clock() + self.txn_ctrl = _TransactionController(hs) self.store = hs.get_datastores().main - self.as_api = hs.get_application_service_api() - - self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api) - self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock, hs) + self.queuer = _ServiceQueuer(self.txn_ctrl, hs) async def start(self) -> None: logger.info("Starting appservice scheduler") @@ -184,9 +182,7 @@ class _ServiceQueuer: appservice at a given time. """ - def __init__( - self, txn_ctrl: "_TransactionController", clock: Clock, hs: "HomeServer" - ): + def __init__(self, txn_ctrl: "_TransactionController", hs: "HomeServer"): # dict of {service_id: [events]} self.queued_events: Dict[str, List[EventBase]] = {} # dict of {service_id: [events]} @@ -199,10 +195,11 @@ class _ServiceQueuer: # the appservices which currently have a transaction in flight self.requests_in_flight: Set[str] = set() self.txn_ctrl = txn_ctrl - self.clock = clock self._msc3202_transaction_extensions_enabled: bool = ( hs.config.experimental.msc3202_transaction_extensions ) + self.server_name = hs.hostname + self.clock = hs.get_clock() self._store = hs.get_datastores().main def start_background_request(self, service: ApplicationService) -> None: @@ -210,7 +207,9 @@ class _ServiceQueuer: if service.id in self.requests_in_flight: return - run_as_background_process("as-sender", self._send_request, service) + run_as_background_process( + "as-sender", self.server_name, self._send_request, service + ) async def _send_request(self, service: ApplicationService) -> None: # sanity-check: we shouldn't get here if this service already has a sender @@ -359,10 +358,11 @@ class _TransactionController: (Note we have only have one of these in the homeserver.) """ - def __init__(self, clock: Clock, store: DataStore, as_api: ApplicationServiceApi): - self.clock = clock - self.store = store - self.as_api = as_api + def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname + self.clock = hs.get_clock() + self.store = hs.get_datastores().main + self.as_api = hs.get_application_service_api() # map from service id to recoverer instance self.recoverers: Dict[str, "_Recoverer"] = {} @@ -446,7 +446,12 @@ class _TransactionController: logger.info("Starting recoverer for AS ID %s", service.id) assert service.id not in self.recoverers recoverer = self.RECOVERER_CLASS( - self.clock, self.store, self.as_api, service, self.on_recovered + self.server_name, + self.clock, + self.store, + self.as_api, + service, + self.on_recovered, ) self.recoverers[service.id] = recoverer recoverer.recover() @@ -477,21 +482,24 @@ class _Recoverer: We have one of these for each appservice which is currently considered DOWN. Args: - clock (synapse.util.Clock): - store (synapse.storage.DataStore): - as_api (synapse.appservice.api.ApplicationServiceApi): - service (synapse.appservice.ApplicationService): the service we are managing - callback (callable[_Recoverer]): called once the service recovers. + server_name: the homeserver name (used to label metrics) (this should be `hs.hostname`). + clock: + store: + as_api: + service: the service we are managing + callback: called once the service recovers. """ def __init__( self, + server_name: str, clock: Clock, store: DataStore, as_api: ApplicationServiceApi, service: ApplicationService, callback: Callable[["_Recoverer"], Awaitable[None]], ): + self.server_name = server_name self.clock = clock self.store = store self.as_api = as_api @@ -504,7 +512,11 @@ class _Recoverer: delay = 2**self.backoff_counter logger.info("Scheduling retries on %s in %fs", self.service.id, delay) self.scheduled_recovery = self.clock.call_later( - delay, run_as_background_process, "as-recoverer", self.retry + delay, + run_as_background_process, + "as-recoverer", + self.server_name, + self.retry, ) def _backoff(self) -> None: @@ -525,6 +537,7 @@ class _Recoverer: # Run a retry, which will resechedule a recovery if it fails. run_as_background_process( "retry", + self.server_name, self.retry, ) diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 8b065f175d..5e03635206 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -36,6 +36,7 @@ from synapse.config import ( # noqa: F401 jwt, key, logger, + mas, metrics, modules, oembed, @@ -124,6 +125,7 @@ class RootConfig: background_updates: background_updates.BackgroundUpdateConfig auto_accept_invites: auto_accept_invites.AutoAcceptInvitesConfig user_types: user_types.UserTypesConfig + mas: mas.MasConfig config_classes: List[Type["Config"]] = ... config_files: List[str] diff --git a/synapse/config/auth.py b/synapse/config/auth.py index 9246fd6430..31b332dc09 100644 --- a/synapse/config/auth.py +++ b/synapse/config/auth.py @@ -36,13 +36,14 @@ class AuthConfig(Config): if password_config is None: password_config = {} - # The default value of password_config.enabled is True, unless msc3861 is enabled. - msc3861_enabled = ( - (config.get("experimental_features") or {}) - .get("msc3861", {}) - .get("enabled", False) - ) - passwords_enabled = password_config.get("enabled", not msc3861_enabled) + auth_delegated = (config.get("experimental_features") or {}).get( + "msc3861", {} + ).get("enabled", False) or ( + config.get("matrix_authentication_service") or {} + ).get("enabled", False) + + # The default value of password_config.enabled is True, unless auth is delegated + passwords_enabled = password_config.get("enabled", not auth_delegated) # 'only_for_reauth' allows users who have previously set a password to use it, # even though passwords would otherwise be disabled. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 1b7474034f..c4181a6b0b 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -582,6 +582,9 @@ class ExperimentalConfig(Config): # MSC4155: Invite filtering self.msc4155_enabled: bool = experimental.get("msc4155_enabled", False) + # MSC4293: Redact on Kick/Ban + self.msc4293_enabled: bool = experimental.get("msc4293_enabled", False) + # MSC4306: Thread Subscriptions # (and MSC4308: sliding sync extension for thread subscriptions) self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False) diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 0b2413a83b..5d7089c2e6 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -36,6 +36,7 @@ from .federation import FederationConfig from .jwt import JWTConfig from .key import KeyConfig from .logger import LoggingConfig +from .mas import MasConfig from .metrics import MetricsConfig from .modules import ModulesConfig from .oembed import OembedConfig @@ -109,4 +110,6 @@ class HomeServerConfig(RootConfig): BackgroundUpdateConfig, AutoAcceptInvitesConfig, UserTypesConfig, + # This must be last, as it checks for conflicts with other config options. + MasConfig, ] diff --git a/synapse/config/mas.py b/synapse/config/mas.py new file mode 100644 index 0000000000..fe0d326f7a --- /dev/null +++ b/synapse/config/mas.py @@ -0,0 +1,192 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# + +from typing import Any, Optional + +from synapse._pydantic_compat import ( + AnyHttpUrl, + Field, + FilePath, + StrictBool, + StrictStr, + ValidationError, + validator, +) +from synapse.config.experimental import read_secret_from_file_once +from synapse.types import JsonDict +from synapse.util.pydantic_models import ParseModel + +from ._base import Config, ConfigError, RootConfig + + +class MasConfigModel(ParseModel): + enabled: StrictBool = False + endpoint: AnyHttpUrl = Field(default="http://localhost:8080") + secret: Optional[StrictStr] = Field(default=None) + secret_path: Optional[FilePath] = Field(default=None) + + @validator("secret") + def validate_secret_is_set_if_enabled(cls, v: Any, values: dict) -> Any: + if values.get("enabled", False) and not values.get("secret_path") and not v: + raise ValueError( + "You must set a `secret` or `secret_path` when enabling Matrix Authentication Service integration." + ) + + return v + + @validator("secret_path") + def validate_secret_path_is_set_if_enabled(cls, v: Any, values: dict) -> Any: + if values.get("secret"): + raise ValueError( + "`secret` and `secret_path` cannot be set at the same time." + ) + + return v + + +class MasConfig(Config): + section = "mas" + + def read_config( + self, config: JsonDict, allow_secrets_in_config: bool, **kwargs: Any + ) -> None: + mas_config = config.get("matrix_authentication_service", {}) + if mas_config is None: + mas_config = {} + + try: + parsed = MasConfigModel(**mas_config) + except ValidationError as e: + raise ConfigError( + "Could not validate Matrix Authentication Service configuration", + path=("matrix_authentication_service",), + ) from e + + if parsed.secret and not allow_secrets_in_config: + raise ConfigError( + "Config options that expect an in-line secret as value are disabled", + ("matrix_authentication_service", "secret"), + ) + + self.enabled = parsed.enabled + self.endpoint = parsed.endpoint + self._secret = parsed.secret + self._secret_path = parsed.secret_path + + self.check_config_conflicts(self.root) + + def check_config_conflicts( + self, + root: RootConfig, + ) -> None: + """Checks for any configuration conflicts with other parts of Synapse. + + Raises: + ConfigError: If there are any configuration conflicts. + """ + + if not self.enabled: + return + + if root.experimental.msc3861.enabled: + raise ConfigError( + "Experimental MSC3861 was replaced by Matrix Authentication Service." + "Please disable MSC3861 or disable Matrix Authentication Service.", + ("experimental", "msc3861"), + ) + + if ( + root.auth.password_enabled_for_reauth + or root.auth.password_enabled_for_login + ): + raise ConfigError( + "Password auth cannot be enabled when OAuth delegation is enabled", + ("password_config", "enabled"), + ) + + if root.registration.enable_registration: + raise ConfigError( + "Registration cannot be enabled when OAuth delegation is enabled", + ("enable_registration",), + ) + + # We only need to test the user consent version, as if it must be set if the user_consent section was present in the config + if root.consent.user_consent_version is not None: + raise ConfigError( + "User consent cannot be enabled when OAuth delegation is enabled", + ("user_consent",), + ) + + if ( + root.oidc.oidc_enabled + or root.saml2.saml2_enabled + or root.cas.cas_enabled + or root.jwt.jwt_enabled + ): + raise ConfigError("SSO cannot be enabled when OAuth delegation is enabled") + + if bool(root.authproviders.password_providers): + raise ConfigError( + "Password auth providers cannot be enabled when OAuth delegation is enabled" + ) + + if root.captcha.enable_registration_captcha: + raise ConfigError( + "CAPTCHA cannot be enabled when OAuth delegation is enabled", + ("captcha", "enable_registration_captcha"), + ) + + if root.auth.login_via_existing_enabled: + raise ConfigError( + "Login via existing session cannot be enabled when OAuth delegation is enabled", + ("login_via_existing_session", "enabled"), + ) + + if root.registration.refresh_token_lifetime: + raise ConfigError( + "refresh_token_lifetime cannot be set when OAuth delegation is enabled", + ("refresh_token_lifetime",), + ) + + if root.registration.nonrefreshable_access_token_lifetime: + raise ConfigError( + "nonrefreshable_access_token_lifetime cannot be set when OAuth delegation is enabled", + ("nonrefreshable_access_token_lifetime",), + ) + + if root.registration.session_lifetime: + raise ConfigError( + "session_lifetime cannot be set when OAuth delegation is enabled", + ("session_lifetime",), + ) + + if root.registration.enable_3pid_changes: + raise ConfigError( + "enable_3pid_changes cannot be enabled when OAuth delegation is enabled", + ("enable_3pid_changes",), + ) + + def secret(self) -> str: + if self._secret is not None: + return self._secret + elif self._secret_path is not None: + return read_secret_from_file_once( + str(self._secret_path), + ("matrix_authentication_service", "secret_path"), + ) + else: + raise RuntimeError( + "Neither `secret` nor `secret_path` are set, this is a bug.", + ) diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 290701615f..b082daa8f7 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -241,6 +241,12 @@ class RatelimitConfig(Config): defaults={"per_second": 1, "burst_count": 5}, ) + self.rc_room_creation = RatelimitSettings.parse( + config, + "rc_room_creation", + defaults={"per_second": 0.016, "burst_count": 10}, + ) + self.rc_reports = RatelimitSettings.parse( config, "rc_reports", diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 8adf21079e..283199aa11 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -148,15 +148,14 @@ class RegistrationConfig(Config): self.enable_set_displayname = config.get("enable_set_displayname", True) self.enable_set_avatar_url = config.get("enable_set_avatar_url", True) + auth_delegated = (config.get("experimental_features") or {}).get( + "msc3861", {} + ).get("enabled", False) or ( + config.get("matrix_authentication_service") or {} + ).get("enabled", False) + # The default value of enable_3pid_changes is True, unless msc3861 is enabled. - msc3861_enabled = ( - (config.get("experimental_features") or {}) - .get("msc3861", {}) - .get("enabled", False) - ) - self.enable_3pid_changes = config.get( - "enable_3pid_changes", not msc3861_enabled - ) + self.enable_3pid_changes = config.get("enable_3pid_changes", not auth_delegated) self.disable_msisdn_registration = config.get( "disable_msisdn_registration", False diff --git a/synapse/config/repository.py b/synapse/config/repository.py index e6a5064c16..efdc505659 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -22,11 +22,10 @@ import logging import os from typing import Any, Dict, List, Tuple -from urllib.request import getproxies_environment import attr -from synapse.config.server import generate_ip_set +from synapse.config.server import generate_ip_set, parse_proxy_config from synapse.types import JsonDict from synapse.util.check_dependencies import check_requirements from synapse.util.module_loader import load_module @@ -61,7 +60,7 @@ THUMBNAIL_SUPPORTED_MEDIA_FORMAT_MAP = { "image/png": "png", } -HTTP_PROXY_SET_WARNING = """\ +URL_PREVIEW_BLACKLIST_IGNORED_BECAUSE_HTTP_PROXY_SET_WARNING = """\ The Synapse config url_preview_ip_range_blacklist will be ignored as an HTTP(s) proxy is configured.""" @@ -234,17 +233,25 @@ class ContentRepositoryConfig(Config): if self.url_preview_enabled: check_requirements("url-preview") - proxy_env = getproxies_environment() - if "url_preview_ip_range_blacklist" not in config: - if "http" not in proxy_env or "https" not in proxy_env: + proxy_config = parse_proxy_config(config) + is_proxy_configured = ( + proxy_config.http_proxy is not None + or proxy_config.https_proxy is not None + ) + if "url_preview_ip_range_blacklist" in config: + if is_proxy_configured: + logger.warning( + "".join( + URL_PREVIEW_BLACKLIST_IGNORED_BECAUSE_HTTP_PROXY_SET_WARNING + ) + ) + else: + if not is_proxy_configured: raise ConfigError( "For security, you must specify an explicit target IP address " "blacklist in url_preview_ip_range_blacklist for url previewing " "to work" ) - else: - if "http" in proxy_env or "https" in proxy_env: - logger.warning("".join(HTTP_PROXY_SET_WARNING)) # we always block '0.0.0.0' and '::', which are supposed to be # unroutable addresses. diff --git a/synapse/config/server.py b/synapse/config/server.py index 6893450989..e15bceb296 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -25,11 +25,13 @@ import logging import os.path import urllib.parse from textwrap import indent -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, TypedDict, Union +from urllib.request import getproxies_environment import attr import yaml from netaddr import AddrFormatError, IPNetwork, IPSet +from typing_extensions import TypeGuard from twisted.conch.ssh.keys import Key @@ -43,6 +45,21 @@ from ._util import validate_config logger = logging.getLogger(__name__) + +# Directly from the mypy docs: +# https://typing.python.org/en/latest/spec/narrowing.html#typeguard +def is_str_list(val: Any, allow_empty: bool) -> TypeGuard[list[str]]: + """ + Type-narrow a value to a list of strings (compatible with mypy). + """ + if not isinstance(val, list): + return False + + if len(val) == 0: + return allow_empty + return all(isinstance(x, str) for x in val) + + DIRECT_TCP_ERROR = """ Using direct TCP replication for workers is no longer supported. @@ -291,6 +308,102 @@ class LimitRemoteRoomsConfig: ) +class ProxyConfigDictionary(TypedDict): + """ + Dictionary of proxy settings suitable for interacting with `urllib.request` API's + """ + + http: Optional[str] + """ + Proxy server to use for HTTP requests. + """ + https: Optional[str] + """ + Proxy server to use for HTTPS requests. + """ + no: str + """ + Comma-separated list of hosts, IP addresses, or IP ranges in CIDR format which + should not use the proxy. + + Empty string means no hosts should be excluded from the proxy. + """ + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ProxyConfig: + """ + Synapse configuration for HTTP proxy settings. + """ + + http_proxy: Optional[str] + """ + Proxy server to use for HTTP requests. + """ + https_proxy: Optional[str] + """ + Proxy server to use for HTTPS requests. + """ + no_proxy_hosts: Optional[List[str]] + """ + List of hosts, IP addresses, or IP ranges in CIDR format which should not use the + proxy. Synapse will directly connect to these hosts. + """ + + def get_proxies_dictionary(self) -> ProxyConfigDictionary: + """ + Returns a dictionary of proxy settings suitable for interacting with + `urllib.request` API's (e.g. `urllib.request.proxy_bypass_environment`) + + The keys are `"http"`, `"https"`, and `"no"`. + """ + return ProxyConfigDictionary( + http=self.http_proxy, + https=self.https_proxy, + no=",".join(self.no_proxy_hosts) if self.no_proxy_hosts else "", + ) + + +def parse_proxy_config(config: JsonDict) -> ProxyConfig: + """ + Figure out forward proxy config for outgoing HTTP requests. + + Prefer values from the given config over the environment variables (`http_proxy`, + `https_proxy`, `no_proxy`, not case-sensitive). + + Args: + config: The top-level homeserver configuration dictionary. + """ + proxies_from_env = getproxies_environment() + http_proxy = config.get("http_proxy", proxies_from_env.get("http")) + if http_proxy is not None and not isinstance(http_proxy, str): + raise ConfigError("'http_proxy' must be a string", ("http_proxy",)) + + https_proxy = config.get("https_proxy", proxies_from_env.get("https")) + if https_proxy is not None and not isinstance(https_proxy, str): + raise ConfigError("'https_proxy' must be a string", ("https_proxy",)) + + # List of hosts which should not use the proxy. Synapse will directly connect to + # these hosts. + no_proxy_hosts = config.get("no_proxy_hosts") + # The `no_proxy` environment variable should be a comma-separated list of hosts, + # IP addresses, or IP ranges in CIDR format + no_proxy_from_env = proxies_from_env.get("no") + if no_proxy_hosts is None and no_proxy_from_env is not None: + no_proxy_hosts = no_proxy_from_env.split(",") + + if no_proxy_hosts is not None and not is_str_list(no_proxy_hosts, allow_empty=True): + raise ConfigError( + "'no_proxy_hosts' must be a list of strings", ("no_proxy_hosts",) + ) + + return ProxyConfig( + http_proxy=http_proxy, + https_proxy=https_proxy, + no_proxy_hosts=no_proxy_hosts, + ) + + class ServerConfig(Config): section = "server" @@ -718,6 +831,17 @@ class ServerConfig(Config): ) ) + # Figure out forward proxy config for outgoing HTTP requests. + # + # Prefer values from the file config over the environment variables + self.proxy_config = parse_proxy_config(config) + logger.debug( + "Using proxy settings: http_proxy=%s, https_proxy=%s, no_proxy=%s", + self.proxy_config.http_proxy, + self.proxy_config.https_proxy, + self.proxy_config.no_proxy_hosts, + ) + self.cleanup_extremities_with_dummy_events = config.get( "cleanup_extremities_with_dummy_events", True ) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 643d2d4e66..8c59772e56 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -152,6 +152,8 @@ class Keyring: def __init__( self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None ): + self.server_name = hs.hostname + if key_fetchers is None: # Always fetch keys from the database. mutable_key_fetchers: List[KeyFetcher] = [StoreKeyFetcher(hs)] @@ -169,7 +171,8 @@ class Keyring: self._fetch_keys_queue: BatchingQueue[ _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]] ] = BatchingQueue( - "keyring_server", + name="keyring_server", + server_name=self.server_name, clock=hs.get_clock(), # The method called to fetch each key process_batch_callback=self._inner_fetch_key_requests, @@ -473,8 +476,12 @@ class Keyring: class KeyFetcher(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname self._queue = BatchingQueue( - self.__class__.__name__, hs.get_clock(), self._fetch_keys + name=self.__class__.__name__, + server_name=self.server_name, + clock=hs.get_clock(), + process_batch_callback=self._fetch_keys, ) async def get_keys( diff --git a/synapse/events/auto_accept_invites.py b/synapse/events/auto_accept_invites.py index 9e5f76f33f..6873ee9d31 100644 --- a/synapse/events/auto_accept_invites.py +++ b/synapse/events/auto_accept_invites.py @@ -34,6 +34,7 @@ class InviteAutoAccepter: def __init__(self, config: AutoAcceptInvitesConfig, api: ModuleApi): # Keep a reference to the Module API. self._api = api + self.server_name = api.server_name self._config = config if not self._config.enabled: diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a6c07fcfd7..cae27136ce 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -545,8 +545,11 @@ def serialize_event( d["content"] = dict(d["content"]) d["content"]["redacts"] = e.redacts - if config.include_admin_metadata and e.internal_metadata.is_soft_failed(): - d["unsigned"]["io.element.synapse.soft_failed"] = True + if config.include_admin_metadata: + if e.internal_metadata.is_soft_failed(): + d["unsigned"]["io.element.synapse.soft_failed"] = True + if e.internal_metadata.policy_server_spammy: + d["unsigned"]["io.element.synapse.policy_server_spammy"] = True only_event_fields = config.only_event_fields if only_event_fields: diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 05c7809dc8..a1c9c286ac 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -174,6 +174,7 @@ class FederationBase: "Event not allowed by policy server, soft-failing %s", pdu.event_id ) pdu.internal_metadata.soft_failed = True + pdu.internal_metadata.policy_server_spammy = True # Note: we don't redact the event so admins can inspect the event after the # fact. Other processes may redact the event, but that won't be applied to # the database copy of the event until the server's config requires it. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 35c5ac6311..542d9650d4 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -74,6 +74,7 @@ from synapse.federation.transport.client import SendJoinResponse from synapse.http.client import is_unknown_endpoint from synapse.http.types import QueryParams from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace +from synapse.metrics import SERVER_NAME_LABEL from synapse.types import JsonDict, StrCollection, UserID, get_domain_from_id from synapse.types.handlers.policy_server import RECOMMENDATION_OK, RECOMMENDATION_SPAM from synapse.util.async_helpers import concurrently_execute @@ -85,7 +86,9 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"]) +sent_queries_counter = Counter( + "synapse_federation_client_sent_queries", "", labelnames=["type", SERVER_NAME_LABEL] +) PDU_RETRY_TIME_MS = 1 * 60 * 1000 @@ -209,7 +212,10 @@ class FederationClient(FederationBase): Returns: The JSON object from the response """ - sent_queries_counter.labels(query_type).inc() + sent_queries_counter.labels( + type=query_type, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() return await self.transport_layer.make_query( destination, @@ -231,7 +237,10 @@ class FederationClient(FederationBase): Returns: The JSON object from the response """ - sent_queries_counter.labels("client_device_keys").inc() + sent_queries_counter.labels( + type="client_device_keys", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() return await self.transport_layer.query_client_keys( destination, content, timeout ) @@ -242,7 +251,10 @@ class FederationClient(FederationBase): """Query the device keys for a list of user ids hosted on a remote server. """ - sent_queries_counter.labels("user_devices").inc() + sent_queries_counter.labels( + type="user_devices", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() return await self.transport_layer.query_user_devices( destination, user_id, timeout ) @@ -264,7 +276,10 @@ class FederationClient(FederationBase): Returns: The JSON object from the response """ - sent_queries_counter.labels("client_one_time_keys").inc() + sent_queries_counter.labels( + type="client_one_time_keys", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() # Convert the query with counts into a stable and unstable query and check # if attempting to claim more than 1 OTK. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 3e6b8b8493..127518e1f7 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -82,6 +82,7 @@ from synapse.logging.opentracing import ( tag_args, trace, ) +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, @@ -104,23 +105,30 @@ TRANSACTION_CONCURRENCY_LIMIT = 10 logger = logging.getLogger(__name__) -received_pdus_counter = Counter("synapse_federation_server_received_pdus", "") +received_pdus_counter = Counter( + "synapse_federation_server_received_pdus", "", labelnames=[SERVER_NAME_LABEL] +) -received_edus_counter = Counter("synapse_federation_server_received_edus", "") +received_edus_counter = Counter( + "synapse_federation_server_received_edus", "", labelnames=[SERVER_NAME_LABEL] +) received_queries_counter = Counter( - "synapse_federation_server_received_queries", "", ["type"] + "synapse_federation_server_received_queries", + "", + labelnames=["type", SERVER_NAME_LABEL], ) pdu_process_time = Histogram( "synapse_federation_server_pdu_process_time", "Time taken to process an event", + labelnames=[SERVER_NAME_LABEL], ) last_pdu_ts_metric = Gauge( "synapse_federation_last_received_pdu_time", "The timestamp of the last PDU which was successfully received from the given domain", - labelnames=("server_name",), + labelnames=("origin_server_name", SERVER_NAME_LABEL), ) @@ -434,7 +442,9 @@ class FederationServer(FederationBase): report back to the sending server. """ - received_pdus_counter.inc(len(transaction.pdus)) + received_pdus_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc( + len(transaction.pdus) + ) origin_host, _ = parse_server_name(origin) @@ -545,7 +555,9 @@ class FederationServer(FederationBase): ) if newest_pdu_ts and origin in self._federation_metrics_domains: - last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000) + last_pdu_ts_metric.labels( + origin_server_name=origin, **{SERVER_NAME_LABEL: self.server_name} + ).set(newest_pdu_ts / 1000) return pdu_results @@ -553,7 +565,7 @@ class FederationServer(FederationBase): """Process the EDUs in a received transaction.""" async def _process_edu(edu_dict: JsonDict) -> None: - received_edus_counter.inc() + received_edus_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc() edu = Edu( origin=origin, @@ -668,7 +680,10 @@ class FederationServer(FederationBase): async def on_query_request( self, query_type: str, args: Dict[str, str] ) -> Tuple[int, Dict[str, Any]]: - received_queries_counter.labels(query_type).inc() + received_queries_counter.labels( + type=query_type, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() resp = await self.registry.on_query(query_type, args) return 200, resp @@ -1310,9 +1325,9 @@ class FederationServer(FederationBase): origin, event.event_id ) if received_ts is not None: - pdu_process_time.observe( - (self._clock.time_msec() - received_ts) / 1000 - ) + pdu_process_time.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).observe((self._clock.time_msec() - received_ts) / 1000) next = await self._get_next_nonspam_staged_event_for_room( room_id, room_version diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index e309836a52..7f511d570c 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -54,7 +54,7 @@ from sortedcontainers import SortedDict from synapse.api.presence import UserPresenceState from synapse.federation.sender import AbstractFederationSender, FederationSender -from synapse.metrics import LaterGauge +from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.replication.tcp.streams.federation import FederationStream from synapse.types import JsonDict, ReadReceipt, RoomStreamToken, StrCollection from synapse.util.metrics import Measure @@ -113,10 +113,10 @@ class FederationRemoteSendQueue(AbstractFederationSender): # changes. ARGH. def register(name: str, queue: Sized) -> None: LaterGauge( - "synapse_federation_send_queue_%s_size" % (queue_name,), - "", - [], - lambda: len(queue), + name="synapse_federation_send_queue_%s_size" % (queue_name,), + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): len(queue)}, ) for queue_name in [ diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 8010cc62f3..8befbe3722 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -160,6 +160,7 @@ from synapse.federation.sender.transaction_manager import TransactionManager from synapse.federation.units import Edu from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import ( + SERVER_NAME_LABEL, LaterGauge, event_processing_loop_counter, event_processing_loop_room_count, @@ -189,11 +190,13 @@ logger = logging.getLogger(__name__) sent_pdus_destination_dist_count = Counter( "synapse_federation_client_sent_pdu_destinations_count", "Number of PDUs queued for sending to one or more destinations", + labelnames=[SERVER_NAME_LABEL], ) sent_pdus_destination_dist_total = Counter( "synapse_federation_client_sent_pdu_destinations", "Total number of PDUs queued for sending across all destinations", + labelnames=[SERVER_NAME_LABEL], ) # Time (in s) to wait before trying to wake up destinations that have @@ -296,6 +299,7 @@ class _DestinationWakeupQueue: Staggers waking up of per destination queues to ensure that we don't attempt to start TLS connections with many hosts all at once, leading to pinned CPU. + """ # The maximum duration in seconds between queuing up a destination and it @@ -303,6 +307,10 @@ class _DestinationWakeupQueue: _MAX_TIME_IN_QUEUE = 30.0 sender: "FederationSender" = attr.ib() + server_name: str = attr.ib() + """ + Our homeserver name (used to label metrics) (`hs.hostname`). + """ clock: Clock = attr.ib() max_delay_s: int = attr.ib() @@ -391,31 +399,37 @@ class FederationSender(AbstractFederationSender): self._per_destination_queues: Dict[str, PerDestinationQueue] = {} LaterGauge( - "synapse_federation_transaction_queue_pending_destinations", - "", - [], - lambda: sum( - 1 - for d in self._per_destination_queues.values() - if d.transmission_loop_running - ), + name="synapse_federation_transaction_queue_pending_destinations", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: { + (self.server_name,): sum( + 1 + for d in self._per_destination_queues.values() + if d.transmission_loop_running + ) + }, ) LaterGauge( - "synapse_federation_transaction_queue_pending_pdus", - "", - [], - lambda: sum( - d.pending_pdu_count() for d in self._per_destination_queues.values() - ), + name="synapse_federation_transaction_queue_pending_pdus", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: { + (self.server_name,): sum( + d.pending_pdu_count() for d in self._per_destination_queues.values() + ) + }, ) LaterGauge( - "synapse_federation_transaction_queue_pending_edus", - "", - [], - lambda: sum( - d.pending_edu_count() for d in self._per_destination_queues.values() - ), + name="synapse_federation_transaction_queue_pending_edus", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: { + (self.server_name,): sum( + d.pending_edu_count() for d in self._per_destination_queues.values() + ) + }, ) self._is_processing = False @@ -427,7 +441,7 @@ class FederationSender(AbstractFederationSender): 1.0 / hs.config.ratelimiting.federation_rr_transactions_per_room_per_second ) self._destination_wakeup_queue = _DestinationWakeupQueue( - self, self.clock, max_delay_s=rr_txn_interval_per_room_s + self, self.server_name, self.clock, max_delay_s=rr_txn_interval_per_room_s ) # Regularly wake up destinations that have outstanding PDUs to be caught up @@ -435,6 +449,7 @@ class FederationSender(AbstractFederationSender): run_as_background_process, WAKEUP_RETRY_PERIOD_SEC * 1000.0, "wake_destinations_needing_catchup", + self.server_name, self._wake_destinations_needing_catchup, ) @@ -477,7 +492,9 @@ class FederationSender(AbstractFederationSender): # fire off a processing loop in the background run_as_background_process( - "process_event_queue_for_federation", self._process_event_queue_loop + "process_event_queue_for_federation", + self.server_name, + self._process_event_queue_loop, ) async def _process_event_queue_loop(self) -> None: @@ -650,7 +667,8 @@ class FederationSender(AbstractFederationSender): ts = event_to_received_ts[event.event_id] assert ts is not None synapse.metrics.event_processing_lag_by_event.labels( - "federation_sender" + name="federation_sender", + **{SERVER_NAME_LABEL: self.server_name}, ).observe((now - ts) / 1000) async def handle_room_events(events: List[EventBase]) -> None: @@ -694,22 +712,30 @@ class FederationSender(AbstractFederationSender): assert ts is not None synapse.metrics.event_processing_lag.labels( - "federation_sender" + name="federation_sender", + **{SERVER_NAME_LABEL: self.server_name}, ).set(now - ts) synapse.metrics.event_processing_last_ts.labels( - "federation_sender" + name="federation_sender", + **{SERVER_NAME_LABEL: self.server_name}, ).set(ts) - events_processed_counter.inc(len(event_entries)) + events_processed_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc(len(event_entries)) - event_processing_loop_room_count.labels("federation_sender").inc( - len(events_by_room) - ) + event_processing_loop_room_count.labels( + name="federation_sender", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc(len(events_by_room)) - event_processing_loop_counter.labels("federation_sender").inc() + event_processing_loop_counter.labels( + name="federation_sender", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() synapse.metrics.event_processing_positions.labels( - "federation_sender" + name="federation_sender", **{SERVER_NAME_LABEL: self.server_name} ).set(next_token) finally: @@ -727,8 +753,12 @@ class FederationSender(AbstractFederationSender): if not destinations: return - sent_pdus_destination_dist_total.inc(len(destinations)) - sent_pdus_destination_dist_count.inc() + sent_pdus_destination_dist_total.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc(len(destinations)) + sent_pdus_destination_dist_count.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc() assert pdu.internal_metadata.stream_ordering diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 8d6c77faee..4c844d403a 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -40,7 +40,7 @@ from synapse.federation.units import Edu from synapse.handlers.presence import format_user_presence_state from synapse.logging import issue9533_logger from synapse.logging.opentracing import SynapseTags, set_tag -from synapse.metrics import sent_transactions_counter +from synapse.metrics import SERVER_NAME_LABEL, sent_transactions_counter from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict, ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter @@ -56,13 +56,15 @@ logger = logging.getLogger(__name__) sent_edus_counter = Counter( - "synapse_federation_client_sent_edus", "Total number of EDUs successfully sent" + "synapse_federation_client_sent_edus", + "Total number of EDUs successfully sent", + labelnames=[SERVER_NAME_LABEL], ) sent_edus_by_type = Counter( "synapse_federation_client_sent_edus_by_type", "Number of sent EDUs successfully sent, by event type", - ["type"], + labelnames=["type", SERVER_NAME_LABEL], ) @@ -91,7 +93,7 @@ class PerDestinationQueue: transaction_manager: "synapse.federation.sender.TransactionManager", destination: str, ): - self._server_name = hs.hostname + self.server_name = hs.hostname self._clock = hs.get_clock() self._storage_controllers = hs.get_storage_controllers() self._store = hs.get_datastores().main @@ -311,6 +313,7 @@ class PerDestinationQueue: run_as_background_process( "federation_transaction_transmission_loop", + self.server_name, self._transaction_transmission_loop, ) @@ -322,7 +325,12 @@ class PerDestinationQueue: # This will throw if we wouldn't retry. We do this here so we fail # quickly, but we will later check this again in the http client, # hence why we throw the result away. - await get_retry_limiter(self._destination, self._clock, self._store) + await get_retry_limiter( + destination=self._destination, + our_server_name=self.server_name, + clock=self._clock, + store=self._store, + ) if self._catching_up: # we potentially need to catch-up first @@ -362,10 +370,17 @@ class PerDestinationQueue: self._destination, pending_pdus, pending_edus ) - sent_transactions_counter.inc() - sent_edus_counter.inc(len(pending_edus)) + sent_transactions_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc() + sent_edus_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc(len(pending_edus)) for edu in pending_edus: - sent_edus_by_type.labels(edu.edu_type).inc() + sent_edus_by_type.labels( + type=edu.edu_type, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() except NotRetryingDestination as e: logger.debug( @@ -566,7 +581,7 @@ class PerDestinationQueue: new_pdus = await filter_events_for_server( self._storage_controllers, self._destination, - self._server_name, + self.server_name, new_pdus, redact=False, filter_out_erased_senders=True, @@ -590,7 +605,9 @@ class PerDestinationQueue: self._destination, room_catchup_pdus, [] ) - sent_transactions_counter.inc() + sent_transactions_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc() # We pulled this from the DB, so it'll be non-null assert pdu.internal_metadata.stream_ordering @@ -613,7 +630,7 @@ class PerDestinationQueue: # Send at most limit EDUs for receipts. for content in self._pending_receipt_edus[:limit]: yield Edu( - origin=self._server_name, + origin=self.server_name, destination=self._destination, edu_type=EduTypes.RECEIPT, content=content, @@ -639,7 +656,7 @@ class PerDestinationQueue: ) edus = [ Edu( - origin=self._server_name, + origin=self.server_name, destination=self._destination, edu_type=edu_type, content=content, @@ -666,7 +683,7 @@ class PerDestinationQueue: edus = [ Edu( - origin=self._server_name, + origin=self.server_name, destination=self._destination, edu_type=EduTypes.DIRECT_TO_DEVICE, content=content, @@ -739,7 +756,7 @@ class _TransactionQueueManager: pending_edus.append( Edu( - origin=self.queue._server_name, + origin=self.queue.server_name, destination=self.queue._destination, edu_type=EduTypes.PRESENCE, content={"push": presence_to_add}, diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 21e2fed085..63ed13c6fa 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -34,6 +34,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.metrics import SERVER_NAME_LABEL from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.metrics import measure_func @@ -47,7 +48,7 @@ issue_8631_logger = logging.getLogger("synapse.8631_debug") last_pdu_ts_metric = Gauge( "synapse_federation_last_sent_pdu_time", "The timestamp of the last PDU which was successfully sent to the given domain", - labelnames=("server_name",), + labelnames=("destination_server_name", SERVER_NAME_LABEL), ) @@ -191,6 +192,7 @@ class TransactionManager: if pdus and destination in self._federation_metrics_domains: last_pdu = pdus[-1] - last_pdu_ts_metric.labels(server_name=destination).set( - last_pdu.origin_server_ts / 1000 - ) + last_pdu_ts_metric.labels( + destination_server_name=destination, + **{SERVER_NAME_LABEL: self.server_name}, + ).set(last_pdu.origin_server_ts / 1000) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 7004d95a0f..39a22b8cbb 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -38,6 +38,9 @@ logger = logging.getLogger(__name__) class AccountValidityHandler: def __init__(self, hs: "HomeServer"): self.hs = hs + self.server_name = ( + hs.hostname + ) # nb must be called this for @wrap_as_background_process self.config = hs.config self.store = hs.get_datastores().main self.send_email_handler = hs.get_send_email_handler() diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 93224d0c1b..5bd239e5fe 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -42,6 +42,7 @@ from synapse.events import EventBase from synapse.handlers.presence import format_user_presence_state from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics import ( + SERVER_NAME_LABEL, event_processing_loop_counter, event_processing_loop_room_count, ) @@ -68,12 +69,16 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "") +events_processed_counter = Counter( + "synapse_handlers_appservice_events_processed", "", labelnames=[SERVER_NAME_LABEL] +) class ApplicationServicesHandler: def __init__(self, hs: "HomeServer"): - self.server_name = hs.hostname + self.server_name = ( + hs.hostname + ) # nb must be called this for @wrap_as_background_process self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() @@ -166,7 +171,9 @@ class ApplicationServicesHandler: except Exception: logger.error("Application Services Failure") - run_as_background_process("as_scheduler", start_scheduler) + run_as_background_process( + "as_scheduler", self.server_name, start_scheduler + ) self.started_scheduler = True # Fork off pushes to these services @@ -180,7 +187,8 @@ class ApplicationServicesHandler: assert ts is not None synapse.metrics.event_processing_lag_by_event.labels( - "appservice_sender" + name="appservice_sender", + **{SERVER_NAME_LABEL: self.server_name}, ).observe((now - ts) / 1000) async def handle_room_events(events: Iterable[EventBase]) -> None: @@ -200,16 +208,23 @@ class ApplicationServicesHandler: await self.store.set_appservice_last_pos(upper_bound) synapse.metrics.event_processing_positions.labels( - "appservice_sender" + name="appservice_sender", + **{SERVER_NAME_LABEL: self.server_name}, ).set(upper_bound) - events_processed_counter.inc(len(events)) + events_processed_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc(len(events)) - event_processing_loop_room_count.labels("appservice_sender").inc( - len(events_by_room) - ) + event_processing_loop_room_count.labels( + name="appservice_sender", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc(len(events_by_room)) - event_processing_loop_counter.labels("appservice_sender").inc() + event_processing_loop_counter.labels( + name="appservice_sender", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() if events: now = self.clock.time_msec() @@ -217,10 +232,12 @@ class ApplicationServicesHandler: assert ts is not None synapse.metrics.event_processing_lag.labels( - "appservice_sender" + name="appservice_sender", + **{SERVER_NAME_LABEL: self.server_name}, ).set(now - ts) synapse.metrics.event_processing_last_ts.labels( - "appservice_sender" + name="appservice_sender", + **{SERVER_NAME_LABEL: self.server_name}, ).set(ts) finally: self.is_processing = False diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index ec29b1a34b..2d1990cce5 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -70,6 +70,7 @@ from synapse.http import get_request_user_agent from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.registration import ( LoginTokenExpired, @@ -95,7 +96,7 @@ INVALID_USERNAME_OR_PASSWORD = "Invalid username or password" invalid_login_token_counter = Counter( "synapse_user_login_invalid_login_tokens", "Counts the number of rejected m.login.token on /login", - ["reason"], + labelnames=["reason", SERVER_NAME_LABEL], ) @@ -199,6 +200,7 @@ class AuthHandler: SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname self.store = hs.get_datastores().main self.auth = hs.get_auth() self.auth_blocking = hs.get_auth_blocking() @@ -248,6 +250,7 @@ class AuthHandler: run_as_background_process, 5 * 60 * 1000, "expire_old_sessions", + self.server_name, self._expire_old_sessions, ) @@ -272,8 +275,6 @@ class AuthHandler: hs.config.sso.sso_account_deactivated_template ) - self._server_name = hs.config.server.server_name - # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso.sso_client_whitelist) @@ -281,7 +282,9 @@ class AuthHandler: # response. self._extra_attributes: Dict[str, SsoLoginExtraAttributes] = {} - self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled + self._auth_delegation_enabled = ( + hs.config.mas.enabled or hs.config.experimental.msc3861.enabled + ) async def validate_user_via_ui_auth( self, @@ -332,7 +335,7 @@ class AuthHandler: LimitExceededError if the ratelimiter's failed request count for this user is too high to proceed """ - if self.msc3861_oauth_delegation_enabled: + if self._auth_delegation_enabled: raise SynapseError( HTTPStatus.INTERNAL_SERVER_ERROR, "UIA shouldn't be used with MSC3861" ) @@ -1479,11 +1482,20 @@ class AuthHandler: try: return await self.store.consume_login_token(login_token) except LoginTokenExpired: - invalid_login_token_counter.labels("expired").inc() + invalid_login_token_counter.labels( + reason="expired", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() except LoginTokenReused: - invalid_login_token_counter.labels("reused").inc() + invalid_login_token_counter.labels( + reason="reused", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() except NotFoundError: - invalid_login_token_counter.labels("not found").inc() + invalid_login_token_counter.labels( + reason="not found", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) @@ -1858,7 +1870,7 @@ class AuthHandler: html = self._sso_redirect_confirm_template.render( display_url=display_url, redirect_url=redirect_url, - server_name=self._server_name, + server_name=self.server_name, new_user=new_user, user_id=registered_user_id, user_profile=user_profile_data, diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 4247faaecb..e4169321cc 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -42,6 +42,7 @@ class DeactivateAccountHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.hs = hs + self.server_name = hs.hostname self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() self._room_member_handler = hs.get_room_member_handler() @@ -271,7 +272,9 @@ class DeactivateAccountHandler: pending deactivation, if it isn't already running. """ if not self._user_parter_running: - run_as_background_process("user_parter_loop", self._user_parter_loop) + run_as_background_process( + "user_parter_loop", self.server_name, self._user_parter_loop + ) async def _user_parter_loop(self) -> None: """Loop that parts deactivated users from rooms""" diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index beb0e819c2..ce13dcc737 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -22,7 +22,7 @@ from synapse.api.errors import ShadowBanError from synapse.api.ratelimiting import Ratelimiter from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME from synapse.logging.opentracing import set_tag -from synapse.metrics import event_processing_positions +from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.delayed_events import ( ReplicationAddedDelayedEventRestServlet, @@ -110,12 +110,13 @@ class DelayedEventsHandler: # Can send the events in background after having awaited on marking them as processed run_as_background_process( "_send_events", + self.server_name, self._send_events, events, ) self._initialized_from_db = run_as_background_process( - "_schedule_db_events", _schedule_db_events + "_schedule_db_events", self.server_name, _schedule_db_events ) else: self._repl_client = ReplicationAddedDelayedEventRestServlet.make_client(hs) @@ -140,7 +141,9 @@ class DelayedEventsHandler: finally: self._event_processing = False - run_as_background_process("delayed_events.notify_new_event", process) + run_as_background_process( + "delayed_events.notify_new_event", self.server_name, process + ) async def _unsafe_process_new_event(self) -> None: # If self._event_pos is None then means we haven't fetched it from the DB yet @@ -188,7 +191,9 @@ class DelayedEventsHandler: self._event_pos = max_pos # Expose current event processing position to prometheus - event_processing_positions.labels("delayed_events").set(max_pos) + event_processing_positions.labels( + name="delayed_events", **{SERVER_NAME_LABEL: self.server_name} + ).set(max_pos) await self._store.update_delayed_events_stream_pos(max_pos) @@ -450,6 +455,7 @@ class DelayedEventsHandler: delay_sec, run_as_background_process, "_send_on_timeout", + self.server_name, self._send_on_timeout, ) else: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 80d49fc18d..acae34e71f 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -193,8 +193,9 @@ class DeviceHandler: self.clock.looping_call( run_as_background_process, DELETE_STALE_DEVICES_INTERVAL_MS, - "delete_stale_devices", - self._delete_stale_devices, + desc="delete_stale_devices", + server_name=self.server_name, + func=self._delete_stale_devices, ) async def _delete_stale_devices(self) -> None: @@ -963,6 +964,9 @@ class DeviceWriterHandler(DeviceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) + self.server_name = ( + hs.hostname + ) # nb must be called this for @measure_func and @wrap_as_background_process # We only need to poke the federation sender explicitly if its on the # same instance. Other federation sender instances will get notified by # `synapse.app.generic_worker.FederationSenderHandler` when it sees it @@ -1440,6 +1444,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): def __init__(self, hs: "HomeServer", device_handler: DeviceWriterHandler): super().__init__(hs) + self.server_name = hs.hostname self.federation = hs.get_federation_client() self.server_name = hs.hostname # nb must be called this for @measure_func self.clock = hs.get_clock() # nb must be called this for @measure_func @@ -1470,6 +1475,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): self.clock.looping_call( run_as_background_process, 30 * 1000, + server_name=self.server_name, func=self._maybe_retry_device_resync, desc="_maybe_retry_device_resync", ) @@ -1591,6 +1597,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): await self.store.mark_remote_users_device_caches_as_stale([user_id]) run_as_background_process( "_maybe_retry_device_resync", + self.server_name, self.multi_user_device_resync, [user_id], False, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index c709ed2c63..34aae7ef3c 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -71,6 +71,7 @@ from synapse.handlers.pagination import PURGE_PAGINATION_LOCK_NAME from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import nested_logging_context from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import NOT_SPAM from synapse.storage.databases.main.events_worker import EventRedactBehaviour @@ -90,7 +91,7 @@ logger = logging.getLogger(__name__) backfill_processing_before_timer = Histogram( "synapse_federation_backfill_processing_before_time_seconds", "sec", - [], + labelnames=[SERVER_NAME_LABEL], buckets=( 0.1, 0.5, @@ -187,7 +188,9 @@ class FederationHandler: # were shut down. if not hs.config.worker.worker_app: run_as_background_process( - "resume_sync_partial_state_room", self._resume_partial_state_room_sync + "resume_sync_partial_state_room", + self.server_name, + self._resume_partial_state_room_sync, ) @trace @@ -316,6 +319,7 @@ class FederationHandler: ) run_as_background_process( "_maybe_backfill_inner_anyway_with_max_depth", + self.server_name, self.maybe_backfill, room_id=room_id, # We use `MAX_DEPTH` so that we find all backfill points next @@ -530,9 +534,9 @@ class FederationHandler: # backfill points regardless of `current_depth`. if processing_start_time is not None: processing_end_time = self.clock.time_msec() - backfill_processing_before_timer.observe( - (processing_end_time - processing_start_time) / 1000 - ) + backfill_processing_before_timer.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).observe((processing_end_time - processing_start_time) / 1000) success = await try_backfill(likely_domains) if success: @@ -798,7 +802,10 @@ class FederationHandler: # have. Hence we fire off the background task, but don't wait for it. run_as_background_process( - "handle_queued_pdus", self._handle_queued_pdus, room_queue + "handle_queued_pdus", + self.server_name, + self._handle_queued_pdus, + room_queue, ) async def do_knock( @@ -1870,7 +1877,9 @@ class FederationHandler: ) run_as_background_process( - desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper + desc="sync_partial_state_room", + server_name=self.server_name, + func=_sync_partial_state_room_wrapper, ) async def _sync_partial_state_room( diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 5cec2b01e5..2ef7e77b1d 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -76,6 +76,7 @@ from synapse.logging.opentracing import ( tag_args, trace, ) +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, @@ -105,13 +106,14 @@ logger = logging.getLogger(__name__) soft_failed_event_counter = Counter( "synapse_federation_soft_failed_events_total", "Events received over federation that we marked as soft_failed", + labelnames=[SERVER_NAME_LABEL], ) # Added to debug performance and track progress on optimizations backfill_processing_after_timer = Histogram( "synapse_federation_backfill_processing_after_time_seconds", "sec", - [], + labelnames=[SERVER_NAME_LABEL], buckets=( 0.1, 0.25, @@ -146,6 +148,7 @@ class FederationEventHandler: """ def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname self._clock = hs.get_clock() self._store = hs.get_datastores().main self._state_store = hs.get_datastores().state @@ -170,7 +173,6 @@ class FederationEventHandler: self._is_mine_id = hs.is_mine_id self._is_mine_server_name = hs.is_mine_server_name - self._server_name = hs.hostname self._instance_name = hs.get_instance_name() self._config = hs.config @@ -249,7 +251,7 @@ class FederationEventHandler: # Note that if we were never in the room then we would have already # dropped the event, since we wouldn't know the room version. is_in_room = await self._event_auth_handler.is_host_in_room( - room_id, self._server_name + room_id, self.server_name ) if not is_in_room: logger.info( @@ -690,7 +692,9 @@ class FederationEventHandler: if not events: return - with backfill_processing_after_timer.time(): + with backfill_processing_after_timer.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).time(): # if there are any events in the wrong room, the remote server is buggy and # should not be trusted. for ev in events: @@ -930,6 +934,7 @@ class FederationEventHandler: if len(events_with_failed_pull_attempts) > 0: run_as_background_process( "_process_new_pulled_events_with_failed_pull_attempts", + self.server_name, _process_new_pulled_events, events_with_failed_pull_attempts, ) @@ -1523,6 +1528,7 @@ class FederationEventHandler: if resync: run_as_background_process( "resync_device_due_to_pdu", + self.server_name, self._resync_device, event.sender, ) @@ -2049,7 +2055,9 @@ class FederationEventHandler: "hs": origin, }, ) - soft_failed_event_counter.inc() + soft_failed_event_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc() event.internal_metadata.soft_failed = True async def _load_or_fetch_auth_events_for_event( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index cb64df2d01..fff46b640b 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -67,7 +67,6 @@ from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( @@ -97,6 +96,7 @@ class MessageHandler: """Contains some read only APIs to get state about a room""" def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname self.auth = hs.get_auth() self.clock = hs.get_clock() self.state = hs.get_state_handler() @@ -112,7 +112,7 @@ class MessageHandler: if not hs.config.worker.worker_app: run_as_background_process( - "_schedule_next_expiry", self._schedule_next_expiry + "_schedule_next_expiry", self.server_name, self._schedule_next_expiry ) async def get_room_data( @@ -444,6 +444,7 @@ class MessageHandler: delay, run_as_background_process, "_expire_event", + self.server_name, self._expire_event, event_id, ) @@ -504,7 +505,6 @@ class EventCreationHandler: self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state - self.send_event = ReplicationSendEventRestServlet.make_client(hs) self.send_events = ReplicationSendEventsRestServlet.make_client(hs) self.request_ratelimiter = hs.get_request_ratelimiter() @@ -546,6 +546,7 @@ class EventCreationHandler: self.clock.looping_call( lambda: run_as_background_process( "send_dummy_events_to_fill_extremities", + self.server_name, self._send_dummy_events_to_fill_extremities, ), 5 * 60 * 1000, @@ -646,38 +647,46 @@ class EventCreationHandler: """ await self.auth_blocking.check_auth_blocking(requester=requester) - requester_suspended = await self.store.get_user_suspended_status( - requester.user.to_string() + # The requester may be a regular user, but puppeted by the server. + request_by_server = ( + requester.authenticated_entity == self.hs.config.server.server_name ) - if requester_suspended: - # We want to allow suspended users to perform "corrective" actions - # asked of them by server admins, such as redact their messages and - # leave rooms. - if event_dict["type"] in ["m.room.redaction", "m.room.member"]: - if event_dict["type"] == "m.room.redaction": - event = await self.store.get_event( - event_dict["content"]["redacts"], allow_none=True - ) - if event: - if event.sender != requester.user.to_string(): + + # If the request is initiated by the server, ignore whether the + # requester or target is suspended. + if not request_by_server: + requester_suspended = await self.store.get_user_suspended_status( + requester.user.to_string() + ) + if requester_suspended: + # We want to allow suspended users to perform "corrective" actions + # asked of them by server admins, such as redact their messages and + # leave rooms. + if event_dict["type"] in ["m.room.redaction", "m.room.member"]: + if event_dict["type"] == "m.room.redaction": + event = await self.store.get_event( + event_dict["content"]["redacts"], allow_none=True + ) + if event: + if event.sender != requester.user.to_string(): + raise SynapseError( + 403, + "You can only redact your own events while account is suspended.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + if event_dict["type"] == "m.room.member": + if event_dict["content"]["membership"] != "leave": raise SynapseError( 403, - "You can only redact your own events while account is suspended.", + "Changing membership while account is suspended is not allowed.", Codes.USER_ACCOUNT_SUSPENDED, ) - if event_dict["type"] == "m.room.member": - if event_dict["content"]["membership"] != "leave": - raise SynapseError( - 403, - "Changing membership while account is suspended is not allowed.", - Codes.USER_ACCOUNT_SUSPENDED, - ) - else: - raise SynapseError( - 403, - "Sending messages while account is suspended is not allowed.", - Codes.USER_ACCOUNT_SUSPENDED, - ) + else: + raise SynapseError( + 403, + "Sending messages while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) is_create_event = ( event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "" @@ -1107,6 +1116,9 @@ class EventCreationHandler: policy_allowed = await self._policy_handler.is_event_allowed(event) if not policy_allowed: + # We shouldn't need to set the metadata because the raise should + # cause the request to be denied, but just in case: + event.internal_metadata.policy_server_spammy = True logger.warning( "Event not allowed by policy server, rejecting %s", event.event_id, @@ -2070,6 +2082,7 @@ class EventCreationHandler: # matters as sometimes presence code can take a while. run_as_background_process( "bump_presence_active_time", + self.server_name, self._bump_active_time, requester.user, requester.device_id, diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 4070b74b7a..df1a7e714c 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -79,12 +79,12 @@ class PaginationHandler: def __init__(self, hs: "HomeServer"): self.hs = hs + self.server_name = hs.hostname self.auth = hs.get_auth() self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state self.clock = hs.get_clock() - self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() self._relations_handler = hs.get_relations_handler() self._worker_locks = hs.get_worker_locks_handler() @@ -119,6 +119,7 @@ class PaginationHandler: run_as_background_process, job.interval, "purge_history_for_rooms_in_range", + self.server_name, self.purge_history_for_rooms_in_range, job.shortest_max_lifetime, job.longest_max_lifetime, @@ -245,6 +246,7 @@ class PaginationHandler: # other purges in the same room. run_as_background_process( PURGE_HISTORY_ACTION_NAME, + self.server_name, self.purge_history, room_id, token, @@ -395,7 +397,7 @@ class PaginationHandler: write=True, ): # first check that we have no users in this room - joined = await self.store.is_host_joined(room_id, self._server_name) + joined = await self.store.is_host_joined(room_id, self.server_name) if joined: if force: logger.info( @@ -604,6 +606,7 @@ class PaginationHandler: # for a costly federation call and processing. run_as_background_process( "maybe_backfill_in_the_background", + self.server_name, self.hs.get_federation_handler().maybe_backfill, room_id, curr_topo, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index c652e333a6..b253117498 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -105,7 +105,7 @@ from synapse.api.presence import UserDevicePresenceState, UserPresenceState from synapse.appservice import ApplicationService from synapse.events.presence_router import PresenceRouter from synapse.logging.context import run_in_background -from synapse.metrics import LaterGauge +from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -137,24 +137,40 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "") +notified_presence_counter = Counter( + "synapse_handler_presence_notified_presence", "", labelnames=[SERVER_NAME_LABEL] +) federation_presence_out_counter = Counter( - "synapse_handler_presence_federation_presence_out", "" + "synapse_handler_presence_federation_presence_out", + "", + labelnames=[SERVER_NAME_LABEL], +) +presence_updates_counter = Counter( + "synapse_handler_presence_presence_updates", "", labelnames=[SERVER_NAME_LABEL] +) +timers_fired_counter = Counter( + "synapse_handler_presence_timers_fired", "", labelnames=[SERVER_NAME_LABEL] ) -presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "") -timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "") federation_presence_counter = Counter( - "synapse_handler_presence_federation_presence", "" + "synapse_handler_presence_federation_presence", "", labelnames=[SERVER_NAME_LABEL] +) +bump_active_time_counter = Counter( + "synapse_handler_presence_bump_active_time", "", labelnames=[SERVER_NAME_LABEL] ) -bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "") -get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"]) +get_updates_counter = Counter( + "synapse_handler_presence_get_updates", "", labelnames=["type", SERVER_NAME_LABEL] +) notify_reason_counter = Counter( - "synapse_handler_presence_notify_reason", "", ["locality", "reason"] + "synapse_handler_presence_notify_reason", + "", + labelnames=["locality", "reason", SERVER_NAME_LABEL], ) state_transition_counter = Counter( - "synapse_handler_presence_state_transition", "", ["locality", "from", "to"] + "synapse_handler_presence_state_transition", + "", + labelnames=["locality", "from", "to", SERVER_NAME_LABEL], ) # If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them @@ -484,6 +500,7 @@ class _NullContextManager(ContextManager[None]): class WorkerPresenceHandler(BasePresenceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) + self.server_name = hs.hostname self._presence_writer_instance = hs.config.worker.writers.presence[0] # Route presence EDUs to the right worker @@ -517,6 +534,7 @@ class WorkerPresenceHandler(BasePresenceHandler): "shutdown", run_as_background_process, "generic_presence.on_shutdown", + self.server_name, self._on_shutdown, ) @@ -666,7 +684,9 @@ class WorkerPresenceHandler(BasePresenceHandler): old_state = self.user_to_current_state.get(new_state.user_id) self.user_to_current_state[new_state.user_id] = new_state is_mine = self.is_mine_id(new_state.user_id) - if not old_state or should_notify(old_state, new_state, is_mine): + if not old_state or should_notify( + old_state, new_state, is_mine, self.server_name + ): state_to_notify.append(new_state) stream_id = token @@ -747,7 +767,9 @@ class WorkerPresenceHandler(BasePresenceHandler): class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.server_name = hs.hostname + self.server_name = ( + hs.hostname + ) # nb must be called this for @wrap_as_background_process self.wheel_timer: WheelTimer[str] = WheelTimer() self.notifier = hs.get_notifier() @@ -758,10 +780,10 @@ class PresenceHandler(BasePresenceHandler): ) LaterGauge( - "synapse_handlers_presence_user_to_current_state_size", - "", - [], - lambda: len(self.user_to_current_state), + name="synapse_handlers_presence_user_to_current_state_size", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): len(self.user_to_current_state)}, ) # The per-device presence state, maps user to devices to per-device presence state. @@ -815,6 +837,7 @@ class PresenceHandler(BasePresenceHandler): "shutdown", run_as_background_process, "presence.on_shutdown", + self.server_name, self._on_shutdown, ) @@ -860,10 +883,10 @@ class PresenceHandler(BasePresenceHandler): ) LaterGauge( - "synapse_handlers_presence_wheel_timer_size", - "", - [], - lambda: len(self.wheel_timer), + name="synapse_handlers_presence_wheel_timer_size", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): len(self.wheel_timer)}, ) # Used to handle sending of presence to newly joined users/servers @@ -972,6 +995,7 @@ class PresenceHandler(BasePresenceHandler): prev_state, new_state, is_mine=self.is_mine_id(user_id), + our_server_name=self.server_name, wheel_timer=self.wheel_timer, now=now, # When overriding disabled presence, don't kick off all the @@ -991,10 +1015,14 @@ class PresenceHandler(BasePresenceHandler): # TODO: We should probably ensure there are no races hereafter - presence_updates_counter.inc(len(new_states)) + presence_updates_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc(len(new_states)) if to_notify: - notified_presence_counter.inc(len(to_notify)) + notified_presence_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc(len(to_notify)) await self._persist_and_notify(list(to_notify.values())) self.unpersisted_users_changes |= {s.user_id for s in new_states} @@ -1013,7 +1041,9 @@ class PresenceHandler(BasePresenceHandler): if user_id not in to_notify } if to_federation_ping: - federation_presence_out_counter.inc(len(to_federation_ping)) + federation_presence_out_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc(len(to_federation_ping)) hosts_to_states = await get_interested_remotes( self.store, @@ -1063,7 +1093,9 @@ class PresenceHandler(BasePresenceHandler): for user_id in users_to_check ] - timers_fired_counter.inc(len(states)) + timers_fired_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc( + len(states) + ) # Set of user ID & device IDs which are currently syncing. syncing_user_devices = { @@ -1097,7 +1129,7 @@ class PresenceHandler(BasePresenceHandler): user_id = user.to_string() - bump_active_time_counter.inc() + bump_active_time_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc() now = self.clock.time_msec() @@ -1349,7 +1381,9 @@ class PresenceHandler(BasePresenceHandler): updates.append(prev_state.copy_and_replace(**new_fields)) if updates: - federation_presence_counter.inc(len(updates)) + federation_presence_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc(len(updates)) await self._update_states(updates) async def set_state( @@ -1495,7 +1529,9 @@ class PresenceHandler(BasePresenceHandler): finally: self._event_processing = False - run_as_background_process("presence.notify_new_event", _process_presence) + run_as_background_process( + "presence.notify_new_event", self.server_name, _process_presence + ) async def _unsafe_process(self) -> None: # Loop round handling deltas until we're up to date @@ -1532,9 +1568,9 @@ class PresenceHandler(BasePresenceHandler): self._event_pos = max_pos # Expose current event processing position to prometheus - synapse.metrics.event_processing_positions.labels("presence").set( - max_pos - ) + synapse.metrics.event_processing_positions.labels( + name="presence", **{SERVER_NAME_LABEL: self.server_name} + ).set(max_pos) async def _handle_state_delta(self, room_id: str, deltas: List[StateDelta]) -> None: """Process current state deltas for the room to find new joins that need @@ -1660,7 +1696,10 @@ class PresenceHandler(BasePresenceHandler): def should_notify( - old_state: UserPresenceState, new_state: UserPresenceState, is_mine: bool + old_state: UserPresenceState, + new_state: UserPresenceState, + is_mine: bool, + our_server_name: str, ) -> bool: """Decides if a presence state change should be sent to interested parties.""" user_location = "remote" @@ -1671,19 +1710,38 @@ def should_notify( return False if old_state.status_msg != new_state.status_msg: - notify_reason_counter.labels(user_location, "status_msg_change").inc() + notify_reason_counter.labels( + locality=user_location, + reason="status_msg_change", + **{SERVER_NAME_LABEL: our_server_name}, + ).inc() return True if old_state.state != new_state.state: - notify_reason_counter.labels(user_location, "state_change").inc() + notify_reason_counter.labels( + locality=user_location, + reason="state_change", + **{SERVER_NAME_LABEL: our_server_name}, + ).inc() state_transition_counter.labels( - user_location, old_state.state, new_state.state + **{ + "locality": user_location, + # `from` is a reserved word in Python so we have to label it this way if + # we want to use keyword args. + "from": old_state.state, + "to": new_state.state, + SERVER_NAME_LABEL: our_server_name, + }, ).inc() return True if old_state.state == PresenceState.ONLINE: if new_state.currently_active != old_state.currently_active: - notify_reason_counter.labels(user_location, "current_active_change").inc() + notify_reason_counter.labels( + locality=user_location, + reason="current_active_change", + **{SERVER_NAME_LABEL: our_server_name}, + ).inc() return True if ( @@ -1693,14 +1751,18 @@ def should_notify( # Only notify about last active bumps if we're not currently active if not new_state.currently_active: notify_reason_counter.labels( - user_location, "last_active_change_online" + locality=user_location, + reason="last_active_change_online", + **{SERVER_NAME_LABEL: our_server_name}, ).inc() return True elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY: # Always notify for a transition where last active gets bumped. notify_reason_counter.labels( - user_location, "last_active_change_not_online" + locality=user_location, + reason="last_active_change_not_online", + **{SERVER_NAME_LABEL: our_server_name}, ).inc() return True @@ -1767,6 +1829,7 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): self.server_name = hs.hostname self.get_presence_handler = hs.get_presence_handler self.get_presence_router = hs.get_presence_router + self.server_name = hs.hostname self.clock = hs.get_clock() self.store = hs.get_datastores().main @@ -1878,7 +1941,10 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): # If we have the full list of changes for presence we can # simply check which ones share a room with the user. - get_updates_counter.labels("stream").inc() + get_updates_counter.labels( + type="stream", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() sharing_users = await self.store.do_users_share_a_room( user_id, updated_users @@ -1891,7 +1957,10 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): else: # Too many possible updates. Find all users we can see and check # if any of them have changed. - get_updates_counter.labels("full").inc() + get_updates_counter.labels( + type="full", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() users_interested_in = ( await self.store.get_users_who_share_room_with_user(user_id) @@ -2141,6 +2210,7 @@ def handle_update( prev_state: UserPresenceState, new_state: UserPresenceState, is_mine: bool, + our_server_name: str, wheel_timer: WheelTimer, now: int, persist: bool, @@ -2153,6 +2223,7 @@ def handle_update( prev_state new_state is_mine: Whether the user is ours + our_server_name: The homeserver name of the our server (`hs.hostname`) wheel_timer now: Time now in ms persist: True if this state should persist until another update occurs. @@ -2221,7 +2292,7 @@ def handle_update( ) # Check whether the change was something worth notifying about - if should_notify(prev_state, new_state, is_mine): + if should_notify(prev_state, new_state, is_mine, our_server_name): new_state = new_state.copy_and_replace(last_federation_update_ts=now) persist_and_notify = True diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index da392e115f..dbff28e7fb 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -124,7 +124,7 @@ class ProfileHandler: except RequestSendFailed as e: raise SynapseError(502, "Failed to fetch profile") from e except HttpResponseException as e: - if e.code < 500 and e.code != 404: + if e.code < 500 and e.code not in (403, 404): # Other codes are not allowed in c2s API logger.info( "Server replied with wrong response: %s %s", e.code, e.msg diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 6322d980d4..5761a7f70b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -45,6 +45,7 @@ from synapse.api.errors import ( from synapse.appservice import ApplicationService from synapse.config.server import is_threepid_reserved from synapse.http.servlet import assert_params_in_dict +from synapse.metrics import SERVER_NAME_LABEL from synapse.replication.http.login import RegisterDeviceReplicationServlet from synapse.replication.http.register import ( ReplicationPostRegisterActionsServlet, @@ -62,29 +63,38 @@ logger = logging.getLogger(__name__) registration_counter = Counter( "synapse_user_registrations_total", "Number of new users registered (since restart)", - ["guest", "shadow_banned", "auth_provider"], + labelnames=["guest", "shadow_banned", "auth_provider", SERVER_NAME_LABEL], ) login_counter = Counter( "synapse_user_logins_total", "Number of user logins (since restart)", - ["guest", "auth_provider"], + labelnames=["guest", "auth_provider", SERVER_NAME_LABEL], ) -def init_counters_for_auth_provider(auth_provider_id: str) -> None: +def init_counters_for_auth_provider(auth_provider_id: str, server_name: str) -> None: """Ensure the prometheus counters for the given auth provider are initialised This fixes a problem where the counters are not reported for a given auth provider until the user first logs in/registers. + + Args: + auth_provider_id: The ID of the auth provider to initialise counters for. + server_name: Our server name (used to label metrics) (this should be `hs.hostname`). """ for is_guest in (True, False): - login_counter.labels(guest=is_guest, auth_provider=auth_provider_id) + login_counter.labels( + guest=is_guest, + auth_provider=auth_provider_id, + **{SERVER_NAME_LABEL: server_name}, + ) for shadow_banned in (True, False): registration_counter.labels( guest=is_guest, shadow_banned=shadow_banned, auth_provider=auth_provider_id, + **{SERVER_NAME_LABEL: server_name}, ) @@ -97,6 +107,7 @@ class LoginDict(TypedDict): class RegistrationHandler: def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() @@ -112,7 +123,6 @@ class RegistrationHandler: self._account_validity_handler = hs.get_account_validity_handler() self._user_consent_version = self.hs.config.consent.user_consent_version self._server_notices_mxid = hs.config.servernotices.server_notices_mxid - self._server_name = hs.hostname self._user_types_config = hs.config.user_types self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker @@ -138,7 +148,9 @@ class RegistrationHandler: ) self.refresh_token_lifetime = hs.config.registration.refresh_token_lifetime - init_counters_for_auth_provider("") + init_counters_for_auth_provider( + auth_provider_id="", server_name=self.server_name + ) async def check_username( self, @@ -362,6 +374,7 @@ class RegistrationHandler: guest=make_guest, shadow_banned=shadow_banned, auth_provider=(auth_provider_id or ""), + **{SERVER_NAME_LABEL: self.server_name}, ).inc() # If the user does not need to consent at registration, auto-join any @@ -422,7 +435,7 @@ class RegistrationHandler: if self.hs.config.registration.auto_join_user_id: fake_requester = create_requester( self.hs.config.registration.auto_join_user_id, - authenticated_entity=self._server_name, + authenticated_entity=self.server_name, ) # If the room requires an invite, add the user to the list of invites. @@ -435,7 +448,7 @@ class RegistrationHandler: requires_join = True else: fake_requester = create_requester( - user_id, authenticated_entity=self._server_name + user_id, authenticated_entity=self.server_name ) # Choose whether to federate the new room. @@ -467,7 +480,7 @@ class RegistrationHandler: await room_member_handler.update_membership( requester=create_requester( - user_id, authenticated_entity=self._server_name + user_id, authenticated_entity=self.server_name ), target=UserID.from_string(user_id), room_id=room_id, @@ -493,7 +506,7 @@ class RegistrationHandler: if requires_join: await room_member_handler.update_membership( requester=create_requester( - user_id, authenticated_entity=self._server_name + user_id, authenticated_entity=self.server_name ), target=UserID.from_string(user_id), room_id=room_id, @@ -539,7 +552,7 @@ class RegistrationHandler: # we don't have a local user in the room to craft up an invite with. requires_invite = await self.store.is_host_joined( room_id, - self._server_name, + self.server_name, ) if requires_invite: @@ -567,7 +580,7 @@ class RegistrationHandler: await room_member_handler.update_membership( requester=create_requester( self.hs.config.registration.auto_join_user_id, - authenticated_entity=self._server_name, + authenticated_entity=self.server_name, ), target=UserID.from_string(user_id), room_id=room_id, @@ -579,7 +592,7 @@ class RegistrationHandler: # Send the join. await room_member_handler.update_membership( requester=create_requester( - user_id, authenticated_entity=self._server_name + user_id, authenticated_entity=self.server_name ), target=UserID.from_string(user_id), room_id=room_id, @@ -790,6 +803,7 @@ class RegistrationHandler: login_counter.labels( guest=is_guest, auth_provider=(auth_provider_id or ""), + **{SERVER_NAME_LABEL: self.server_name}, ).inc() return ( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index a8b29debd9..47bd139ca7 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -66,6 +66,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.api.filtering import Filter +from synapse.api.ratelimiting import Ratelimiter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.event_auth import validate_event_for_room_version from synapse.events import EventBase @@ -134,7 +135,12 @@ class RoomCreationHandler: self.room_member_handler = hs.get_room_member_handler() self._event_auth_handler = hs.get_event_auth_handler() self.config = hs.config - self.request_ratelimiter = hs.get_request_ratelimiter() + self.common_request_ratelimiter = hs.get_request_ratelimiter() + self.creation_ratelimiter = Ratelimiter( + store=self.store, + clock=self.clock, + cfg=self.config.ratelimiting.rc_room_creation, + ) # Room state based off defined presets self._presets_dict: Dict[str, Dict[str, Any]] = { @@ -216,7 +222,11 @@ class RoomCreationHandler: ShadowBanError if the requester is shadow-banned. """ if ratelimit: - await self.request_ratelimiter.ratelimit(requester) + await self.creation_ratelimiter.ratelimit(requester, update=False) + + # then apply the ratelimits + await self.common_request_ratelimiter.ratelimit(requester) + await self.creation_ratelimiter.ratelimit(requester) user_id = requester.user.to_string() @@ -566,6 +576,7 @@ class RoomCreationHandler: created with _generate_room_id()) new_room_version: the new room version to use tombstone_event_id: the ID of the tombstone event in the old room. + additional_creators: additional room creators, for MSC4289. creation_event_with_context: The create event of the new room, if the new room supports room ID as create event ID hash. auto_member: Whether to automatically join local users to the new @@ -1060,6 +1071,25 @@ class RoomCreationHandler: await self.auth_blocking.check_auth_blocking(requester=requester) + if ratelimit: + # Limit the rate of room creations, + # using both the limiter specific to room creations as well + # as the general request ratelimiter. + # + # Note that we don't rate limit the individual + # events in the room — room creation isn't atomic and + # historically it was very janky if half the events in the + # initial state don't make it because of rate limiting. + + # First check the room creation ratelimiter without updating it + # (this is so we don't consume a token if the other ratelimiter doesn't + # allow us to proceed) + await self.creation_ratelimiter.ratelimit(requester, update=False) + + # then apply the ratelimits + await self.common_request_ratelimiter.ratelimit(requester) + await self.creation_ratelimiter.ratelimit(requester) + if ( self._server_notices_mxid is not None and user_id == self._server_notices_mxid @@ -1091,25 +1121,6 @@ class RoomCreationHandler: Codes.MISSING_PARAM, ) - if not is_requester_admin: - spam_check = await self._spam_checker_module_callbacks.user_may_create_room( - user_id, config - ) - if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: - raise SynapseError( - 403, - "You are not permitted to create rooms", - errcode=spam_check[0], - additional_fields=spam_check[1], - ) - - if ratelimit: - # Rate limit once in advance, but don't rate limit the individual - # events in the room — room creation isn't atomic and it's very - # janky if half the events in the initial state don't make it because - # of rate limiting. - await self.request_ratelimiter.ratelimit(requester) - room_version_id = config.get( "room_version", self.config.server.default_room_version.identifier ) @@ -1202,6 +1213,19 @@ class RoomCreationHandler: self._validate_room_config(config, visibility) + # Run the spam checker after other validation + if not is_requester_admin: + spam_check = await self._spam_checker_module_callbacks.user_may_create_room( + user_id, config + ) + if spam_check != self._spam_checker_module_callbacks.NOT_SPAM: + raise SynapseError( + 403, + "You are not permitted to create rooms", + errcode=spam_check[0], + additional_fields=spam_check[1], + ) + creation_content = config.get("creation_content", {}) # override any attempt to set room versions via the creation_content creation_content["room_version"] = room_version.identifier diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index fea25b9920..5ba64912c9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -49,7 +49,7 @@ from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging import opentracing -from synapse.metrics import event_processing_positions +from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.push import ReplicationCopyPusherRestServlet from synapse.storage.databases.main.state_deltas import StateDelta @@ -746,35 +746,41 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): and requester.user.to_string() == self._server_notices_mxid ) - requester_suspended = await self.store.get_user_suspended_status( - requester.user.to_string() - ) - if action == Membership.INVITE and requester_suspended: - raise SynapseError( - 403, - "Sending invites while account is suspended is not allowed.", - Codes.USER_ACCOUNT_SUSPENDED, - ) + # The requester may be a regular user, but puppeted by the server. + request_by_server = requester.authenticated_entity == self._server_name - if target.to_string() != requester.user.to_string(): - target_suspended = await self.store.get_user_suspended_status( - target.to_string() + # If the request is initiated by the server, ignore whether the + # requester or target is suspended. + if not request_by_server: + requester_suspended = await self.store.get_user_suspended_status( + requester.user.to_string() ) - else: - target_suspended = requester_suspended + if action == Membership.INVITE and requester_suspended: + raise SynapseError( + 403, + "Sending invites while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) - if action == Membership.JOIN and target_suspended: - raise SynapseError( - 403, - "Joining rooms while account is suspended is not allowed.", - Codes.USER_ACCOUNT_SUSPENDED, - ) - if action == Membership.KNOCK and target_suspended: - raise SynapseError( - 403, - "Knocking on rooms while account is suspended is not allowed.", - Codes.USER_ACCOUNT_SUSPENDED, - ) + if target.to_string() != requester.user.to_string(): + target_suspended = await self.store.get_user_suspended_status( + target.to_string() + ) + else: + target_suspended = requester_suspended + + if action == Membership.JOIN and target_suspended: + raise SynapseError( + 403, + "Joining rooms while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) + if action == Membership.KNOCK and target_suspended: + raise SynapseError( + 403, + "Knocking on rooms while account is suspended is not allowed.", + Codes.USER_ACCOUNT_SUSPENDED, + ) if ( not self.allow_per_room_profiles and not is_requester_server_notices_user @@ -2163,6 +2169,7 @@ class RoomForgetterHandler(StateDeltasHandler): super().__init__(hs) self._hs = hs + self.server_name = hs.hostname self._store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self._clock = hs.get_clock() @@ -2194,7 +2201,9 @@ class RoomForgetterHandler(StateDeltasHandler): finally: self._is_processing = False - run_as_background_process("room_forgetter.notify_new_event", process) + run_as_background_process( + "room_forgetter.notify_new_event", self.server_name, process + ) async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB @@ -2251,7 +2260,9 @@ class RoomForgetterHandler(StateDeltasHandler): self.pos = max_pos # Expose current event processing position to prometheus - event_processing_positions.labels("room_forgetter").set(max_pos) + event_processing_positions.labels( + name="room_forgetter", **{SERVER_NAME_LABEL: self.server_name} + ).set(max_pos) await self._store.update_room_forgetter_stream_pos(max_pos) diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py index 92fed980e6..6469b182c8 100644 --- a/synapse/handlers/send_email.py +++ b/synapse/handlers/send_email.py @@ -24,16 +24,13 @@ import logging from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from io import BytesIO -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional -from pkg_resources import parse_version - -import twisted from twisted.internet.defer import Deferred from twisted.internet.endpoints import HostnameEndpoint -from twisted.internet.interfaces import IOpenSSLContextFactory, IProtocolFactory +from twisted.internet.interfaces import IProtocolFactory from twisted.internet.ssl import optionsForClientTLS -from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory +from twisted.mail.smtp import ESMTPSenderFactory from twisted.protocols.tls import TLSMemoryBIOFactory from synapse.logging.context import make_deferred_yieldable @@ -44,49 +41,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -_is_old_twisted = parse_version(twisted.__version__) < parse_version("21") - - -class _BackportESMTPSender(ESMTPSender): - """Extend old versions of ESMTPSender to configure TLS. - - Unfortunately, before Twisted 21.2, ESMTPSender doesn't give an easy way to - disable TLS, or to configure the hostname used for TLS certificate validation. - This backports the `hostname` parameter for that functionality. - """ - - __hostname: Optional[str] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - """""" - self.__hostname = kwargs.pop("hostname", None) - super().__init__(*args, **kwargs) - - def _getContextFactory(self) -> Optional[IOpenSSLContextFactory]: - if self.context is not None: - return self.context - elif self.__hostname is None: - return None # disable TLS if hostname is None - return optionsForClientTLS(self.__hostname) - - -class _BackportESMTPSenderFactory(ESMTPSenderFactory): - """An ESMTPSenderFactory for _BackportESMTPSender. - - This backports the `hostname` parameter, to disable or configure TLS. - """ - - __hostname: Optional[str] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - self.__hostname = kwargs.pop("hostname", None) - super().__init__(*args, **kwargs) - - def protocol(self, *args: Any, **kwargs: Any) -> ESMTPSender: # type: ignore - # this overrides ESMTPSenderFactory's `protocol` attribute, with a Callable - # instantiating our _BackportESMTPSender, providing the hostname parameter - return _BackportESMTPSender(*args, **kwargs, hostname=self.__hostname) - async def _sendmail( reactor: ISynapseReactor, @@ -129,9 +83,7 @@ async def _sendmail( elif tlsname is None: tlsname = smtphost - factory: IProtocolFactory = ( - _BackportESMTPSenderFactory if _is_old_twisted else ESMTPSenderFactory - )( + factory: IProtocolFactory = ESMTPSenderFactory( username, password, from_addr, diff --git a/synapse/handlers/sliding_sync/__init__.py b/synapse/handlers/sliding_sync/__init__.py index cb56eb53fc..a9573ba0f1 100644 --- a/synapse/handlers/sliding_sync/__init__.py +++ b/synapse/handlers/sliding_sync/__init__.py @@ -38,6 +38,7 @@ from synapse.logging.opentracing import ( tag_args, trace, ) +from synapse.metrics import SERVER_NAME_LABEL from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.databases.main.state_deltas import StateDelta from synapse.storage.databases.main.stream import PaginateFunction @@ -79,7 +80,7 @@ logger = logging.getLogger(__name__) sync_processing_time = Histogram( "synapse_sliding_sync_processing_time", "Time taken to generate a sliding sync response, ignoring wait times.", - ["initial"], + labelnames=["initial", SERVER_NAME_LABEL], ) # Limit the number of state_keys we should remember sending down the connection for each @@ -94,6 +95,7 @@ MAX_NUMBER_PREVIOUS_STATE_KEYS_TO_REMEMBER = 100 class SlidingSyncHandler: def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname self.clock = hs.get_clock() self.store = hs.get_datastores().main self.storage_controllers = hs.get_storage_controllers() @@ -368,9 +370,9 @@ class SlidingSyncHandler: set_tag(SynapseTags.FUNC_ARG_PREFIX + "sync_config.user", user_id) end_time_s = self.clock.time() - sync_processing_time.labels(from_token is not None).observe( - end_time_s - start_time_s - ) + sync_processing_time.labels( + initial=from_token is not None, **{SERVER_NAME_LABEL: self.server_name} + ).observe(end_time_s - start_time_s) return sliding_sync_result diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 48f7ba094e..eec420cbb1 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -202,7 +202,7 @@ class SsoHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self._store = hs.get_datastores().main - self._server_name = hs.hostname + self.server_name = hs.hostname self._is_mine_server_name = hs.is_mine_server_name self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() @@ -238,7 +238,9 @@ class SsoHandler: p_id = p.idp_id assert p_id not in self._identity_providers self._identity_providers[p_id] = p - init_counters_for_auth_provider(p_id) + init_counters_for_auth_provider( + auth_provider_id=p_id, server_name=self.server_name + ) def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]: """Get the configured identity providers""" @@ -569,7 +571,7 @@ class SsoHandler: return attributes # Check if this mxid already exists - user_id = UserID(attributes.localpart, self._server_name).to_string() + user_id = UserID(attributes.localpart, self.server_name).to_string() if not await self._store.get_users_by_id_case_insensitive(user_id): # This mxid is free break @@ -907,7 +909,7 @@ class SsoHandler: # render an error page. html = self._bad_user_template.render( - server_name=self._server_name, + server_name=self.server_name, user_id_to_verify=user_id_to_verify, ) respond_with_html(request, 200, html) @@ -959,7 +961,7 @@ class SsoHandler: if contains_invalid_mxid_characters(localpart): raise SynapseError(400, "localpart is invalid: %s" % (localpart,)) - user_id = UserID(localpart, self._server_name).to_string() + user_id = UserID(localpart, self.server_name).to_string() user_infos = await self._store.get_users_by_id_case_insensitive(user_id) logger.info("[session %s] users: %s", session_id, user_infos) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index aa33260809..a2602ea818 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -32,7 +32,7 @@ from typing import ( ) from synapse.api.constants import EventContentFields, EventTypes, Membership -from synapse.metrics import event_processing_positions +from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.state_deltas import StateDelta from synapse.types import JsonDict @@ -54,6 +54,7 @@ class StatsHandler: def __init__(self, hs: "HomeServer"): self.hs = hs + self.server_name = hs.hostname self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() @@ -89,7 +90,7 @@ class StatsHandler: finally: self._is_processing = False - run_as_background_process("stats.notify_new_event", process) + run_as_background_process("stats.notify_new_event", self.server_name, process) async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB @@ -146,7 +147,9 @@ class StatsHandler: logger.debug("Handled room stats to %s -> %s", self.pos, max_pos) - event_processing_positions.labels("stats").set(max_pos) + event_processing_positions.labels( + name="stats", **{SERVER_NAME_LABEL: self.server_name} + ).set(max_pos) self.pos = max_pos diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 69064e751a..7bfe4e8760 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -63,6 +63,7 @@ from synapse.logging.opentracing import ( start_active_span, trace, ) +from synapse.metrics import SERVER_NAME_LABEL from synapse.storage.databases.main.event_push_actions import RoomNotifCounts from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.databases.main.stream import PaginateFunction @@ -104,7 +105,7 @@ non_empty_sync_counter = Counter( "Count of non empty sync responses. type is initial_sync/full_state_sync" "/incremental_sync. lazy_loaded indicates if lazy loaded members were " "enabled for that request.", - ["type", "lazy_loaded"], + labelnames=["type", "lazy_loaded", SERVER_NAME_LABEL], ) # Store the cache that tracks which lazy-loaded members have been sent to a given @@ -614,7 +615,11 @@ class SyncHandler: lazy_loaded = "true" else: lazy_loaded = "false" - non_empty_sync_counter.labels(sync_label, lazy_loaded).inc() + non_empty_sync_counter.labels( + type=sync_label, + lazy_loaded=lazy_loaded, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() return result diff --git a/synapse/handlers/thread_subscriptions.py b/synapse/handlers/thread_subscriptions.py index 79e4d6040d..bda4342949 100644 --- a/synapse/handlers/thread_subscriptions.py +++ b/synapse/handlers/thread_subscriptions.py @@ -1,9 +1,15 @@ import logging +from http import HTTPStatus from typing import TYPE_CHECKING, Optional -from synapse.api.errors import AuthError, NotFoundError -from synapse.storage.databases.main.thread_subscriptions import ThreadSubscription -from synapse.types import UserID +from synapse.api.constants import RelationTypes +from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError +from synapse.events import relation_from_event +from synapse.storage.databases.main.thread_subscriptions import ( + AutomaticSubscriptionConflicted, + ThreadSubscription, +) +from synapse.types import EventOrderings, UserID if TYPE_CHECKING: from synapse.server import HomeServer @@ -55,42 +61,79 @@ class ThreadSubscriptionsHandler: room_id: str, thread_root_event_id: str, *, - automatic: bool, + automatic_event_id: Optional[str], ) -> Optional[int]: """Sets or updates a user's subscription settings for a specific thread root. Args: requester_user_id: The ID of the user whose settings are being updated. thread_root_event_id: The event ID of the thread root. - automatic: whether the user was subscribed by an automatic decision by - their client. + automatic_event_id: if the user was subscribed by an automatic decision by + their client, the event ID that caused this. Returns: The stream ID for this update, if the update isn't no-opped. Raises: NotFoundError if the user cannot access the thread root event, or it isn't - known to this homeserver. + known to this homeserver. Ditto for the automatic cause event if supplied. + + SynapseError(400, M_NOT_IN_THREAD): if client supplied an automatic cause event + but user cannot access the event. + + SynapseError(409, M_SKIPPED): if client requested an automatic subscription + but it was skipped because the cause event is logically later than an unsubscription. """ # First check that the user can access the thread root event # and that it exists try: - event = await self.event_handler.get_event( + thread_root_event = await self.event_handler.get_event( user_id, room_id, thread_root_event_id ) - if event is None: + if thread_root_event is None: raise NotFoundError("No such thread root") except AuthError: logger.info("rejecting thread subscriptions change (thread not accessible)") raise NotFoundError("No such thread root") - return await self.store.subscribe_user_to_thread( + if automatic_event_id: + autosub_cause_event = await self.event_handler.get_event( + user_id, room_id, automatic_event_id + ) + if autosub_cause_event is None: + raise NotFoundError("Automatic subscription event not found") + relation = relation_from_event(autosub_cause_event) + if ( + relation is None + or relation.rel_type != RelationTypes.THREAD + or relation.parent_id != thread_root_event_id + ): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Automatic subscription must use an event in the thread", + errcode=Codes.MSC4306_NOT_IN_THREAD, + ) + + automatic_event_orderings = EventOrderings.from_event(autosub_cause_event) + else: + automatic_event_orderings = None + + outcome = await self.store.subscribe_user_to_thread( user_id.to_string(), - event.room_id, + room_id, thread_root_event_id, - automatic=automatic, + automatic_event_orderings=automatic_event_orderings, ) + if isinstance(outcome, AutomaticSubscriptionConflicted): + raise SynapseError( + HTTPStatus.CONFLICT, + "Automatic subscription obsoleted by an unsubscription request.", + errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION, + ) + + return outcome + async def unsubscribe_user_from_thread( self, user_id: UserID, room_id: str, thread_root_event_id: str ) -> Optional[int]: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 3c49655598..6a7b36ea0c 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -80,7 +80,9 @@ class FollowerTypingHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() - self.server_name = hs.config.server.server_name + self.server_name = ( + hs.hostname + ) # nb must be called this for @wrap_as_background_process self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id self.is_mine_server_name = hs.is_mine_server_name @@ -143,7 +145,11 @@ class FollowerTypingHandler: last_fed_poke = self._member_last_federation_poke.get(member, None) if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: run_as_background_process( - "typing._push_remote", self._push_remote, member=member, typing=True + "typing._push_remote", + self.server_name, + self._push_remote, + member=member, + typing=True, ) # Add a paranoia timer to ensure that we always have a timer for @@ -216,6 +222,7 @@ class FollowerTypingHandler: if self.federation: run_as_background_process( "_send_changes_in_typing_to_remotes", + self.server_name, self._send_changes_in_typing_to_remotes, row.room_id, prev_typing, @@ -378,7 +385,11 @@ class TypingWriterHandler(FollowerTypingHandler): if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. run_as_background_process( - "typing._push_remote", self._push_remote, member, typing + "typing._push_remote", + self.server_name, + self._push_remote, + member, + typing, ) self._push_update_local(member=member, typing=typing) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 5f9e96706a..130099a239 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -35,6 +35,7 @@ from synapse.api.constants import ( ) from synapse.api.errors import Codes, SynapseError from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.state_deltas import StateDelta from synapse.storage.databases.main.user_directory import SearchResult @@ -192,7 +193,9 @@ class UserDirectoryHandler(StateDeltasHandler): self._is_processing = False self._is_processing = True - run_as_background_process("user_directory.notify_new_event", process) + run_as_background_process( + "user_directory.notify_new_event", self.server_name, process + ) async def handle_local_profile_change( self, user_id: str, profile: ProfileInfo @@ -260,9 +263,9 @@ class UserDirectoryHandler(StateDeltasHandler): self.pos = max_pos # Expose current event processing position to prometheus - synapse.metrics.event_processing_positions.labels("user_dir").set( - max_pos - ) + synapse.metrics.event_processing_positions.labels( + name="user_dir", **{SERVER_NAME_LABEL: self.server_name} + ).set(max_pos) await self.store.update_user_directory_stream_pos(max_pos) @@ -606,7 +609,9 @@ class UserDirectoryHandler(StateDeltasHandler): self._is_refreshing_remote_profiles = False self._is_refreshing_remote_profiles = True - run_as_background_process("user_directory.refresh_remote_profiles", process) + run_as_background_process( + "user_directory.refresh_remote_profiles", self.server_name, process + ) async def _unsafe_refresh_remote_profiles(self) -> None: limit = MAX_SERVERS_TO_REFRESH_PROFILES_FOR_IN_ONE_GO - len( @@ -688,7 +693,9 @@ class UserDirectoryHandler(StateDeltasHandler): self._is_refreshing_remote_profiles_for_servers.add(server_name) run_as_background_process( - "user_directory.refresh_remote_profiles_for_remote_server", process + "user_directory.refresh_remote_profiles_for_remote_server", + self.server_name, + process, ) async def _unsafe_refresh_remote_profiles_for_remote_server( diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py index 3077e9e463..0b375790dd 100644 --- a/synapse/handlers/worker_lock.py +++ b/synapse/handlers/worker_lock.py @@ -66,6 +66,9 @@ class WorkerLocksHandler: """ def __init__(self, hs: "HomeServer") -> None: + self.server_name = ( + hs.hostname + ) # nb must be called this for @wrap_as_background_process self._reactor = hs.get_reactor() self._store = hs.get_datastores().main self._clock = hs.get_clock() diff --git a/synapse/http/client.py b/synapse/http/client.py index 928bfb228a..1f6d4dcd86 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -85,6 +85,7 @@ from synapse.http.replicationagent import ReplicationAgent from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag, start_active_span, tags +from synapse.metrics import SERVER_NAME_LABEL from synapse.types import ISynapseReactor, StrSequence from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred @@ -108,9 +109,13 @@ except ImportError: logger = logging.getLogger(__name__) -outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) +outgoing_requests_counter = Counter( + "synapse_http_client_requests", "", labelnames=["method", SERVER_NAME_LABEL] +) incoming_responses_counter = Counter( - "synapse_http_client_responses", "", ["method", "code"] + "synapse_http_client_responses", + "", + labelnames=["method", "code", SERVER_NAME_LABEL], ) # the type of the headers map, to be passed to the t.w.h.Headers. @@ -346,6 +351,7 @@ class BaseHttpClient: treq_args: Optional[Dict[str, Any]] = None, ): self.hs = hs + self.server_name = hs.hostname self.reactor = hs.get_reactor() self._extra_treq_args = treq_args or {} @@ -384,7 +390,9 @@ class BaseHttpClient: RequestTimedOutError if the request times out before the headers are read """ - outgoing_requests_counter.labels(method).inc() + outgoing_requests_counter.labels( + method=method, **{SERVER_NAME_LABEL: self.server_name} + ).inc() # log request but strip `access_token` (AS requests for example include this) logger.debug("Sending request %s %s", method, redact_uri(uri)) @@ -438,7 +446,11 @@ class BaseHttpClient: response = await make_deferred_yieldable(request_deferred) - incoming_responses_counter.labels(method, response.code).inc() + incoming_responses_counter.labels( + method=method, + code=response.code, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() logger.info( "Received response to %s %s: %s", method, @@ -447,7 +459,11 @@ class BaseHttpClient: ) return response except Exception as e: - incoming_responses_counter.labels(method, "ERR").inc() + incoming_responses_counter.labels( + method=method, + code="ERR", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() logger.info( "Error sending request to %s %s: %s %s", method, @@ -821,12 +837,12 @@ class SimpleHttpClient(BaseHttpClient): pool.cachedConnectionTimeout = 2 * 60 self.agent: IAgent = ProxyAgent( - self.reactor, - hs.get_reactor(), + reactor=self.reactor, + proxy_reactor=hs.get_reactor(), connectTimeout=15, contextFactory=self.hs.get_http_client_context_factory(), pool=pool, - use_proxy=use_proxy, + proxy_config=hs.config.server.proxy_config, ) if self._ip_blocklist: @@ -855,6 +871,7 @@ class ReplicationClient(BaseHttpClient): hs: The HomeServer instance to pass in """ super().__init__(hs) + self.server_name = hs.hostname # Use a pool, but a very small one. pool = HTTPConnectionPool(self.reactor) @@ -891,7 +908,9 @@ class ReplicationClient(BaseHttpClient): RequestTimedOutError if the request times out before the headers are read """ - outgoing_requests_counter.labels(method).inc() + outgoing_requests_counter.labels( + method=method, **{SERVER_NAME_LABEL: self.server_name} + ).inc() logger.debug("Sending request %s %s", method, uri) @@ -948,7 +967,11 @@ class ReplicationClient(BaseHttpClient): response = await make_deferred_yieldable(request_deferred) - incoming_responses_counter.labels(method, response.code).inc() + incoming_responses_counter.labels( + method=method, + code=response.code, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() logger.info( "Received response to %s %s: %s", method, @@ -957,7 +980,11 @@ class ReplicationClient(BaseHttpClient): ) return response except Exception as e: - incoming_responses_counter.labels(method, "ERR").inc() + incoming_responses_counter.labels( + method=method, + code="ERR", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() logger.info( "Error sending request to %s %s: %s %s", method, diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 15609a799f..6ebadf0dbf 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -21,7 +21,6 @@ import logging import urllib.parse from typing import Any, Generator, List, Optional from urllib.request import ( # type: ignore[attr-defined] - getproxies_environment, proxy_bypass_environment, ) @@ -40,6 +39,7 @@ from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer, IResponse +from synapse.config.server import ProxyConfig from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http import proxyagent from synapse.http.client import BlocklistingAgentWrapper, BlocklistingReactorWrapper @@ -77,6 +77,8 @@ class MatrixFederationAgent: ip_blocklist: Disallowed IP addresses. + proxy_config: Proxy configuration to use for this agent. + proxy_reactor: twisted reactor to use for connections to the proxy server reactor might have some blocking applied (i.e. for DNS queries), but we need unblocked access to the proxy. @@ -92,12 +94,14 @@ class MatrixFederationAgent: def __init__( self, + *, server_name: str, reactor: ISynapseReactor, tls_client_options_factory: Optional[FederationPolicyForHTTPS], user_agent: bytes, ip_allowlist: Optional[IPSet], ip_blocklist: IPSet, + proxy_config: Optional[ProxyConfig] = None, _srv_resolver: Optional[SrvResolver] = None, _well_known_resolver: Optional[WellKnownResolver] = None, ): @@ -129,10 +133,11 @@ class MatrixFederationAgent: self._agent = Agent.usingEndpointFactory( reactor, MatrixHostnameEndpointFactory( - reactor, - proxy_reactor, - tls_client_options_factory, - _srv_resolver, + reactor=reactor, + proxy_reactor=proxy_reactor, + tls_client_options_factory=tls_client_options_factory, + srv_resolver=_srv_resolver, + proxy_config=proxy_config, ), pool=self._pool, ) @@ -144,11 +149,11 @@ class MatrixFederationAgent: reactor=reactor, agent=BlocklistingAgentWrapper( ProxyAgent( - reactor, - proxy_reactor, + reactor=reactor, + proxy_reactor=proxy_reactor, pool=self._pool, contextFactory=tls_client_options_factory, - use_proxy=True, + proxy_config=proxy_config, ), ip_blocklist=ip_blocklist, ), @@ -246,14 +251,17 @@ class MatrixHostnameEndpointFactory: def __init__( self, + *, reactor: IReactorCore, proxy_reactor: IReactorCore, tls_client_options_factory: Optional[FederationPolicyForHTTPS], srv_resolver: Optional[SrvResolver], + proxy_config: Optional[ProxyConfig], ): self._reactor = reactor self._proxy_reactor = proxy_reactor self._tls_client_options_factory = tls_client_options_factory + self._proxy_config = proxy_config if srv_resolver is None: srv_resolver = SrvResolver() @@ -262,11 +270,12 @@ class MatrixHostnameEndpointFactory: def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint": return MatrixHostnameEndpoint( - self._reactor, - self._proxy_reactor, - self._tls_client_options_factory, - self._srv_resolver, - parsed_uri, + reactor=self._reactor, + proxy_reactor=self._proxy_reactor, + tls_client_options_factory=self._tls_client_options_factory, + srv_resolver=self._srv_resolver, + proxy_config=self._proxy_config, + parsed_uri=parsed_uri, ) @@ -283,6 +292,7 @@ class MatrixHostnameEndpoint: tls_client_options_factory: factory to use for fetching client tls options, or none to disable TLS. srv_resolver: The SRV resolver to use + proxy_config: Proxy configuration to use for this agent. parsed_uri: The parsed URI that we're wanting to connect to. Raises: @@ -292,26 +302,28 @@ class MatrixHostnameEndpoint: def __init__( self, + *, reactor: IReactorCore, proxy_reactor: IReactorCore, tls_client_options_factory: Optional[FederationPolicyForHTTPS], srv_resolver: SrvResolver, + proxy_config: Optional[ProxyConfig], parsed_uri: URI, ): self._reactor = reactor self._parsed_uri = parsed_uri + self.proxy_config = proxy_config # http_proxy is not needed because federation is always over TLS - proxies = getproxies_environment() - https_proxy = proxies["https"].encode() if "https" in proxies else None - self.no_proxy = proxies["no"] if "no" in proxies else None # endpoint and credentials to use to connect to the outbound https proxy, if any. ( self._https_proxy_endpoint, self._https_proxy_creds, ) = proxyagent.http_proxy_endpoint( - https_proxy, + self.proxy_config.https_proxy.encode() + if self.proxy_config and self.proxy_config.https_proxy + else None, proxy_reactor, tls_client_options_factory, ) @@ -348,10 +360,10 @@ class MatrixHostnameEndpoint: port = server.port should_skip_proxy = False - if self.no_proxy is not None: + if self.proxy_config is not None: should_skip_proxy = proxy_bypass_environment( host.decode(), - proxies={"no": self.no_proxy}, + proxies=self.proxy_config.get_proxies_dictionary(), ) endpoint: IStreamClientEndpoint diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 0013b97723..15f8e147ab 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -87,6 +87,7 @@ from synapse.http.types import QueryParams from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag, start_active_span, tags +from synapse.metrics import SERVER_NAME_LABEL from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import AwakenableSleeper, Linearizer, timeout_deferred @@ -99,10 +100,14 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) outgoing_requests_counter = Counter( - "synapse_http_matrixfederationclient_requests", "", ["method"] + "synapse_http_matrixfederationclient_requests", + "", + labelnames=["method", SERVER_NAME_LABEL], ) incoming_responses_counter = Counter( - "synapse_http_matrixfederationclient_responses", "", ["method", "code"] + "synapse_http_matrixfederationclient_responses", + "", + labelnames=["method", "code", SERVER_NAME_LABEL], ) @@ -423,6 +428,7 @@ class MatrixFederationHttpClient: user_agent=user_agent.encode("ascii"), ip_allowlist=hs.config.server.federation_ip_range_allowlist, ip_blocklist=hs.config.server.federation_ip_range_blocklist, + proxy_config=hs.config.server.proxy_config, ) else: proxy_authorization_secret = hs.config.worker.worker_replication_secret @@ -437,9 +443,9 @@ class MatrixFederationHttpClient: # locations federation_proxy_locations = outbound_federation_restricted_to.locations federation_agent = ProxyAgent( - self.reactor, - self.reactor, - tls_client_options_factory, + reactor=self.reactor, + proxy_reactor=self.reactor, + contextFactory=tls_client_options_factory, federation_proxy_locations=federation_proxy_locations, federation_proxy_credentials=federation_proxy_credentials, ) @@ -619,9 +625,10 @@ class MatrixFederationHttpClient: raise FederationDeniedError(request.destination) limiter = await synapse.util.retryutils.get_retry_limiter( - request.destination, - self.clock, - self._store, + destination=request.destination, + our_server_name=self.server_name, + clock=self.clock, + store=self._store, backoff_on_404=backoff_on_404, ignore_backoff=ignore_backoff, notifier=self.hs.get_notifier(), @@ -695,7 +702,9 @@ class MatrixFederationHttpClient: _sec_timeout, ) - outgoing_requests_counter.labels(request.method).inc() + outgoing_requests_counter.labels( + method=request.method, **{SERVER_NAME_LABEL: self.server_name} + ).inc() try: with Measure( @@ -734,7 +743,9 @@ class MatrixFederationHttpClient: raise RequestSendFailed(e, can_retry=True) from e incoming_responses_counter.labels( - request.method, response.code + method=request.method, + code=response.code, + **{SERVER_NAME_LABEL: self.server_name}, ).inc() set_tag(tags.HTTP_STATUS_CODE, response.code) diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index 6217f9b0b2..ab413990c5 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -24,7 +24,6 @@ import re from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple, Union, cast from urllib.parse import urlparse from urllib.request import ( # type: ignore[attr-defined] - getproxies_environment, proxy_bypass_environment, ) @@ -54,6 +53,7 @@ from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse +from synapse.config.server import ProxyConfig from synapse.config.workers import ( InstanceLocationConfig, InstanceTcpLocationConfig, @@ -99,8 +99,7 @@ class ProxyAgent(_AgentBase): pool: connection pool to be used. If None, a non-persistent pool instance will be created. - use_proxy: Whether proxy settings should be discovered and used - from conventional environment variables. + proxy_config: Proxy configuration to use for this agent. federation_proxy_locations: An optional list of locations to proxy outbound federation traffic through (only requests that use the `matrix-federation://` scheme @@ -118,13 +117,14 @@ class ProxyAgent(_AgentBase): def __init__( self, + *, reactor: IReactorCore, proxy_reactor: Optional[IReactorCore] = None, contextFactory: Optional[IPolicyForHTTPS] = None, connectTimeout: Optional[float] = None, bindAddress: Optional[bytes] = None, pool: Optional[HTTPConnectionPool] = None, - use_proxy: bool = False, + proxy_config: Optional[ProxyConfig] = None, federation_proxy_locations: Collection[InstanceLocationConfig] = (), federation_proxy_credentials: Optional[ProxyCredentials] = None, ): @@ -145,31 +145,33 @@ class ProxyAgent(_AgentBase): if bindAddress is not None: self._endpoint_kwargs["bindAddress"] = bindAddress - http_proxy = None - https_proxy = None - no_proxy = None - if use_proxy: - proxies = getproxies_environment() - http_proxy = proxies["http"].encode() if "http" in proxies else None - https_proxy = proxies["https"].encode() if "https" in proxies else None - no_proxy = proxies["no"] if "no" in proxies else None + self.proxy_config = proxy_config + if self.proxy_config is not None: logger.debug( "Using proxy settings: http_proxy=%s, https_proxy=%s, no_proxy=%s", - http_proxy, - https_proxy, - no_proxy, + self.proxy_config.http_proxy, + self.proxy_config.https_proxy, + self.proxy_config.no_proxy_hosts, ) self.http_proxy_endpoint, self.http_proxy_creds = http_proxy_endpoint( - http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs + self.proxy_config.http_proxy.encode() + if self.proxy_config and self.proxy_config.http_proxy + else None, + self.proxy_reactor, + contextFactory, + **self._endpoint_kwargs, ) self.https_proxy_endpoint, self.https_proxy_creds = http_proxy_endpoint( - https_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs + self.proxy_config.https_proxy.encode() + if self.proxy_config and self.proxy_config.https_proxy + else None, + self.proxy_reactor, + contextFactory, + **self._endpoint_kwargs, ) - self.no_proxy = no_proxy - self._policy_for_https = contextFactory self._reactor = cast(IReactorTime, reactor) @@ -268,10 +270,10 @@ class ProxyAgent(_AgentBase): request_path = parsed_uri.originForm should_skip_proxy = False - if self.no_proxy is not None: + if self.proxy_config is not None: should_skip_proxy = proxy_bypass_environment( parsed_uri.host.decode(), - proxies={"no": self.no_proxy}, + proxies=self.proxy_config.get_proxies_dictionary(), ) if ( diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py index 366f06eb80..a9b049f904 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py @@ -27,40 +27,52 @@ from typing import Dict, Mapping, Set, Tuple from prometheus_client.core import Counter, Histogram from synapse.logging.context import current_context -from synapse.metrics import LaterGauge +from synapse.metrics import SERVER_NAME_LABEL, LaterGauge logger = logging.getLogger(__name__) # total number of responses served, split by method/servlet/tag response_count = Counter( - "synapse_http_server_response_count", "", ["method", "servlet", "tag"] + "synapse_http_server_response_count", + "", + labelnames=["method", "servlet", "tag", SERVER_NAME_LABEL], ) requests_counter = Counter( - "synapse_http_server_requests_received", "", ["method", "servlet"] + "synapse_http_server_requests_received", + "", + labelnames=["method", "servlet", SERVER_NAME_LABEL], ) outgoing_responses_counter = Counter( - "synapse_http_server_responses", "", ["method", "code"] + "synapse_http_server_responses", + "", + labelnames=["method", "code", SERVER_NAME_LABEL], ) response_timer = Histogram( "synapse_http_server_response_time_seconds", "sec", - ["method", "servlet", "tag", "code"], + labelnames=["method", "servlet", "tag", "code", SERVER_NAME_LABEL], ) response_ru_utime = Counter( - "synapse_http_server_response_ru_utime_seconds", "sec", ["method", "servlet", "tag"] + "synapse_http_server_response_ru_utime_seconds", + "sec", + labelnames=["method", "servlet", "tag", SERVER_NAME_LABEL], ) response_ru_stime = Counter( - "synapse_http_server_response_ru_stime_seconds", "sec", ["method", "servlet", "tag"] + "synapse_http_server_response_ru_stime_seconds", + "sec", + labelnames=["method", "servlet", "tag", SERVER_NAME_LABEL], ) response_db_txn_count = Counter( - "synapse_http_server_response_db_txn_count", "", ["method", "servlet", "tag"] + "synapse_http_server_response_db_txn_count", + "", + labelnames=["method", "servlet", "tag", SERVER_NAME_LABEL], ) # seconds spent waiting for db txns, excluding scheduling time, when processing @@ -68,34 +80,42 @@ response_db_txn_count = Counter( response_db_txn_duration = Counter( "synapse_http_server_response_db_txn_duration_seconds", "", - ["method", "servlet", "tag"], + labelnames=["method", "servlet", "tag", SERVER_NAME_LABEL], ) # seconds spent waiting for a db connection, when processing this request response_db_sched_duration = Counter( "synapse_http_server_response_db_sched_duration_seconds", "", - ["method", "servlet", "tag"], + labelnames=["method", "servlet", "tag", SERVER_NAME_LABEL], ) # size in bytes of the response written response_size = Counter( - "synapse_http_server_response_size", "", ["method", "servlet", "tag"] + "synapse_http_server_response_size", + "", + labelnames=["method", "servlet", "tag", SERVER_NAME_LABEL], ) # In flight metrics are incremented while the requests are in flight, rather # than when the response was written. in_flight_requests_ru_utime = Counter( - "synapse_http_server_in_flight_requests_ru_utime_seconds", "", ["method", "servlet"] + "synapse_http_server_in_flight_requests_ru_utime_seconds", + "", + labelnames=["method", "servlet", SERVER_NAME_LABEL], ) in_flight_requests_ru_stime = Counter( - "synapse_http_server_in_flight_requests_ru_stime_seconds", "", ["method", "servlet"] + "synapse_http_server_in_flight_requests_ru_stime_seconds", + "", + labelnames=["method", "servlet", SERVER_NAME_LABEL], ) in_flight_requests_db_txn_count = Counter( - "synapse_http_server_in_flight_requests_db_txn_count", "", ["method", "servlet"] + "synapse_http_server_in_flight_requests_db_txn_count", + "", + labelnames=["method", "servlet", SERVER_NAME_LABEL], ) # seconds spent waiting for db txns, excluding scheduling time, when processing @@ -103,14 +123,14 @@ in_flight_requests_db_txn_count = Counter( in_flight_requests_db_txn_duration = Counter( "synapse_http_server_in_flight_requests_db_txn_duration_seconds", "", - ["method", "servlet"], + labelnames=["method", "servlet", SERVER_NAME_LABEL], ) # seconds spent waiting for a db connection, when processing this request in_flight_requests_db_sched_duration = Counter( "synapse_http_server_in_flight_requests_db_sched_duration_seconds", "", - ["method", "servlet"], + labelnames=["method", "servlet", SERVER_NAME_LABEL], ) _in_flight_requests: Set["RequestMetrics"] = set() @@ -124,31 +144,42 @@ def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]: # Cast to a list to prevent it changing while the Prometheus # thread is collecting metrics with _in_flight_requests_lock: - reqs = list(_in_flight_requests) + request_metrics = list(_in_flight_requests) - for rm in reqs: - rm.update_metrics() + for request_metric in request_metrics: + request_metric.update_metrics() # Map from (method, name) -> int, the number of in flight requests of that # type. The key type is Tuple[str, str], but we leave the length unspecified # for compatability with LaterGauge's annotations. counts: Dict[Tuple[str, ...], int] = {} - for rm in reqs: - key = (rm.method, rm.name) + for request_metric in request_metrics: + key = ( + request_metric.method, + request_metric.name, + request_metric.our_server_name, + ) counts[key] = counts.get(key, 0) + 1 return counts LaterGauge( - "synapse_http_server_in_flight_requests_count", - "", - ["method", "servlet"], - _get_in_flight_counts, + name="synapse_http_server_in_flight_requests_count", + desc="", + labelnames=["method", "servlet", SERVER_NAME_LABEL], + caller=_get_in_flight_counts, ) class RequestMetrics: + def __init__(self, our_server_name: str) -> None: + """ + Args: + our_server_name: Our homeserver name (used to label metrics) (`hs.hostname`) + """ + self.our_server_name = our_server_name + def start(self, time_sec: float, name: str, method: str) -> None: self.start_ts = time_sec self.start_context = current_context() @@ -194,33 +225,40 @@ class RequestMetrics: response_code_str = str(response_code) - outgoing_responses_counter.labels(self.method, response_code_str).inc() + outgoing_responses_counter.labels( + method=self.method, + code=response_code_str, + **{SERVER_NAME_LABEL: self.our_server_name}, + ).inc() - response_count.labels(self.method, self.name, tag).inc() + response_base_labels = { + "method": self.method, + "servlet": self.name, + "tag": tag, + SERVER_NAME_LABEL: self.our_server_name, + } - response_timer.labels(self.method, self.name, tag, response_code_str).observe( - time_sec - self.start_ts - ) + response_count.labels(**response_base_labels).inc() + + response_timer.labels( + code=response_code_str, + **response_base_labels, + ).observe(time_sec - self.start_ts) resource_usage = context.get_resource_usage() - response_ru_utime.labels(self.method, self.name, tag).inc( - resource_usage.ru_utime - ) - response_ru_stime.labels(self.method, self.name, tag).inc( - resource_usage.ru_stime - ) - response_db_txn_count.labels(self.method, self.name, tag).inc( + response_ru_utime.labels(**response_base_labels).inc(resource_usage.ru_utime) + response_ru_stime.labels(**response_base_labels).inc(resource_usage.ru_stime) + response_db_txn_count.labels(**response_base_labels).inc( resource_usage.db_txn_count ) - response_db_txn_duration.labels(self.method, self.name, tag).inc( + response_db_txn_duration.labels(**response_base_labels).inc( resource_usage.db_txn_duration_sec ) - response_db_sched_duration.labels(self.method, self.name, tag).inc( + response_db_sched_duration.labels(**response_base_labels).inc( resource_usage.db_sched_duration_sec ) - - response_size.labels(self.method, self.name, tag).inc(sent_bytes) + response_size.labels(**response_base_labels).inc(sent_bytes) # We always call this at the end to ensure that we update the metrics # regardless of whether a call to /metrics while the request was in @@ -240,24 +278,30 @@ class RequestMetrics: diff = new_stats - self._request_stats self._request_stats = new_stats + in_flight_labels = { + "method": self.method, + "servlet": self.name, + SERVER_NAME_LABEL: self.our_server_name, + } + # max() is used since rapid use of ru_stime/ru_utime can end up with the # count going backwards due to NTP, time smearing, fine-grained # correction, or floating points. Who knows, really? - in_flight_requests_ru_utime.labels(self.method, self.name).inc( + in_flight_requests_ru_utime.labels(**in_flight_labels).inc( max(diff.ru_utime, 0) ) - in_flight_requests_ru_stime.labels(self.method, self.name).inc( + in_flight_requests_ru_stime.labels(**in_flight_labels).inc( max(diff.ru_stime, 0) ) - in_flight_requests_db_txn_count.labels(self.method, self.name).inc( + in_flight_requests_db_txn_count.labels(**in_flight_labels).inc( diff.db_txn_count ) - in_flight_requests_db_txn_duration.labels(self.method, self.name).inc( + in_flight_requests_db_txn_duration.labels(**in_flight_labels).inc( diff.db_txn_duration_sec ) - in_flight_requests_db_sched_duration.labels(self.method, self.name).inc( + in_flight_requests_db_sched_duration.labels(**in_flight_labels).inc( diff.db_sched_duration_sec ) diff --git a/synapse/http/server.py b/synapse/http/server.py index 395d82fd16..f8f58ec6d0 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -337,7 +337,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): callback_return = await self._async_render(request) except LimitExceededError as e: if e.pause: - self._clock.sleep(e.pause) + await self._clock.sleep(e.pause) raise if callback_return is not None: diff --git a/synapse/http/site.py b/synapse/http/site.py index e83a4447b2..55088fc190 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -44,6 +44,7 @@ from synapse.logging.context import ( LoggingContext, PreserveLoggingContext, ) +from synapse.metrics import SERVER_NAME_LABEL from synapse.types import ISynapseReactor, Requester if TYPE_CHECKING: @@ -83,12 +84,14 @@ class SynapseRequest(Request): self, channel: HTTPChannel, site: "SynapseSite", + our_server_name: str, *args: Any, max_request_body_size: int = 1024, request_id_header: Optional[str] = None, **kw: Any, ): super().__init__(channel, *args, **kw) + self.our_server_name = our_server_name self._max_request_body_size = max_request_body_size self.request_id_header = request_id_header self.synapse_site = site @@ -334,7 +337,11 @@ class SynapseRequest(Request): # dispatching to the handler, so that the handler # can update the servlet name in the request # metrics - requests_counter.labels(self.get_method(), self.request_metrics.name).inc() + requests_counter.labels( + method=self.get_method(), + servlet=self.request_metrics.name, + **{SERVER_NAME_LABEL: self.our_server_name}, + ).inc() @contextlib.contextmanager def processing(self) -> Generator[None, None, None]: @@ -455,7 +462,7 @@ class SynapseRequest(Request): self.request_metrics.name. """ self.start_time = time.time() - self.request_metrics = RequestMetrics() + self.request_metrics = RequestMetrics(our_server_name=self.our_server_name) self.request_metrics.start( self.start_time, name=servlet_name, method=self.get_method() ) @@ -694,6 +701,7 @@ class SynapseSite(ProxySite): self.site_tag = site_tag self.reactor: ISynapseReactor = reactor + self.server_name = hs.hostname assert config.http_options is not None proxied = config.http_options.x_forwarded @@ -705,6 +713,7 @@ class SynapseSite(ProxySite): return request_class( channel, self, + our_server_name=self.server_name, max_request_body_size=max_request_body_size, queued=queued, request_id_header=request_id_header, diff --git a/synapse/logging/loggers.py b/synapse/logging/loggers.py new file mode 100644 index 0000000000..7f7bfef5d4 --- /dev/null +++ b/synapse/logging/loggers.py @@ -0,0 +1,25 @@ +import logging + +root_logger = logging.getLogger() + + +class ExplicitlyConfiguredLogger(logging.Logger): + """ + A custom logger class that only allows logging if the logger is explicitly + configured (does not inherit log level from parent). + """ + + def isEnabledFor(self, level: int) -> bool: + # Check if the logger is explicitly configured + explicitly_configured_logger = self.manager.loggerDict.get(self.name) + + log_level = logging.NOTSET + if isinstance(explicitly_configured_logger, logging.Logger): + log_level = explicitly_configured_logger.level + + # If the logger is not configured, we don't log anything + if log_level == logging.NOTSET: + return False + + # Otherwise, follow the normal logging behavior + return level >= log_level diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index d7259176e7..aae88d25c9 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -186,12 +186,16 @@ class MediaRepository: def _start_update_recently_accessed(self) -> Deferred: return run_as_background_process( - "update_recently_accessed_media", self._update_recently_accessed + "update_recently_accessed_media", + self.server_name, + self._update_recently_accessed, ) def _start_apply_media_retention_rules(self) -> Deferred: return run_as_background_process( - "apply_media_retention_rules", self._apply_media_retention_rules + "apply_media_retention_rules", + self.server_name, + self._apply_media_retention_rules, ) async def _update_recently_accessed(self) -> None: diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py index eb0104e543..8f106a3d5f 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py @@ -740,7 +740,7 @@ class UrlPreviewer: def _start_expire_url_cache_data(self) -> Deferred: return run_as_background_process( - "expire_url_cache_data", self._expire_url_cache_data + "expire_url_cache_data", self.server_name, self._expire_url_cache_data ) async def _expire_url_cache_data(self) -> None: diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py index de750a5de2..11e2551a16 100644 --- a/synapse/metrics/__init__.py +++ b/synapse/metrics/__init__.py @@ -33,6 +33,7 @@ from typing import ( Iterable, Mapping, Optional, + Sequence, Set, Tuple, Type, @@ -91,6 +92,7 @@ terms, an endpoint you can scrape is called an *instance*, usually corresponding single process." (source: https://prometheus.io/docs/concepts/jobs_instances/) """ + CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8" """ Content type of the latest text format for Prometheus metrics. @@ -154,13 +156,13 @@ class _RegistryProxy: RegistryProxy = cast(CollectorRegistry, _RegistryProxy) -@attr.s(slots=True, hash=True, auto_attribs=True) +@attr.s(slots=True, hash=True, auto_attribs=True, kw_only=True) class LaterGauge(Collector): """A Gauge which periodically calls a user-provided callback to produce metrics.""" name: str desc: str - labels: Optional[StrSequence] = attr.ib(hash=False) + labelnames: Optional[StrSequence] = attr.ib(hash=False) # callback: should either return a value (if there are no labels for this metric), # or dict mapping from a label tuple to a value caller: Callable[ @@ -168,7 +170,9 @@ class LaterGauge(Collector): ] def collect(self) -> Iterable[Metric]: - g = GaugeMetricFamily(self.name, self.desc, labels=self.labels) + # The decision to add `SERVER_NAME_LABEL` is from the `LaterGauge` usage itself + # (we don't enforce it here, one level up). + g = GaugeMetricFamily(self.name, self.desc, labels=self.labelnames) # type: ignore[missing-server-name-label] try: calls = self.caller() @@ -302,7 +306,9 @@ class InFlightGauge(Generic[MetricsEntry], Collector): Note: may be called by a separate thread. """ - in_flight = GaugeMetricFamily( + # The decision to add `SERVER_NAME_LABEL` is from the `GaugeBucketCollector` + # usage itself (we don't enforce it here, one level up). + in_flight = GaugeMetricFamily( # type: ignore[missing-server-name-label] self.name + "_total", self.desc, labels=self.labels ) @@ -326,7 +332,9 @@ class InFlightGauge(Generic[MetricsEntry], Collector): yield in_flight for name in self.sub_metrics: - gauge = GaugeMetricFamily( + # The decision to add `SERVER_NAME_LABEL` is from the `InFlightGauge` usage + # itself (we don't enforce it here, one level up). + gauge = GaugeMetricFamily( # type: ignore[missing-server-name-label] "_".join([self.name, name]), "", labels=self.labels ) for key, metrics in metrics_by_key.items(): @@ -342,6 +350,51 @@ class InFlightGauge(Generic[MetricsEntry], Collector): all_gauges[self.name] = self +class GaugeHistogramMetricFamilyWithLabels(GaugeHistogramMetricFamily): + """ + Custom version of `GaugeHistogramMetricFamily` from `prometheus_client` that allows + specifying labels and label values. + + A single gauge histogram and its samples. + + For use by custom collectors. + """ + + def __init__( + self, + *, + name: str, + documentation: str, + gsum_value: float, + buckets: Optional[Sequence[Tuple[str, float]]] = None, + labelnames: StrSequence = (), + labelvalues: StrSequence = (), + unit: str = "", + ): + # Sanity check the number of label values matches the number of label names. + if len(labelvalues) != len(labelnames): + raise ValueError( + "The number of label values must match the number of label names" + ) + + # Call the super to validate and set the labelnames. We use this stable API + # instead of setting the internal `_labelnames` field directly. + super().__init__( + name=name, + documentation=documentation, + labels=labelnames, + # Since `GaugeHistogramMetricFamily` doesn't support supplying `labels` and + # `buckets` at the same time (artificial limitation), we will just set these + # as `None` and set up the buckets ourselves just below. + buckets=None, + gsum_value=None, + ) + + # Create a gauge for each bucket. + if buckets is not None: + self.add_metric(labels=labelvalues, buckets=buckets, gsum_value=gsum_value) + + class GaugeBucketCollector(Collector): """Like a Histogram, but the buckets are Gauges which are updated atomically. @@ -354,14 +407,17 @@ class GaugeBucketCollector(Collector): __slots__ = ( "_name", "_documentation", + "_labelnames", "_bucket_bounds", "_metric", ) def __init__( self, + *, name: str, documentation: str, + labelnames: Optional[StrSequence], buckets: Iterable[float], registry: CollectorRegistry = REGISTRY, ): @@ -375,6 +431,7 @@ class GaugeBucketCollector(Collector): """ self._name = name self._documentation = documentation + self._labelnames = labelnames if labelnames else () # the tops of the buckets self._bucket_bounds = [float(b) for b in buckets] @@ -386,7 +443,7 @@ class GaugeBucketCollector(Collector): # We initially set this to None. We won't report metrics until # this has been initialised after a successful data update - self._metric: Optional[GaugeHistogramMetricFamily] = None + self._metric: Optional[GaugeHistogramMetricFamilyWithLabels] = None registry.register(self) @@ -395,15 +452,26 @@ class GaugeBucketCollector(Collector): if self._metric is not None: yield self._metric - def update_data(self, values: Iterable[float]) -> None: + def update_data(self, values: Iterable[float], labels: StrSequence = ()) -> None: """Update the data to be reported by the metric The existing data is cleared, and each measurement in the input is assigned to the relevant bucket. - """ - self._metric = self._values_to_metric(values) - def _values_to_metric(self, values: Iterable[float]) -> GaugeHistogramMetricFamily: + Args: + values + labels + """ + self._metric = self._values_to_metric(values, labels) + + def _values_to_metric( + self, values: Iterable[float], labels: StrSequence = () + ) -> GaugeHistogramMetricFamilyWithLabels: + """ + Args: + values + labels + """ total = 0.0 bucket_values = [0 for _ in self._bucket_bounds] @@ -421,9 +489,13 @@ class GaugeBucketCollector(Collector): # that bucket or below. accumulated_values = itertools.accumulate(bucket_values) - return GaugeHistogramMetricFamily( - self._name, - self._documentation, + # The decision to add `SERVER_NAME_LABEL` is from the `GaugeBucketCollector` + # usage itself (we don't enforce it here, one level up). + return GaugeHistogramMetricFamilyWithLabels( # type: ignore[missing-server-name-label] + name=self._name, + documentation=self._documentation, + labelnames=self._labelnames, + labelvalues=labels, buckets=list( zip((str(b) for b in self._bucket_bounds), accumulated_values) ), @@ -455,61 +527,82 @@ class CPUMetrics(Collector): line = s.read() raw_stats = line.split(") ", 1)[1].split(" ") - user = GaugeMetricFamily("process_cpu_user_seconds_total", "") + # This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. + user = GaugeMetricFamily("process_cpu_user_seconds_total", "") # type: ignore[missing-server-name-label] user.add_metric([], float(raw_stats[11]) / self.ticks_per_sec) yield user - sys = GaugeMetricFamily("process_cpu_system_seconds_total", "") + # This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. + sys = GaugeMetricFamily("process_cpu_system_seconds_total", "") # type: ignore[missing-server-name-label] sys.add_metric([], float(raw_stats[12]) / self.ticks_per_sec) yield sys -REGISTRY.register(CPUMetrics()) +# This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. +REGISTRY.register(CPUMetrics()) # type: ignore[missing-server-name-label] # # Federation Metrics # -sent_transactions_counter = Counter("synapse_federation_client_sent_transactions", "") +sent_transactions_counter = Counter( + "synapse_federation_client_sent_transactions", "", labelnames=[SERVER_NAME_LABEL] +) -events_processed_counter = Counter("synapse_federation_client_events_processed", "") +events_processed_counter = Counter( + "synapse_federation_client_events_processed", "", labelnames=[SERVER_NAME_LABEL] +) event_processing_loop_counter = Counter( - "synapse_event_processing_loop_count", "Event processing loop iterations", ["name"] + "synapse_event_processing_loop_count", + "Event processing loop iterations", + labelnames=["name", SERVER_NAME_LABEL], ) event_processing_loop_room_count = Counter( "synapse_event_processing_loop_room_count", "Rooms seen per event processing loop iteration", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], ) # Used to track where various components have processed in the event stream, # e.g. federation sending, appservice sending, etc. -event_processing_positions = Gauge("synapse_event_processing_positions", "", ["name"]) +event_processing_positions = Gauge( + "synapse_event_processing_positions", "", labelnames=["name", SERVER_NAME_LABEL] +) # Used to track the current max events stream position -event_persisted_position = Gauge("synapse_event_persisted_position", "") +event_persisted_position = Gauge( + "synapse_event_persisted_position", "", labelnames=[SERVER_NAME_LABEL] +) # Used to track the received_ts of the last event processed by various # components -event_processing_last_ts = Gauge("synapse_event_processing_last_ts", "", ["name"]) +event_processing_last_ts = Gauge( + "synapse_event_processing_last_ts", "", labelnames=["name", SERVER_NAME_LABEL] +) # Used to track the lag processing events. This is the time difference # between the last processed event's received_ts and the time it was # finished being processed. -event_processing_lag = Gauge("synapse_event_processing_lag", "", ["name"]) +event_processing_lag = Gauge( + "synapse_event_processing_lag", "", labelnames=["name", SERVER_NAME_LABEL] +) event_processing_lag_by_event = Histogram( "synapse_event_processing_lag_by_event", "Time between an event being persisted and it being queued up to be sent to the relevant remote servers", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], ) # Build info of the running server. -build_info = Gauge( +# +# This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. We +# consider this process-level because all Synapse homeservers running in the process +# will use the same Synapse version. +build_info = Gauge( # type: ignore[missing-server-name-label] "synapse_build_info", "Build information", ["pythonversion", "version", "osversion"] ) build_info.labels( @@ -525,44 +618,57 @@ threepid_send_requests = Histogram( " there is a request with try count of 4, then there would have been one" " each for 1, 2 and 3", buckets=(1, 2, 3, 4, 5, 10), - labelnames=("type", "reason"), + labelnames=("type", "reason", SERVER_NAME_LABEL), ) threadpool_total_threads = Gauge( "synapse_threadpool_total_threads", "Total number of threads currently in the threadpool", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], ) threadpool_total_working_threads = Gauge( "synapse_threadpool_working_threads", "Number of threads currently working in the threadpool", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], ) threadpool_total_min_threads = Gauge( "synapse_threadpool_min_threads", "Minimum number of threads configured in the threadpool", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], ) threadpool_total_max_threads = Gauge( "synapse_threadpool_max_threads", "Maximum number of threads configured in the threadpool", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], ) -def register_threadpool(name: str, threadpool: ThreadPool) -> None: - """Add metrics for the threadpool.""" +def register_threadpool(*, name: str, server_name: str, threadpool: ThreadPool) -> None: + """ + Add metrics for the threadpool. - threadpool_total_min_threads.labels(name).set(threadpool.min) - threadpool_total_max_threads.labels(name).set(threadpool.max) + Args: + name: The name of the threadpool, used to identify it in the metrics. + server_name: The homeserver name (used to label metrics) (this should be `hs.hostname`). + threadpool: The threadpool to register metrics for. + """ - threadpool_total_threads.labels(name).set_function(lambda: len(threadpool.threads)) - threadpool_total_working_threads.labels(name).set_function( - lambda: len(threadpool.working) - ) + threadpool_total_min_threads.labels( + name=name, **{SERVER_NAME_LABEL: server_name} + ).set(threadpool.min) + threadpool_total_max_threads.labels( + name=name, **{SERVER_NAME_LABEL: server_name} + ).set(threadpool.max) + + threadpool_total_threads.labels( + name=name, **{SERVER_NAME_LABEL: server_name} + ).set_function(lambda: len(threadpool.threads)) + threadpool_total_working_threads.labels( + name=name, **{SERVER_NAME_LABEL: server_name} + ).set_function(lambda: len(threadpool.working)) class MetricsResource(Resource): diff --git a/synapse/metrics/_gc.py b/synapse/metrics/_gc.py index d16481a0f6..e7783b05e6 100644 --- a/synapse/metrics/_gc.py +++ b/synapse/metrics/_gc.py @@ -54,8 +54,9 @@ running_on_pypy = platform.python_implementation() == "PyPy" # Python GC metrics # -gc_unreachable = Gauge("python_gc_unreachable_total", "Unreachable GC objects", ["gen"]) -gc_time = Histogram( +# These are process-level metrics, so they do not have the `SERVER_NAME_LABEL`. +gc_unreachable = Gauge("python_gc_unreachable_total", "Unreachable GC objects", ["gen"]) # type: ignore[missing-server-name-label] +gc_time = Histogram( # type: ignore[missing-server-name-label] "python_gc_time", "Time taken to GC (sec)", ["gen"], @@ -82,7 +83,8 @@ gc_time = Histogram( class GCCounts(Collector): def collect(self) -> Iterable[Metric]: - cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) + # This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. + cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"]) # type: ignore[missing-server-name-label] for n, m in enumerate(gc.get_count()): cm.add_metric([str(n)], m) @@ -101,7 +103,8 @@ def install_gc_manager() -> None: if running_on_pypy: return - REGISTRY.register(GCCounts()) + # This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. + REGISTRY.register(GCCounts()) # type: ignore[missing-server-name-label] gc.disable() @@ -176,7 +179,8 @@ class PyPyGCStats(Collector): # # Total time spent in GC: 0.073 # s.total_gc_time - pypy_gc_time = CounterMetricFamily( + # This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. + pypy_gc_time = CounterMetricFamily( # type: ignore[missing-server-name-label] "pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[], @@ -184,7 +188,8 @@ class PyPyGCStats(Collector): pypy_gc_time.add_metric([], s.total_gc_time / 1000) yield pypy_gc_time - pypy_mem = GaugeMetricFamily( + # This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. + pypy_mem = GaugeMetricFamily( # type: ignore[missing-server-name-label] "pypy_memory_bytes", "Memory tracked by PyPy allocator", labels=["state", "class", "kind"], @@ -208,4 +213,5 @@ class PyPyGCStats(Collector): if running_on_pypy: - REGISTRY.register(PyPyGCStats()) + # This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. + REGISTRY.register(PyPyGCStats()) # type: ignore[missing-server-name-label] diff --git a/synapse/metrics/_reactor_metrics.py b/synapse/metrics/_reactor_metrics.py index fda0cd018b..9852d0b932 100644 --- a/synapse/metrics/_reactor_metrics.py +++ b/synapse/metrics/_reactor_metrics.py @@ -62,7 +62,8 @@ logger = logging.getLogger(__name__) # Twisted reactor metrics # -tick_time = Histogram( +# This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. +tick_time = Histogram( # type: ignore[missing-server-name-label] "python_twisted_reactor_tick_time", "Tick time of the Twisted reactor (sec)", buckets=[0.001, 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.5, 1, 2, 5], @@ -114,7 +115,8 @@ class ReactorLastSeenMetric(Collector): self._call_wrapper = call_wrapper def collect(self) -> Iterable[Metric]: - cm = GaugeMetricFamily( + # This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. + cm = GaugeMetricFamily( # type: ignore[missing-server-name-label] "python_twisted_reactor_last_seen", "Seconds since the Twisted reactor was last seen", ) @@ -165,4 +167,5 @@ except Exception as e: if wrapper: - REGISTRY.register(ReactorLastSeenMetric(wrapper)) + # This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. + REGISTRY.register(ReactorLastSeenMetric(wrapper)) # type: ignore[missing-server-name-label] diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 49d0ff9fc1..f7f2d88885 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -31,6 +31,7 @@ from typing import ( Dict, Iterable, Optional, + Protocol, Set, Type, TypeVar, @@ -39,7 +40,7 @@ from typing import ( from prometheus_client import Metric from prometheus_client.core import REGISTRY, Counter, Gauge -from typing_extensions import ParamSpec +from typing_extensions import Concatenate, ParamSpec from twisted.internet import defer @@ -49,6 +50,7 @@ from synapse.logging.context import ( PreserveLoggingContext, ) from synapse.logging.opentracing import SynapseTags, start_active_span +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics._types import Collector if TYPE_CHECKING: @@ -64,13 +66,13 @@ logger = logging.getLogger(__name__) _background_process_start_count = Counter( "synapse_background_process_start_count", "Number of background processes started", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], ) _background_process_in_flight_count = Gauge( "synapse_background_process_in_flight_count", "Number of background processes in flight", - labelnames=["name"], + labelnames=["name", SERVER_NAME_LABEL], ) # we set registry=None in all of these to stop them getting registered with @@ -80,21 +82,21 @@ _background_process_in_flight_count = Gauge( _background_process_ru_utime = Counter( "synapse_background_process_ru_utime_seconds", "User CPU time used by background processes, in seconds", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], registry=None, ) _background_process_ru_stime = Counter( "synapse_background_process_ru_stime_seconds", "System CPU time used by background processes, in seconds", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], registry=None, ) _background_process_db_txn_count = Counter( "synapse_background_process_db_txn_count", "Number of database transactions done by background processes", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], registry=None, ) @@ -104,14 +106,14 @@ _background_process_db_txn_duration = Counter( "Seconds spent by background processes waiting for database " "transactions, excluding scheduling time" ), - ["name"], + labelnames=["name", SERVER_NAME_LABEL], registry=None, ) _background_process_db_sched_duration = Counter( "synapse_background_process_db_sched_duration_seconds", "Seconds spent by background processes waiting for database connections", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], registry=None, ) @@ -165,12 +167,15 @@ class _Collector(Collector): yield from m.collect() -REGISTRY.register(_Collector()) +# The `SERVER_NAME_LABEL` is included in the individual metrics added to this registry, +# so we don't need to worry about it on the collector itself. +REGISTRY.register(_Collector()) # type: ignore[missing-server-name-label] class _BackgroundProcess: - def __init__(self, desc: str, ctx: LoggingContext): + def __init__(self, *, desc: str, server_name: str, ctx: LoggingContext): self.desc = desc + self.server_name = server_name self._context = ctx self._reported_stats: Optional[ContextResourceUsage] = None @@ -185,15 +190,21 @@ class _BackgroundProcess: # For unknown reasons, the difference in times can be negative. See comment in # synapse.http.request_metrics.RequestMetrics.update_metrics. - _background_process_ru_utime.labels(self.desc).inc(max(diff.ru_utime, 0)) - _background_process_ru_stime.labels(self.desc).inc(max(diff.ru_stime, 0)) - _background_process_db_txn_count.labels(self.desc).inc(diff.db_txn_count) - _background_process_db_txn_duration.labels(self.desc).inc( - diff.db_txn_duration_sec - ) - _background_process_db_sched_duration.labels(self.desc).inc( - diff.db_sched_duration_sec - ) + _background_process_ru_utime.labels( + name=self.desc, **{SERVER_NAME_LABEL: self.server_name} + ).inc(max(diff.ru_utime, 0)) + _background_process_ru_stime.labels( + name=self.desc, **{SERVER_NAME_LABEL: self.server_name} + ).inc(max(diff.ru_stime, 0)) + _background_process_db_txn_count.labels( + name=self.desc, **{SERVER_NAME_LABEL: self.server_name} + ).inc(diff.db_txn_count) + _background_process_db_txn_duration.labels( + name=self.desc, **{SERVER_NAME_LABEL: self.server_name} + ).inc(diff.db_txn_duration_sec) + _background_process_db_sched_duration.labels( + name=self.desc, **{SERVER_NAME_LABEL: self.server_name} + ).inc(diff.db_sched_duration_sec) R = TypeVar("R") @@ -201,6 +212,7 @@ R = TypeVar("R") def run_as_background_process( desc: "LiteralString", + server_name: str, func: Callable[..., Awaitable[Optional[R]]], *args: Any, bg_start_span: bool = True, @@ -218,6 +230,8 @@ def run_as_background_process( Args: desc: a description for this background process type + server_name: The homeserver name that this background process is being run for + (this should be `hs.hostname`). func: a function, which may return a Deferred or a coroutine bg_start_span: Whether to start an opentracing span. Defaults to True. Should only be disabled for processes that will not log to or tag @@ -236,10 +250,16 @@ def run_as_background_process( count = _background_process_counts.get(desc, 0) _background_process_counts[desc] = count + 1 - _background_process_start_count.labels(desc).inc() - _background_process_in_flight_count.labels(desc).inc() + _background_process_start_count.labels( + name=desc, **{SERVER_NAME_LABEL: server_name} + ).inc() + _background_process_in_flight_count.labels( + name=desc, **{SERVER_NAME_LABEL: server_name} + ).inc() - with BackgroundProcessLoggingContext(desc, count) as context: + with BackgroundProcessLoggingContext( + name=desc, server_name=server_name, instance_id=count + ) as context: try: if bg_start_span: ctx = start_active_span( @@ -256,7 +276,9 @@ def run_as_background_process( ) return None finally: - _background_process_in_flight_count.labels(desc).dec() + _background_process_in_flight_count.labels( + name=desc, **{SERVER_NAME_LABEL: server_name} + ).dec() with PreserveLoggingContext(): # Note that we return a Deferred here so that it can be used in a @@ -267,6 +289,14 @@ def run_as_background_process( P = ParamSpec("P") +class HasServerName(Protocol): + server_name: str + """ + The homeserver name that this cache is associated with (used to label the metric) + (`hs.hostname`). + """ + + def wrap_as_background_process( desc: "LiteralString", ) -> Callable[ @@ -292,22 +322,37 @@ def wrap_as_background_process( multiple places. """ - def wrap_as_background_process_inner( - func: Callable[P, Awaitable[Optional[R]]], + def wrapper( + func: Callable[Concatenate[HasServerName, P], Awaitable[Optional[R]]], ) -> Callable[P, "defer.Deferred[Optional[R]]"]: @wraps(func) - def wrap_as_background_process_inner_2( - *args: P.args, **kwargs: P.kwargs + def wrapped_func( + self: HasServerName, *args: P.args, **kwargs: P.kwargs ) -> "defer.Deferred[Optional[R]]": - # type-ignore: mypy is confusing kwargs with the bg_start_span kwarg. - # Argument 4 to "run_as_background_process" has incompatible type - # "**P.kwargs"; expected "bool" - # See https://github.com/python/mypy/issues/8862 - return run_as_background_process(desc, func, *args, **kwargs) # type: ignore[arg-type] + assert self.server_name is not None, ( + "The `server_name` attribute must be set on the object where `@wrap_as_background_process` decorator is used." + ) - return wrap_as_background_process_inner_2 + return run_as_background_process( + desc, + self.server_name, + func, + self, + *args, + # type-ignore: mypy is confusing kwargs with the bg_start_span kwarg. + # Argument 4 to "run_as_background_process" has incompatible type + # "**P.kwargs"; expected "bool" + # See https://github.com/python/mypy/issues/8862 + **kwargs, # type: ignore[arg-type] + ) - return wrap_as_background_process_inner + # There are some shenanigans here, because we're decorating a method but + # explicitly making use of the `self` parameter. The key thing here is that the + # return type within the return type for `measure_func` itself describes how the + # decorated function will be called. + return wrapped_func # type: ignore[return-value] + + return wrapper # type: ignore[return-value] class BackgroundProcessLoggingContext(LoggingContext): @@ -317,13 +362,20 @@ class BackgroundProcessLoggingContext(LoggingContext): __slots__ = ["_proc"] - def __init__(self, name: str, instance_id: Optional[Union[int, str]] = None): + def __init__( + self, + *, + name: str, + server_name: str, + instance_id: Optional[Union[int, str]] = None, + ): """ Args: name: The name of the background process. Each distinct `name` gets a separate prometheus time series. - + server_name: The homeserver name that this background process is being run for + (this should be `hs.hostname`). instance_id: an identifer to add to `name` to distinguish this instance of the named background process in the logs. If this is `None`, one is made up based on id(self). @@ -331,7 +383,9 @@ class BackgroundProcessLoggingContext(LoggingContext): if instance_id is None: instance_id = id(self) super().__init__("%s-%s" % (name, instance_id)) - self._proc: Optional[_BackgroundProcess] = _BackgroundProcess(name, self) + self._proc: Optional[_BackgroundProcess] = _BackgroundProcess( + desc=name, server_name=server_name, ctx=self + ) def start(self, rusage: "Optional[resource.struct_rusage]") -> None: """Log context has started running (again).""" diff --git a/synapse/metrics/common_usage_metrics.py b/synapse/metrics/common_usage_metrics.py index 970367e9e0..cd1c3c8649 100644 --- a/synapse/metrics/common_usage_metrics.py +++ b/synapse/metrics/common_usage_metrics.py @@ -22,6 +22,7 @@ from typing import TYPE_CHECKING import attr +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import run_as_background_process if TYPE_CHECKING: @@ -33,6 +34,7 @@ from prometheus_client import Gauge current_dau_gauge = Gauge( "synapse_admin_daily_active_users", "Current daily active users count", + labelnames=[SERVER_NAME_LABEL], ) @@ -47,6 +49,7 @@ class CommonUsageMetricsManager: """Collects common usage metrics.""" def __init__(self, hs: "HomeServer") -> None: + self.server_name = hs.hostname self._store = hs.get_datastores().main self._clock = hs.get_clock() @@ -62,12 +65,15 @@ class CommonUsageMetricsManager: async def setup(self) -> None: """Keep the gauges for common usage metrics up to date.""" run_as_background_process( - desc="common_usage_metrics_update_gauges", func=self._update_gauges + desc="common_usage_metrics_update_gauges", + server_name=self.server_name, + func=self._update_gauges, ) self._clock.looping_call( run_as_background_process, 5 * 60 * 1000, desc="common_usage_metrics_update_gauges", + server_name=self.server_name, func=self._update_gauges, ) @@ -85,4 +91,6 @@ class CommonUsageMetricsManager: """Update the Prometheus gauges.""" metrics = await self._collect() - current_dau_gauge.set(float(metrics.daily_active_users)) + current_dau_gauge.labels( + **{SERVER_NAME_LABEL: self.server_name}, + ).set(float(metrics.daily_active_users)) diff --git a/synapse/metrics/jemalloc.py b/synapse/metrics/jemalloc.py index 321ff58083..fb8adbe060 100644 --- a/synapse/metrics/jemalloc.py +++ b/synapse/metrics/jemalloc.py @@ -188,7 +188,8 @@ def _setup_jemalloc_stats() -> None: def collect(self) -> Iterable[Metric]: stats.refresh_stats() - g = GaugeMetricFamily( + # This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. + g = GaugeMetricFamily( # type: ignore[missing-server-name-label] "jemalloc_stats_app_memory_bytes", "The stats reported by jemalloc", labels=["type"], @@ -230,7 +231,8 @@ def _setup_jemalloc_stats() -> None: yield g - REGISTRY.register(JemallocCollector()) + # This is a process-level metric, so it does not have the `SERVER_NAME_LABEL`. + REGISTRY.register(JemallocCollector()) # type: ignore[missing-server-name-label] logger.debug("Added jemalloc stats") diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index fd0811ca12..9309aa9394 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -23,6 +23,7 @@ import logging from typing import ( TYPE_CHECKING, Any, + Awaitable, Callable, Collection, Dict, @@ -80,7 +81,9 @@ from synapse.logging.context import ( make_deferred_yieldable, run_in_background, ) -from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process as _run_as_background_process, +) from synapse.module_api.callbacks.account_validity_callbacks import ( IS_USER_EXPIRED_CALLBACK, ON_LEGACY_ADMIN_REQUEST, @@ -158,6 +161,9 @@ from synapse.util.caches.descriptors import CachedFunction, cached as _cached from synapse.util.frozenutils import freeze if TYPE_CHECKING: + # Old versions don't have `LiteralString` + from typing_extensions import LiteralString + from synapse.app.generic_worker import GenericWorkerStore from synapse.server import HomeServer @@ -216,6 +222,65 @@ class UserIpAndAgent: last_seen: int +def run_as_background_process( + desc: "LiteralString", + func: Callable[..., Awaitable[Optional[T]]], + *args: Any, + bg_start_span: bool = True, + **kwargs: Any, +) -> "defer.Deferred[Optional[T]]": + """ + XXX: Deprecated: use `ModuleApi.run_as_background_process` instead. + + Run the given function in its own logcontext, with resource metrics + + This should be used to wrap processes which are fired off to run in the + background, instead of being associated with a particular request. + + It returns a Deferred which completes when the function completes, but it doesn't + follow the synapse logcontext rules, which makes it appropriate for passing to + clock.looping_call and friends (or for firing-and-forgetting in the middle of a + normal synapse async function). + + Args: + desc: a description for this background process type + server_name: The homeserver name that this background process is being run for + (this should be `hs.hostname`). + func: a function, which may return a Deferred or a coroutine + bg_start_span: Whether to start an opentracing span. Defaults to True. + Should only be disabled for processes that will not log to or tag + a span. + args: positional args for func + kwargs: keyword args for func + + Returns: + Deferred which returns the result of func, or `None` if func raises. + Note that the returned Deferred does not follow the synapse logcontext + rules. + """ + + logger.warning( + "Using deprecated `run_as_background_process` that's exported from the Module API. " + "Prefer `ModuleApi.run_as_background_process` instead.", + ) + + # Historically, since this function is exported from the module API, we can't just + # change the signature to require a `server_name` argument. Since + # `run_as_background_process` internally in Synapse requires `server_name` now, we + # just have to stub this out with a placeholder value and tell people to use the new + # function instead. + stub_server_name = "synapse_module_running_from_unknown_server" + + return _run_as_background_process( + desc, + stub_server_name, + func, + *args, + bg_start_span=bg_start_span, + **kwargs, + ) + + def cached( *, max_entries: int = 1000, @@ -277,7 +342,9 @@ class ModuleApi: self._device_handler = hs.get_device_handler() self.custom_template_dir = hs.config.server.custom_template_directory self._callbacks = hs.get_module_api_callbacks() - self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled + self._auth_delegation_enabled = ( + hs.config.mas.enabled or hs.config.experimental.msc3861.enabled + ) self._event_serializer = hs.get_event_client_serializer() try: @@ -484,7 +551,7 @@ class ModuleApi: Added in Synapse v1.46.0. """ - if self.msc3861_oauth_delegation_enabled: + if self._auth_delegation_enabled: raise ConfigError( "Cannot use password auth provider callbacks when OAuth delegation is enabled" ) @@ -1323,7 +1390,7 @@ class ModuleApi: if self._hs.config.worker.run_background_tasks or run_on_all_instances: self._clock.looping_call( - run_as_background_process, + self.run_as_background_process, msec, desc, lambda: maybe_awaitable(f(*args, **kwargs)), @@ -1381,7 +1448,7 @@ class ModuleApi: return self._clock.call_later( # convert ms to seconds as needed by call_later. msec * 0.001, - run_as_background_process, + self.run_as_background_process, desc, lambda: maybe_awaitable(f(*args, **kwargs)), ) @@ -1588,6 +1655,44 @@ class ModuleApi: return {key: state_events[event_id] for key, event_id in state_ids.items()} + def run_as_background_process( + self, + desc: "LiteralString", + func: Callable[..., Awaitable[Optional[T]]], + *args: Any, + bg_start_span: bool = True, + **kwargs: Any, + ) -> "defer.Deferred[Optional[T]]": + """Run the given function in its own logcontext, with resource metrics + + This should be used to wrap processes which are fired off to run in the + background, instead of being associated with a particular request. + + It returns a Deferred which completes when the function completes, but it doesn't + follow the synapse logcontext rules, which makes it appropriate for passing to + clock.looping_call and friends (or for firing-and-forgetting in the middle of a + normal synapse async function). + + Args: + desc: a description for this background process type + server_name: The homeserver name that this background process is being run for + (this should be `hs.hostname`). + func: a function, which may return a Deferred or a coroutine + bg_start_span: Whether to start an opentracing span. Defaults to True. + Should only be disabled for processes that will not log to or tag + a span. + args: positional args for func + kwargs: keyword args for func + + Returns: + Deferred which returns the result of func, or `None` if func raises. + Note that the returned Deferred does not follow the synapse logcontext + rules. + """ + return _run_as_background_process( + desc, self.server_name, func, *args, bg_start_span=bg_start_span, **kwargs + ) + async def defer_to_thread( self, f: Callable[P, T], diff --git a/synapse/notifier.py b/synapse/notifier.py index 6190432b87..448a715e2a 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -29,6 +29,7 @@ from typing import ( Iterable, List, Literal, + Mapping, Optional, Set, Tuple, @@ -50,7 +51,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.logging import issue9533_logger from synapse.logging.context import PreserveLoggingContext from synapse.logging.opentracing import log_kv, start_active_span -from synapse.metrics import LaterGauge +from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.streams.config import PaginationConfig from synapse.types import ( ISynapseReactor, @@ -74,10 +75,15 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -notified_events_counter = Counter("synapse_notifier_notified_events", "") +# FIXME: Unused metric, remove if not needed. +notified_events_counter = Counter( + "synapse_notifier_notified_events", "", labelnames=[SERVER_NAME_LABEL] +) users_woken_by_stream_counter = Counter( - "synapse_notifier_users_woken_by_stream", "", ["stream"] + "synapse_notifier_users_woken_by_stream", + "", + labelnames=["stream", SERVER_NAME_LABEL], ) T = TypeVar("T") @@ -224,6 +230,7 @@ class Notifier: self.room_to_user_streams: Dict[str, Set[_NotifierUserStream]] = {} self.hs = hs + self.server_name = hs.hostname self._storage_controllers = hs.get_storage_controllers() self.event_sources = hs.get_event_sources() self.store = hs.get_datastores().main @@ -257,7 +264,10 @@ class Notifier: # This is not a very cheap test to perform, but it's only executed # when rendering the metrics page, which is likely once per minute at # most when scraping it. - def count_listeners() -> int: + # + # Ideally, we'd use `Mapping[Tuple[str], int]` here but mypy doesn't like it. + # This is close enough and better than a type ignore. + def count_listeners() -> Mapping[Tuple[str, ...], int]: all_user_streams: Set[_NotifierUserStream] = set() for streams in list(self.room_to_user_streams.values()): @@ -265,18 +275,34 @@ class Notifier: for stream in list(self.user_to_user_stream.values()): all_user_streams.add(stream) - return sum(stream.count_listeners() for stream in all_user_streams) - - LaterGauge("synapse_notifier_listeners", "", [], count_listeners) + return { + (self.server_name,): sum( + stream.count_listeners() for stream in all_user_streams + ) + } LaterGauge( - "synapse_notifier_rooms", - "", - [], - lambda: count(bool, list(self.room_to_user_streams.values())), + name="synapse_notifier_listeners", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=count_listeners, + ) + + LaterGauge( + name="synapse_notifier_rooms", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: { + (self.server_name,): count( + bool, list(self.room_to_user_streams.values()) + ) + }, ) LaterGauge( - "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream) + name="synapse_notifier_users", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): len(self.user_to_user_stream)}, ) def add_replication_callback(self, cb: Callable[[], None]) -> None: @@ -350,9 +376,10 @@ class Notifier: for listener in listeners: listener.callback(current_token) - users_woken_by_stream_counter.labels(StreamKeyType.UN_PARTIAL_STATED_ROOMS).inc( - len(user_streams) - ) + users_woken_by_stream_counter.labels( + stream=StreamKeyType.UN_PARTIAL_STATED_ROOMS, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc(len(user_streams)) # Poke the replication so that other workers also see the write to # the un-partial-stated rooms stream. @@ -575,7 +602,10 @@ class Notifier: listener.callback(current_token) if user_streams: - users_woken_by_stream_counter.labels(stream_key).inc(len(user_streams)) + users_woken_by_stream_counter.labels( + stream=stream_key, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc(len(user_streams)) self.notify_replication() diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index fed9931930..bb9d5dbcaa 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -25,6 +25,7 @@ from typing import ( Any, Collection, Dict, + FrozenSet, List, Mapping, Optional, @@ -50,6 +51,7 @@ from synapse.event_auth import auth_types_for_event, get_user_power_level from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.metrics import SERVER_NAME_LABEL from synapse.state import CREATE_KEY, POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.invite_rule import InviteRule @@ -68,11 +70,17 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +# FIXME: Unused metric, remove if not needed. push_rules_invalidation_counter = Counter( - "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "" + "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", + "", + labelnames=[SERVER_NAME_LABEL], ) +# FIXME: Unused metric, remove if not needed. push_rules_state_size_counter = Counter( - "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "" + "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", + "", + labelnames=[SERVER_NAME_LABEL], ) @@ -470,8 +478,18 @@ class BulkPushRuleEvaluator: event.room_version.msc3931_push_features, self.hs.config.experimental.msc1767_enabled, # MSC3931 flag self.hs.config.experimental.msc4210_enabled, + self.hs.config.experimental.msc4306_enabled, ) + msc4306_thread_subscribers: Optional[FrozenSet[str]] = None + if self.hs.config.experimental.msc4306_enabled and thread_id != MAIN_TIMELINE: + # pull out, in batch, all local subscribers to this thread + # (in the common case, they will all be getting processed for push + # rules right now) + msc4306_thread_subscribers = await self.store.get_subscribers_to_thread( + event.room_id, thread_id + ) + for uid, rules in rules_by_user.items(): if event.sender == uid: continue @@ -496,7 +514,13 @@ class BulkPushRuleEvaluator: # current user, it'll be added to the dict later. actions_by_user[uid] = [] - actions = evaluator.run(rules, uid, display_name) + msc4306_thread_subscription_state: Optional[bool] = None + if msc4306_thread_subscribers is not None: + msc4306_thread_subscription_state = uid in msc4306_thread_subscribers + + actions = evaluator.run( + rules, uid, display_name, msc4306_thread_subscription_state + ) if "notify" in actions: # Push rules say we should notify the user of this event actions_by_user[uid] = actions diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 0a14c534f7..09ca14584a 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -68,6 +68,7 @@ class EmailPusher(Pusher): super().__init__(hs, pusher_config) self.mailer = mailer + self.server_name = hs.hostname self.store = self.hs.get_datastores().main self.email = pusher_config.pushkey self.timed_call: Optional[IDelayedCall] = None @@ -117,7 +118,7 @@ class EmailPusher(Pusher): if self._is_processing: return - run_as_background_process("emailpush.process", self._process) + run_as_background_process("emailpush.process", self.server_name, self._process) def _pause_processing(self) -> None: """Used by tests to temporarily pause processing of events. diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 7df8a128c9..5946a6e972 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -31,6 +31,7 @@ from twisted.internet.interfaces import IDelayedCall from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.logging import opentracing +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.storage.databases.main.event_push_actions import HttpPushAction @@ -46,21 +47,25 @@ logger = logging.getLogger(__name__) http_push_processed_counter = Counter( "synapse_http_httppusher_http_pushes_processed", "Number of push notifications successfully sent", + labelnames=[SERVER_NAME_LABEL], ) http_push_failed_counter = Counter( "synapse_http_httppusher_http_pushes_failed", "Number of push notifications which failed", + labelnames=[SERVER_NAME_LABEL], ) http_badges_processed_counter = Counter( "synapse_http_httppusher_badge_updates_processed", "Number of badge updates successfully sent", + labelnames=[SERVER_NAME_LABEL], ) http_badges_failed_counter = Counter( "synapse_http_httppusher_badge_updates_failed", "Number of badge updates which failed", + labelnames=[SERVER_NAME_LABEL], ) @@ -106,6 +111,7 @@ class HttpPusher(Pusher): def __init__(self, hs: "HomeServer", pusher_config: PusherConfig): super().__init__(hs, pusher_config) + self.server_name = hs.hostname self._storage_controllers = self.hs.get_storage_controllers() self.app_display_name = pusher_config.app_display_name self.device_display_name = pusher_config.device_display_name @@ -176,7 +182,9 @@ class HttpPusher(Pusher): # We could check the receipts are actually m.read receipts here, # but currently that's the only type of receipt anyway... - run_as_background_process("http_pusher.on_new_receipts", self._update_badge) + run_as_background_process( + "http_pusher.on_new_receipts", self.server_name, self._update_badge + ) async def _update_badge(self) -> None: # XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems @@ -211,7 +219,7 @@ class HttpPusher(Pusher): if self.failing_since and self.timed_call and self.timed_call.active(): return - run_as_background_process("httppush.process", self._process) + run_as_background_process("httppush.process", self.server_name, self._process) async def _process(self) -> None: # we should never get here if we are already processing @@ -265,7 +273,9 @@ class HttpPusher(Pusher): processed = await self._process_one(push_action) if processed: - http_push_processed_counter.inc() + http_push_processed_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc() self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC self.last_stream_ordering = push_action.stream_ordering pusher_still_exists = ( @@ -289,7 +299,9 @@ class HttpPusher(Pusher): self.app_id, self.pushkey, self.user_id, self.failing_since ) else: - http_push_failed_counter.inc() + http_push_failed_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc() if not self.failing_since: self.failing_since = self.clock.time_msec() await self.store.update_pusher_failing_since( @@ -540,9 +552,13 @@ class HttpPusher(Pusher): } try: await self.http_client.post_json_get_json(self.url, d) - http_badges_processed_counter.inc() + http_badges_processed_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc() except Exception as e: logger.warning( "Failed to send badge count to %s: %s %s", self.name, type(e), e ) - http_badges_failed_counter.inc() + http_badges_failed_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc() diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index fadba480dd..d76cc8237b 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -32,6 +32,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership, Ro from synapse.api.errors import StoreError from synapse.config.emailconfig import EmailSubjectConfig from synapse.events import EventBase +from synapse.metrics import SERVER_NAME_LABEL from synapse.push.presentable_names import ( calculate_room_name, descriptor_from_member_events, @@ -60,7 +61,7 @@ T = TypeVar("T") emails_sent_counter = Counter( "synapse_emails_sent_total", "Emails sent by type", - ["type"], + labelnames=["type", SERVER_NAME_LABEL], ) @@ -123,6 +124,7 @@ class Mailer: template_text: jinja2.Template, ): self.hs = hs + self.server_name = hs.hostname self.template_html = template_html self.template_text = template_text @@ -137,8 +139,6 @@ class Mailer: logger.info("Created Mailer for app_name %s", app_name) - emails_sent_counter.labels("password_reset") - async def send_password_reset_mail( self, email_address: str, token: str, client_secret: str, sid: str ) -> None: @@ -162,7 +162,10 @@ class Mailer: template_vars: TemplateVars = {"link": link} - emails_sent_counter.labels("password_reset").inc() + emails_sent_counter.labels( + type="password_reset", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() await self.send_email( email_address, @@ -171,8 +174,6 @@ class Mailer: template_vars, ) - emails_sent_counter.labels("registration") - async def send_registration_mail( self, email_address: str, token: str, client_secret: str, sid: str ) -> None: @@ -196,7 +197,10 @@ class Mailer: template_vars: TemplateVars = {"link": link} - emails_sent_counter.labels("registration").inc() + emails_sent_counter.labels( + type="registration", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() await self.send_email( email_address, @@ -205,8 +209,6 @@ class Mailer: template_vars, ) - emails_sent_counter.labels("already_in_use") - async def send_already_in_use_mail(self, email_address: str) -> None: """Send an email if the address is already bound to an user account @@ -214,6 +216,11 @@ class Mailer: email_address: Email address we're sending to the "already in use" mail """ + emails_sent_counter.labels( + type="already_in_use", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() + await self.send_email( email_address, self.email_subjects.email_already_in_use @@ -221,8 +228,6 @@ class Mailer: {}, ) - emails_sent_counter.labels("add_threepid") - async def send_add_threepid_mail( self, email_address: str, token: str, client_secret: str, sid: str ) -> None: @@ -247,7 +252,10 @@ class Mailer: template_vars: TemplateVars = {"link": link} - emails_sent_counter.labels("add_threepid").inc() + emails_sent_counter.labels( + type="add_threepid", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() await self.send_email( email_address, @@ -256,8 +264,6 @@ class Mailer: template_vars, ) - emails_sent_counter.labels("notification") - async def send_notification_mail( self, app_id: str, @@ -352,7 +358,10 @@ class Mailer: "reason": reason, } - emails_sent_counter.labels("notification").inc() + emails_sent_counter.labels( + type="notification", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() await self.send_email( email_address, summary_text, template_vars, unsubscribe_link diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 59550a41de..d1f79ec999 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Dict, Iterable, Optional from prometheus_client import Gauge from synapse.api.errors import Codes, SynapseError +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -47,7 +48,9 @@ logger = logging.getLogger(__name__) synapse_pushers = Gauge( - "synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"] + "synapse_pushers", + "Number of active synapse pushers", + labelnames=["kind", "app_id", SERVER_NAME_LABEL], ) @@ -68,6 +71,9 @@ class PusherPool: def __init__(self, hs: "HomeServer"): self.hs = hs + self.server_name = ( + hs.hostname + ) # nb must be called this for @wrap_as_background_process self.pusher_factory = PusherFactory(hs) self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() @@ -106,7 +112,9 @@ class PusherPool: if not self._should_start_pushers: logger.info("Not starting pushers because they are disabled in the config") return - run_as_background_process("start_pushers", self._start_pushers) + run_as_background_process( + "start_pushers", self.server_name, self._start_pushers + ) async def add_or_update_pusher( self, @@ -422,11 +430,17 @@ class PusherPool: previous_pusher.on_stop() synapse_pushers.labels( - type(previous_pusher).__name__, previous_pusher.app_id + kind=type(previous_pusher).__name__, + app_id=previous_pusher.app_id, + **{SERVER_NAME_LABEL: self.server_name}, ).dec() byuser[appid_pushkey] = pusher - synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc() + synapse_pushers.labels( + kind=type(pusher).__name__, + app_id=pusher.app_id, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey) @@ -485,4 +499,8 @@ class PusherPool: pusher = byuser.pop(appid_pushkey) pusher.on_stop() - synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() + synapse_pushers.labels( + kind=type(pusher).__name__, + app_id=pusher.app_id, + **{SERVER_NAME_LABEL: self.server_name}, + ).dec() diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index ab2e6707cd..68cc6ce1fc 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -32,7 +32,6 @@ from synapse.replication.http import ( presence, push, register, - send_event, send_events, state, streams, @@ -51,7 +50,6 @@ class ReplicationRestResource(JsonResource): self.register_servlets(hs) def register_servlets(self, hs: "HomeServer") -> None: - send_event.register_servlets(hs, self) send_events.register_servlets(hs, self) federation.register_servlets(hs, self) presence.register_servlets(hs, self) diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 31204a8384..0850a99e0c 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -38,6 +38,7 @@ from synapse.http.servlet import parse_json_object_from_request from synapse.http.site import SynapseRequest from synapse.logging import opentracing from synapse.logging.opentracing import trace_with_opname +from synapse.metrics import SERVER_NAME_LABEL from synapse.types import JsonDict from synapse.util.caches.response_cache import ResponseCache from synapse.util.cancellation import is_function_cancellable @@ -51,13 +52,13 @@ logger = logging.getLogger(__name__) _pending_outgoing_requests = Gauge( "synapse_pending_outgoing_replication_requests", "Number of active outgoing replication requests, by replication method name", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], ) _outgoing_request_counter = Counter( "synapse_outgoing_replication_requests", "Number of outgoing replication requests, by replication method name and result", - ["name", "code"], + labelnames=["name", "code", SERVER_NAME_LABEL], ) @@ -205,13 +206,17 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): parameter to specify which instance to hit (the instance must be in the `instance_map` config). """ + server_name = hs.hostname clock = hs.get_clock() client = hs.get_replication_client() local_instance_name = hs.get_instance_name() instance_map = hs.config.worker.instance_map - outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME) + outgoing_gauge = _pending_outgoing_requests.labels( + name=cls.NAME, + **{SERVER_NAME_LABEL: server_name}, + ) replication_secret = None if hs.config.worker.worker_replication_secret: @@ -333,15 +338,27 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # We convert to SynapseError as we know that it was a SynapseError # on the main process that we should send to the client. (And # importantly, not stack traces everywhere) - _outgoing_request_counter.labels(cls.NAME, e.code).inc() + _outgoing_request_counter.labels( + name=cls.NAME, + code=e.code, + **{SERVER_NAME_LABEL: server_name}, + ).inc() raise e.to_synapse_error() except Exception as e: - _outgoing_request_counter.labels(cls.NAME, "ERR").inc() + _outgoing_request_counter.labels( + name=cls.NAME, + code="ERR", + **{SERVER_NAME_LABEL: server_name}, + ).inc() raise SynapseError( 502, f"Failed to talk to {instance_name} process" ) from e - _outgoing_request_counter.labels(cls.NAME, 200).inc() + _outgoing_request_counter.labels( + name=cls.NAME, + code=200, + **{SERVER_NAME_LABEL: server_name}, + ).inc() # Wait on any streams that the remote may have written to. for stream_name, position in result.pop( diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py deleted file mode 100644 index edda419a03..0000000000 --- a/synapse/replication/http/send_event.py +++ /dev/null @@ -1,164 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# . -# -# Originally licensed under the Apache License, Version 2.0: -# . -# -# [This file includes modifications made by New Vector Limited] -# -# - -import logging -from typing import TYPE_CHECKING, List, Tuple - -from twisted.web.server import Request - -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.events import EventBase, make_event_from_dict -from synapse.events.snapshot import EventContext -from synapse.http.server import HttpServer -from synapse.replication.http._base import ReplicationEndpoint -from synapse.types import JsonDict, Requester, UserID -from synapse.util.metrics import Measure - -if TYPE_CHECKING: - from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore - -logger = logging.getLogger(__name__) - - -class ReplicationSendEventRestServlet(ReplicationEndpoint): - """Handles events newly created on workers, including persisting and - notifying. - - The API looks like: - - POST /_synapse/replication/send_event/:event_id/:txn_id - - { - "event": { .. serialized event .. }, - "room_version": .., // "1", "2", "3", etc: the version of the room - // containing the event - "event_format_version": .., // 1,2,3 etc: the event format version - "internal_metadata": { .. serialized internal_metadata .. }, - "outlier": true|false, - "rejected_reason": .., // The event.rejected_reason field - "context": { .. serialized event context .. }, - "requester": { .. serialized requester .. }, - "ratelimit": true, - "extra_users": [], - } - - 200 OK - - { "stream_id": 12345, "event_id": "$abcdef..." } - - Responds with a 409 when a `PartialStateConflictError` is raised due to an event - context that needs to be recomputed due to the un-partial stating of a room. - - The returned event ID may not match the sent event if it was deduplicated. - """ - - NAME = "send_event" - PATH_ARGS = ("event_id",) - - def __init__(self, hs: "HomeServer"): - super().__init__(hs) - - self.server_name = hs.hostname - self.event_creation_handler = hs.get_event_creation_handler() - self.store = hs.get_datastores().main - self._storage_controllers = hs.get_storage_controllers() - self.clock = hs.get_clock() - - @staticmethod - async def _serialize_payload( # type: ignore[override] - event_id: str, - store: "DataStore", - event: EventBase, - context: EventContext, - requester: Requester, - ratelimit: bool, - extra_users: List[UserID], - ) -> JsonDict: - """ - Args: - event_id - store - requester - event - context - ratelimit - extra_users: Any extra users to notify about event - """ - serialized_context = await context.serialize(event, store) - - payload = { - "event": event.get_pdu_json(), - "room_version": event.room_version.identifier, - "event_format_version": event.format_version, - "internal_metadata": event.internal_metadata.get_dict(), - "outlier": event.internal_metadata.is_outlier(), - "rejected_reason": event.rejected_reason, - "context": serialized_context, - "requester": requester.serialize(), - "ratelimit": ratelimit, - "extra_users": [u.to_string() for u in extra_users], - } - - return payload - - async def _handle_request( # type: ignore[override] - self, request: Request, content: JsonDict, event_id: str - ) -> Tuple[int, JsonDict]: - with Measure( - self.clock, name="repl_send_event_parse", server_name=self.server_name - ): - event_dict = content["event"] - room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]] - internal_metadata = content["internal_metadata"] - rejected_reason = content["rejected_reason"] - - event = make_event_from_dict( - event_dict, room_ver, internal_metadata, rejected_reason - ) - event.internal_metadata.outlier = content["outlier"] - - requester = Requester.deserialize(self.store, content["requester"]) - context = EventContext.deserialize( - self._storage_controllers, content["context"] - ) - - ratelimit = content["ratelimit"] - extra_users = [UserID.from_string(u) for u in content["extra_users"]] - - logger.info( - "Got event to send with ID: %s into room: %s", event.event_id, event.room_id - ) - - event = await self.event_creation_handler.persist_and_notify_client_events( - requester, [(event, context)], ratelimit=ratelimit, extra_users=extra_users - ) - - return ( - 200, - { - "stream_id": event.internal_metadata.stream_ordering, - "event_id": event.event_id, - }, - ) - - -def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - ReplicationSendEventRestServlet(hs).register(http_server) diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index b99f11f7c6..ee9250cf7d 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -413,6 +413,7 @@ class FederationSenderHandler: def __init__(self, hs: "HomeServer"): assert hs.should_send_federation() + self.server_name = hs.hostname self.store = hs.get_datastores().main self._is_mine_id = hs.is_mine_id self._hs = hs @@ -503,7 +504,9 @@ class FederationSenderHandler: # no need to queue up another task. return - run_as_background_process("_save_and_send_ack", self._save_and_send_ack) + run_as_background_process( + "_save_and_send_ack", self.server_name, self._save_and_send_ack + ) async def _save_and_send_ack(self) -> None: """Save the current federation position in the database and send an ACK diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py index a95771b5f6..497b26fcaf 100644 --- a/synapse/replication/tcp/external_cache.py +++ b/synapse/replication/tcp/external_cache.py @@ -26,6 +26,7 @@ from prometheus_client import Counter, Histogram from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable +from synapse.metrics import SERVER_NAME_LABEL from synapse.util import json_decoder, json_encoder if TYPE_CHECKING: @@ -36,19 +37,19 @@ if TYPE_CHECKING: set_counter = Counter( "synapse_external_cache_set", "Number of times we set a cache", - labelnames=["cache_name"], + labelnames=["cache_name", SERVER_NAME_LABEL], ) get_counter = Counter( "synapse_external_cache_get", "Number of times we get a cache", - labelnames=["cache_name", "hit"], + labelnames=["cache_name", "hit", SERVER_NAME_LABEL], ) response_timer = Histogram( "synapse_external_cache_response_time_seconds", "Time taken to get a response from Redis for a cache get/set request", - labelnames=["method"], + labelnames=["method", SERVER_NAME_LABEL], buckets=( 0.001, 0.002, @@ -69,6 +70,8 @@ class ExternalCache: """ def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname + if hs.config.redis.redis_enabled: self._redis_connection: Optional["ConnectionHandler"] = ( hs.get_outbound_redis_connection() @@ -93,7 +96,9 @@ class ExternalCache: if self._redis_connection is None: return - set_counter.labels(cache_name).inc() + set_counter.labels( + cache_name=cache_name, **{SERVER_NAME_LABEL: self.server_name} + ).inc() # txredisapi requires the value to be string, bytes or numbers, so we # encode stuff in JSON. @@ -105,7 +110,9 @@ class ExternalCache: "ExternalCache.set", tags={opentracing.SynapseTags.CACHE_NAME: cache_name}, ): - with response_timer.labels("set").time(): + with response_timer.labels( + method="set", **{SERVER_NAME_LABEL: self.server_name} + ).time(): return await make_deferred_yieldable( self._redis_connection.set( self._get_redis_key(cache_name, key), @@ -124,14 +131,20 @@ class ExternalCache: "ExternalCache.get", tags={opentracing.SynapseTags.CACHE_NAME: cache_name}, ): - with response_timer.labels("get").time(): + with response_timer.labels( + method="get", **{SERVER_NAME_LABEL: self.server_name} + ).time(): result = await make_deferred_yieldable( self._redis_connection.get(self._get_redis_key(cache_name, key)) ) logger.debug("Got cache result %s %s: %r", cache_name, key, result) - get_counter.labels(cache_name, result is not None).inc() + get_counter.labels( + cache_name=cache_name, + hit=result is not None, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() if not result: return None diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 3611c678c2..0f14c7e380 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -40,7 +40,7 @@ from prometheus_client import Counter from twisted.internet.protocol import ReconnectingClientFactory -from synapse.metrics import LaterGauge +from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, @@ -85,13 +85,26 @@ logger = logging.getLogger(__name__) # number of updates received for each RDATA stream inbound_rdata_count = Counter( - "synapse_replication_tcp_protocol_inbound_rdata_count", "", ["stream_name"] + "synapse_replication_tcp_protocol_inbound_rdata_count", + "", + labelnames=["stream_name", SERVER_NAME_LABEL], +) +user_sync_counter = Counter( + "synapse_replication_tcp_resource_user_sync", "", labelnames=[SERVER_NAME_LABEL] +) +federation_ack_counter = Counter( + "synapse_replication_tcp_resource_federation_ack", + "", + labelnames=[SERVER_NAME_LABEL], +) +# FIXME: Unused metric, remove if not needed. +remove_pusher_counter = Counter( + "synapse_replication_tcp_resource_remove_pusher", "", labelnames=[SERVER_NAME_LABEL] ) -user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "") -federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "") -remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "") -user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "") +user_ip_cache_counter = Counter( + "synapse_replication_tcp_resource_user_ip_cache", "", labelnames=[SERVER_NAME_LABEL] +) # the type of the entries in _command_queues_by_stream @@ -106,6 +119,7 @@ class ReplicationCommandHandler: """ def __init__(self, hs: "HomeServer"): + self.server_name = hs.hostname self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() self._store = hs.get_datastores().main @@ -230,10 +244,10 @@ class ReplicationCommandHandler: self._connections: List[IReplicationConnection] = [] LaterGauge( - "synapse_replication_tcp_resource_total_connections", - "", - [], - lambda: len(self._connections), + name="synapse_replication_tcp_resource_total_connections", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): len(self._connections)}, ) # When POSITION or RDATA commands arrive, we stick them in a queue and process @@ -253,11 +267,11 @@ class ReplicationCommandHandler: self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {} LaterGauge( - "synapse_replication_tcp_command_queue", - "Number of inbound RDATA/POSITION commands queued for processing", - ["stream_name"], - lambda: { - (stream_name,): len(queue) + name="synapse_replication_tcp_command_queue", + desc="Number of inbound RDATA/POSITION commands queued for processing", + labelnames=["stream_name", SERVER_NAME_LABEL], + caller=lambda: { + (stream_name, self.server_name): len(queue) for stream_name, queue in self._command_queues_by_stream.items() }, ) @@ -340,7 +354,10 @@ class ReplicationCommandHandler: # fire off a background process to start processing the queue. run_as_background_process( - "process-replication-data", self._unsafe_process_queue, stream_name + "process-replication-data", + self.server_name, + self._unsafe_process_queue, + stream_name, ) async def _unsafe_process_queue(self, stream_name: str) -> None: @@ -456,7 +473,7 @@ class ReplicationCommandHandler: def on_USER_SYNC( self, conn: IReplicationConnection, cmd: UserSyncCommand ) -> Optional[Awaitable[None]]: - user_sync_counter.inc() + user_sync_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc() if self._is_presence_writer: return self._presence_handler.update_external_syncs_row( @@ -480,7 +497,7 @@ class ReplicationCommandHandler: def on_FEDERATION_ACK( self, conn: IReplicationConnection, cmd: FederationAckCommand ) -> None: - federation_ack_counter.inc() + federation_ack_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc() if self._federation_sender: self._federation_sender.federation_ack(cmd.instance_name, cmd.token) @@ -488,7 +505,7 @@ class ReplicationCommandHandler: def on_USER_IP( self, conn: IReplicationConnection, cmd: UserIpCommand ) -> Optional[Awaitable[None]]: - user_ip_cache_counter.inc() + user_ip_cache_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc() if self._is_master or self._should_insert_client_ips: # We make a point of only returning an awaitable if there's actually @@ -528,7 +545,9 @@ class ReplicationCommandHandler: return stream_name = cmd.stream_name - inbound_rdata_count.labels(stream_name).inc() + inbound_rdata_count.labels( + stream_name=stream_name, **{SERVER_NAME_LABEL: self.server_name} + ).inc() # We put the received command into a queue here for two reasons: # 1. so we don't try and concurrently handle multiple rows for the diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index fb9c539122..969f0303e0 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -39,7 +39,7 @@ from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure from synapse.logging.context import PreserveLoggingContext -from synapse.metrics import LaterGauge +from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.metrics.background_process_metrics import ( BackgroundProcessLoggingContext, run_as_background_process, @@ -64,19 +64,21 @@ if TYPE_CHECKING: connection_close_counter = Counter( - "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"] + "synapse_replication_tcp_protocol_close_reason", + "", + labelnames=["reason_type", SERVER_NAME_LABEL], ) tcp_inbound_commands_counter = Counter( "synapse_replication_tcp_protocol_inbound_commands", "Number of commands received from replication, by command and name of process connected to", - ["command", "name"], + labelnames=["command", "name", SERVER_NAME_LABEL], ) tcp_outbound_commands_counter = Counter( "synapse_replication_tcp_protocol_outbound_commands", "Number of commands sent to replication, by command and name of process connected to", - ["command", "name"], + labelnames=["command", "name", SERVER_NAME_LABEL], ) # A list of all connected protocols. This allows us to send metrics about the @@ -137,7 +139,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): max_line_buffer = 10000 - def __init__(self, clock: Clock, handler: "ReplicationCommandHandler"): + def __init__( + self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler" + ): + self.server_name = server_name self.clock = clock self.command_handler = handler @@ -166,7 +171,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # capture the sentinel context as its containing context and won't prevent # GC of / unintentionally reactivate what would be the current context. self._logging_context = BackgroundProcessLoggingContext( - "replication-conn", self.conn_id + name="replication-conn", + server_name=self.server_name, + instance_id=self.conn_id, ) def connectionMade(self) -> None: @@ -244,7 +251,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self.last_received_command = self.clock.time_msec() - tcp_inbound_commands_counter.labels(cmd.NAME, self.name).inc() + tcp_inbound_commands_counter.labels( + command=cmd.NAME, + name=self.name, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() self.handle_command(cmd) @@ -280,7 +291,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): if isawaitable(res): run_as_background_process( - "replication-" + cmd.get_logcontext_id(), lambda: res + "replication-" + cmd.get_logcontext_id(), + self.server_name, + lambda: res, ) handled = True @@ -318,7 +331,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): self._queue_command(cmd) return - tcp_outbound_commands_counter.labels(cmd.NAME, self.name).inc() + tcp_outbound_commands_counter.labels( + command=cmd.NAME, + name=self.name, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() string = "%s %s" % (cmd.NAME, cmd.to_line()) if "\n" in string: @@ -390,9 +407,15 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): logger.info("[%s] Replication connection closed: %r", self.id(), reason) if isinstance(reason, Failure): assert reason.type is not None - connection_close_counter.labels(reason.type.__name__).inc() + connection_close_counter.labels( + reason_type=reason.type.__name__, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() else: - connection_close_counter.labels(reason.__class__.__name__).inc() # type: ignore[unreachable] + connection_close_counter.labels( # type: ignore[unreachable] + reason_type=reason.__class__.__name__, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() try: # Remove us from list of connections to be monitored @@ -449,7 +472,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): def __init__( self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler" ): - super().__init__(clock, handler) + super().__init__(server_name, clock, handler) self.server_name = server_name @@ -474,7 +497,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): clock: Clock, command_handler: "ReplicationCommandHandler", ): - super().__init__(clock, command_handler) + super().__init__(server_name, clock, command_handler) self.client_name = client_name self.server_name = server_name @@ -501,10 +524,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): # The following simply registers metrics for the replication connections pending_commands = LaterGauge( - "synapse_replication_tcp_protocol_pending_commands", - "", - ["name"], - lambda: {(p.name,): len(p.pending_commands) for p in connected_connections}, + name="synapse_replication_tcp_protocol_pending_commands", + desc="", + labelnames=["name", SERVER_NAME_LABEL], + caller=lambda: { + (p.name, p.server_name): len(p.pending_commands) for p in connected_connections + }, ) @@ -516,10 +541,12 @@ def transport_buffer_size(protocol: BaseReplicationStreamProtocol) -> int: transport_send_buffer = LaterGauge( - "synapse_replication_tcp_protocol_transport_send_buffer", - "", - ["name"], - lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections}, + name="synapse_replication_tcp_protocol_transport_send_buffer", + desc="", + labelnames=["name", SERVER_NAME_LABEL], + caller=lambda: { + (p.name, p.server_name): transport_buffer_size(p) for p in connected_connections + }, ) @@ -541,22 +568,22 @@ def transport_kernel_read_buffer_size( tcp_transport_kernel_send_buffer = LaterGauge( - "synapse_replication_tcp_protocol_transport_kernel_send_buffer", - "", - ["name"], - lambda: { - (p.name,): transport_kernel_read_buffer_size(p, False) + name="synapse_replication_tcp_protocol_transport_kernel_send_buffer", + desc="", + labelnames=["name", SERVER_NAME_LABEL], + caller=lambda: { + (p.name, p.server_name): transport_kernel_read_buffer_size(p, False) for p in connected_connections }, ) tcp_transport_kernel_read_buffer = LaterGauge( - "synapse_replication_tcp_protocol_transport_kernel_read_buffer", - "", - ["name"], - lambda: { - (p.name,): transport_kernel_read_buffer_size(p, True) + name="synapse_replication_tcp_protocol_transport_kernel_read_buffer", + desc="", + labelnames=["name", SERVER_NAME_LABEL], + caller=lambda: { + (p.name, p.server_name): transport_kernel_read_buffer_size(p, True) for p in connected_connections }, ) diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index c4601a6141..aba79b2378 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -37,6 +37,7 @@ from twisted.internet.interfaces import IAddress, IConnector from twisted.python.failure import Failure from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import ( BackgroundProcessLoggingContext, run_as_background_process, @@ -97,6 +98,9 @@ class RedisSubscriber(SubscriberProtocol): immediately after initialisation. Attributes: + server_name: The homeserver name of the Synapse instance that this connection + is associated with. This is used to label metrics and should be set to + `hs.hostname`. synapse_handler: The command handler to handle incoming commands. synapse_stream_prefix: The *redis* stream name to subscribe to and publish from (not anything to do with Synapse replication streams). @@ -104,6 +108,7 @@ class RedisSubscriber(SubscriberProtocol): commands. """ + server_name: str synapse_handler: "ReplicationCommandHandler" synapse_stream_prefix: str synapse_channel_names: List[str] @@ -114,18 +119,36 @@ class RedisSubscriber(SubscriberProtocol): # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. - with PreserveLoggingContext(): - # thanks to `PreserveLoggingContext()`, the new logcontext is guaranteed to - # capture the sentinel context as its containing context and won't prevent - # GC of / unintentionally reactivate what would be the current context. - self._logging_context = BackgroundProcessLoggingContext( - "replication_command_handler" - ) + self._logging_context: Optional[BackgroundProcessLoggingContext] = None + + def _get_logging_context(self) -> BackgroundProcessLoggingContext: + """ + We lazily create the logging context so that `self.server_name` is set and + available. See `RedisDirectTcpReplicationClientFactory.buildProtocol` for more + details on why we set `self.server_name` after the fact instead of in the + constructor. + """ + assert self.server_name is not None, ( + "self.server_name must be set before using _get_logging_context()" + ) + if self._logging_context is None: + # a logcontext which we use for processing incoming commands. We declare it as a + # background process so that the CPU stats get reported to prometheus. + with PreserveLoggingContext(): + # thanks to `PreserveLoggingContext()`, the new logcontext is guaranteed to + # capture the sentinel context as its containing context and won't prevent + # GC of / unintentionally reactivate what would be the current context. + self._logging_context = BackgroundProcessLoggingContext( + name="replication_command_handler", server_name=self.server_name + ) + return self._logging_context def connectionMade(self) -> None: logger.info("Connected to redis") super().connectionMade() - run_as_background_process("subscribe-replication", self._send_subscribe) + run_as_background_process( + "subscribe-replication", self.server_name, self._send_subscribe + ) async def _send_subscribe(self) -> None: # it's important to make sure that we only send the REPLICATE command once we @@ -152,7 +175,7 @@ class RedisSubscriber(SubscriberProtocol): def messageReceived(self, pattern: str, channel: str, message: str) -> None: """Received a message from redis.""" - with PreserveLoggingContext(self._logging_context): + with PreserveLoggingContext(self._get_logging_context()): self._parse_and_dispatch_message(message) def _parse_and_dispatch_message(self, message: str) -> None: @@ -171,7 +194,11 @@ class RedisSubscriber(SubscriberProtocol): # We use "redis" as the name here as we don't have 1:1 connections to # remote instances. - tcp_inbound_commands_counter.labels(cmd.NAME, "redis").inc() + tcp_inbound_commands_counter.labels( + command=cmd.NAME, + name="redis", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() self.handle_command(cmd) @@ -197,7 +224,7 @@ class RedisSubscriber(SubscriberProtocol): if isawaitable(res): run_as_background_process( - "replication-" + cmd.get_logcontext_id(), lambda: res + "replication-" + cmd.get_logcontext_id(), self.server_name, lambda: res ) def connectionLost(self, reason: Failure) -> None: # type: ignore[override] @@ -207,7 +234,7 @@ class RedisSubscriber(SubscriberProtocol): # mark the logging context as finished by triggering `__exit__()` with PreserveLoggingContext(): - with self._logging_context: + with self._get_logging_context(): pass # the sentinel context is now active, which may not be correct. # PreserveLoggingContext() will restore the correct logging context. @@ -219,7 +246,11 @@ class RedisSubscriber(SubscriberProtocol): cmd: The command to send """ run_as_background_process( - "send-cmd", self._async_send_command, cmd, bg_start_span=False + "send-cmd", + self.server_name, + self._async_send_command, + cmd, + bg_start_span=False, ) async def _async_send_command(self, cmd: Command) -> None: @@ -232,7 +263,11 @@ class RedisSubscriber(SubscriberProtocol): # We use "redis" as the name here as we don't have 1:1 connections to # remote instances. - tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc() + tcp_outbound_commands_counter.labels( + command=cmd.NAME, + name="redis", + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() channel_name = cmd.redis_channel_name(self.synapse_stream_prefix) @@ -275,6 +310,10 @@ class SynapseRedisFactory(RedisFactory): convertNumbers=convertNumbers, ) + self.server_name = ( + hs.hostname + ) # nb must be called this for @wrap_as_background_process + hs.get_clock().looping_call(self._send_ping, 30 * 1000) @wrap_as_background_process("redis_ping") @@ -350,6 +389,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): password=hs.config.redis.redis_password, ) + self.server_name = hs.hostname self.synapse_handler = hs.get_replication_command_handler() self.synapse_stream_prefix = hs.hostname self.synapse_channel_names = channel_names @@ -364,6 +404,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory): # as to do so would involve overriding `buildProtocol` entirely, however # the base method does some other things than just instantiating the # protocol. + p.server_name = self.server_name p.synapse_handler = self.synapse_handler p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection p.synapse_stream_prefix = self.synapse_stream_prefix diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 0080a76f6f..d800cfe6f6 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -29,6 +29,7 @@ from prometheus_client import Counter from twisted.internet.interfaces import IAddress from twisted.internet.protocol import ServerFactory +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import PositionCommand from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol @@ -40,7 +41,9 @@ if TYPE_CHECKING: from synapse.server import HomeServer stream_updates_counter = Counter( - "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] + "synapse_replication_tcp_resource_stream_updates", + "", + labelnames=["stream_name", SERVER_NAME_LABEL], ) logger = logging.getLogger(__name__) @@ -144,7 +147,9 @@ class ReplicationStreamer: logger.debug("Notifier poke loop already running") return - run_as_background_process("replication_notifier", self._run_notifier_loop) + run_as_background_process( + "replication_notifier", self.server_name, self._run_notifier_loop + ) async def _run_notifier_loop(self) -> None: self.is_looping = True @@ -224,7 +229,10 @@ class ReplicationStreamer: len(updates), current_token, ) - stream_updates_counter.labels(stream.NAME).inc(len(updates)) + stream_updates_counter.labels( + stream_name=stream.NAME, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc(len(updates)) else: # The token has advanced but there is no data to diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 9694fff4fe..ec7e935d6a 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -739,7 +739,7 @@ class ThreadSubscriptionsStream(_StreamFromIdGen): NAME = "thread_subscriptions" ROW_TYPE = ThreadSubscriptionsStreamRow - def __init__(self, hs: Any): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main super().__init__( hs.get_instance_name(), @@ -751,7 +751,7 @@ class ThreadSubscriptionsStream(_StreamFromIdGen): self, instance_name: str, from_token: int, to_token: int, limit: int ) -> StreamUpdateResult: updates = await self.store.get_updated_thread_subscriptions( - from_token, to_token, limit + from_id=from_token, to_id=to_token, limit=limit ) rows = [ ( diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index aeb4267bb7..a24ca09846 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -63,6 +63,7 @@ from synapse.rest.client import ( sync, tags, thirdparty, + thread_subscriptions, tokenrefresh, user_directory, versions, @@ -122,6 +123,7 @@ CLIENT_SERVLET_FUNCTIONS: Tuple[RegisterServletsFunc, ...] = ( login_token_request.register_servlets, rendezvous.register_servlets, auth_metadata.register_servlets, + thread_subscriptions.register_servlets, ) SERVLET_GROUPS: Dict[str, Iterable[RegisterServletsFunc]] = { diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 32df4b244c..d9a6e99c5d 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -272,11 +272,15 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: # Admin servlets below may not work on workers. if hs.config.worker.worker_app is not None: # Some admin servlets can be mounted on workers when MSC3861 is enabled. + # Note that this is only for MSC3861 mode, as modern MAS using the + # matrix_authentication_service integration uses the dedicated MAS API. if hs.config.experimental.msc3861.enabled: register_servlets_for_msc3861_delegation(hs, http_server) return + auth_delegated = hs.config.mas.enabled or hs.config.experimental.msc3861.enabled + register_servlets_for_client_rest_resource(hs, http_server) BlockRoomRestServlet(hs).register(http_server) ListRoomRestServlet(hs).register(http_server) @@ -287,10 +291,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: DeleteRoomStatusByRoomIdRestServlet(hs).register(http_server) JoinRoomAliasServlet(hs).register(http_server) VersionServlet(hs).register(http_server) - if not hs.config.experimental.msc3861.enabled: + if not auth_delegated: UserAdminServlet(hs).register(http_server) UserMembershipRestServlet(hs).register(http_server) - if not hs.config.experimental.msc3861.enabled: + if not auth_delegated: UserTokenRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) @@ -307,7 +311,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RoomEventContextServlet(hs).register(http_server) RateLimitRestServlet(hs).register(http_server) UsernameAvailableRestServlet(hs).register(http_server) - if not hs.config.experimental.msc3861.enabled: + if not auth_delegated: ListRegistrationTokensRestServlet(hs).register(http_server) NewRegistrationTokenRestServlet(hs).register(http_server) RegistrationTokenRestServlet(hs).register(http_server) @@ -341,16 +345,18 @@ def register_servlets_for_client_rest_resource( hs: "HomeServer", http_server: HttpServer ) -> None: """Register only the servlets which need to be exposed on /_matrix/client/xxx""" + auth_delegated = hs.config.mas.enabled or hs.config.experimental.msc3861.enabled + WhoisRestServlet(hs).register(http_server) PurgeHistoryStatusRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server) # The following resources can only be run on the main process. if hs.config.worker.worker_app is None: DeactivateAccountRestServlet(hs).register(http_server) - if not hs.config.experimental.msc3861.enabled: + if not auth_delegated: ResetPasswordRestServlet(hs).register(http_server) SearchUsersRestServlet(hs).register(http_server) - if not hs.config.experimental.msc3861.enabled: + if not auth_delegated: UserRegisterServlet(hs).register(http_server) AccountValidityRenewServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index efcc60a2de..5bed89c2c4 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -634,7 +634,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): for creator in creators: if self.is_mine_id(creator): # include the creator as they won't be in the PL users map. - admin_users.insert(0, creator) + admin_users.append(creator) if not admin_users: raise SynapseError( diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 8240b270ba..25a38dc4ac 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -42,6 +42,7 @@ from synapse.http.servlet import ( parse_strings_from_args, ) from synapse.http.site import SynapseRequest +from synapse.logging.loggers import ExplicitlyConfiguredLogger from synapse.rest.admin._base import ( admin_patterns, assert_requester_is_admin, @@ -60,6 +61,25 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +original_logger_class = logging.getLoggerClass() +# Because this can log sensitive information, use a custom logger class that only allows +# logging if the logger is explicitly configured. +logging.setLoggerClass(ExplicitlyConfiguredLogger) +user_registration_sensitive_debug_logger = logging.getLogger( + "synapse.rest.admin.users.registration_debug" +) +""" +A logger for debugging the user registration process. + +Because this can log sensitive information (such as passwords and +`registration_shared_secret`), we want people to explictly opt-in before seeing anything +in the logs. Requires explicitly setting `synapse.rest.admin.users.registration_debug` +in the logging configuration and does not inherit the log level from the parent logger. +""" +# Restore the original logger class +logging.setLoggerClass(original_logger_class) + + class UsersRestServletV2(RestServlet): PATTERNS = admin_patterns("/users$", "v2") @@ -89,7 +109,9 @@ class UsersRestServletV2(RestServlet): self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() self._msc3866_enabled = hs.config.experimental.msc3866.enabled - self._msc3861_enabled = hs.config.experimental.msc3861.enabled + self._auth_delegation_enabled = ( + hs.config.mas.enabled or hs.config.experimental.msc3861.enabled + ) async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) @@ -101,10 +123,10 @@ class UsersRestServletV2(RestServlet): name = parse_string(request, "name", encoding="utf-8") guests = parse_boolean(request, "guests", default=True) - if self._msc3861_enabled and guests: + if self._auth_delegation_enabled and guests: raise SynapseError( HTTPStatus.BAD_REQUEST, - "The guests parameter is not supported when MSC3861 is enabled.", + "The guests parameter is not supported when delegating to MAS.", errcode=Codes.INVALID_PARAM, ) @@ -635,6 +657,34 @@ class UserRegisterServlet(RestServlet): want_mac = want_mac_builder.hexdigest() if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")): + # If the sensitive debug logger is enabled, log the full details. + # + # For reference, the `user_registration_sensitive_debug_logger.debug(...)` + # call is enough to gate the logging of sensitive information unless + # explicitly enabled. We only have this if-statement to avoid logging the + # suggestion to enable the debug logger if you already have it enabled. + if user_registration_sensitive_debug_logger.isEnabledFor(logging.DEBUG): + user_registration_sensitive_debug_logger.debug( + "UserRegisterServlet: Incorrect HMAC digest: actual=%s, expected=%s, registration_shared_secret=%s, body=%s", + got_mac, + want_mac, + self.hs.config.registration.registration_shared_secret, + body, + ) + else: + # Otherwise, just log the non-sensitive essentials and advertise the + # debug logger for sensitive information. + logger.debug( + ( + "UserRegisterServlet: HMAC incorrect (username=%s): actual=%s, expected=%s - " + "If you need more information, explicitly enable the `synapse.rest.admin.users.registration_debug` " + "logger at the `DEBUG` level to log things like the full request body and " + "`registration_shared_secret` used to calculate the HMAC." + ), + username, + got_mac, + want_mac, + ) raise SynapseError(HTTPStatus.FORBIDDEN, "HMAC incorrect") should_issue_refresh_token = body.get("refresh_token", False) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 9d0649a505..d9f0c169e8 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -47,7 +47,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.http.site import SynapseRequest -from synapse.metrics import threepid_send_requests +from synapse.metrics import SERVER_NAME_LABEL, threepid_send_requests from synapse.push.mailer import Mailer from synapse.types import JsonDict from synapse.types.rest import RequestBodyModel @@ -76,6 +76,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs + self.server_name = hs.hostname self.datastore = hs.get_datastores().main self.config = hs.config self.identity_handler = hs.get_identity_handler() @@ -136,9 +137,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): self.mailer.send_password_reset_mail, body.next_link, ) - threepid_send_requests.labels(type="email", reason="password_reset").observe( - body.send_attempt - ) + threepid_send_requests.labels( + type="email", + reason="password_reset", + **{SERVER_NAME_LABEL: self.server_name}, + ).observe(body.send_attempt) # Wrap the session id in a JSON object return 200, {"sid": sid} @@ -325,6 +328,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs + self.server_name = hs.hostname self.config = hs.config self.identity_handler = hs.get_identity_handler() self.store = self.hs.get_datastores().main @@ -394,9 +398,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): body.next_link, ) - threepid_send_requests.labels(type="email", reason="add_threepid").observe( - body.send_attempt - ) + threepid_send_requests.labels( + type="email", + reason="add_threepid", + **{SERVER_NAME_LABEL: self.server_name}, + ).observe(body.send_attempt) # Wrap the session id in a JSON object return 200, {"sid": sid} @@ -407,6 +413,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.hs = hs + self.server_name = hs.hostname super().__init__() self.store = self.hs.get_datastores().main self.identity_handler = hs.get_identity_handler() @@ -469,9 +476,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): body.next_link, ) - threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe( - body.send_attempt - ) + threepid_send_requests.labels( + type="msisdn", + reason="add_threepid", + **{SERVER_NAME_LABEL: self.server_name}, + ).observe(body.send_attempt) logger.info("MSISDN %s: got response from identity server: %s", msisdn, ret) return 200, ret @@ -604,7 +613,7 @@ class ThreepidRestServlet(RestServlet): # ThreePidBindRestServelet.PostBody with an `alias_generator` to handle # `threePidCreds` versus `three_pid_creds`. async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.hs.config.experimental.msc3861.enabled: + if self.hs.config.mas.enabled or self.hs.config.experimental.msc3861.enabled: raise NotFoundError(errcode=Codes.UNRECOGNIZED) if not self.hs.config.registration.enable_3pid_changes: @@ -896,18 +905,19 @@ class AccountStatusRestServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + auth_delegated = hs.config.mas.enabled or hs.config.experimental.msc3861.enabled + ThreepidRestServlet(hs).register(http_server) WhoamiRestServlet(hs).register(http_server) - if not hs.config.experimental.msc3861.enabled: + if not auth_delegated: DeactivateAccountRestServlet(hs).register(http_server) - # These servlets are only registered on the main process if hs.config.worker.worker_app is None: ThreepidBindRestServlet(hs).register(http_server) ThreepidUnbindRestServlet(hs).register(http_server) - if not hs.config.experimental.msc3861.enabled: + if not auth_delegated: EmailPasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) EmailThreepidRequestTokenRestServlet(hs).register(http_server) @@ -917,5 +927,5 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ThreepidAddRestServlet(hs).register(http_server) ThreepidDeleteRestServlet(hs).register(http_server) - if hs.config.experimental.msc3720_enabled: - AccountStatusRestServlet(hs).register(http_server) + if hs.config.experimental.msc3720_enabled: + AccountStatusRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index b8dca7c797..600bb51a7e 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -20,10 +20,11 @@ # import logging -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING from twisted.web.server import Request +from synapse.api.auth.mas import MasDelegatedAuth from synapse.api.constants import LoginType from synapse.api.errors import LoginError, SynapseError from synapse.api.urls import CLIENT_API_PREFIX @@ -66,22 +67,30 @@ class AuthRestServlet(RestServlet): if not session: raise SynapseError(400, "No session supplied") - if ( - self.hs.config.experimental.msc3861.enabled - and stagetype == "org.matrix.cross_signing_reset" - ): - # If MSC3861 is enabled, we can assume self._auth is an instance of MSC3861DelegatedAuth - # We import lazily here because of the authlib requirement - from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + if stagetype == "org.matrix.cross_signing_reset": + if self.hs.config.mas.enabled: + assert isinstance(self.auth, MasDelegatedAuth) - auth = cast(MSC3861DelegatedAuth, self.auth) - - url = await auth.account_management_url() - if url is not None: + url = await self.auth.account_management_url() url = f"{url}?action=org.matrix.cross_signing_reset" - else: - url = await auth.issuer() - respond_with_redirect(request, str.encode(url)) + return respond_with_redirect( + request, + url.encode(), + ) + + elif self.hs.config.experimental.msc3861.enabled: + # If MSC3861 is enabled, we can assume self._auth is an instance of MSC3861DelegatedAuth + # We import lazily here because of the authlib requirement + from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + + assert isinstance(self.auth, MSC3861DelegatedAuth) + + base = await self.auth.account_management_url() + if base is not None: + url = f"{base}?action=org.matrix.cross_signing_reset" + else: + url = await self.auth.issuer() + return respond_with_redirect(request, url.encode()) if stagetype == LoginType.RECAPTCHA: html = self.recaptcha_template.render( diff --git a/synapse/rest/client/auth_metadata.py b/synapse/rest/client/auth_metadata.py index 5444a89be6..25e01a6574 100644 --- a/synapse/rest/client/auth_metadata.py +++ b/synapse/rest/client/auth_metadata.py @@ -15,6 +15,7 @@ import logging import typing from typing import Tuple, cast +from synapse.api.auth.mas import MasDelegatedAuth from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet @@ -48,13 +49,18 @@ class AuthIssuerServlet(RestServlet): self._auth = hs.get_auth() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self._config.experimental.msc3861.enabled: + if self._config.mas.enabled: + assert isinstance(self._auth, MasDelegatedAuth) + return 200, {"issuer": await self._auth.issuer()} + + elif self._config.experimental.msc3861.enabled: # If MSC3861 is enabled, we can assume self._auth is an instance of MSC3861DelegatedAuth # We import lazily here because of the authlib requirement from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth - auth = cast(MSC3861DelegatedAuth, self._auth) - return 200, {"issuer": await auth.issuer()} + assert isinstance(self._auth, MSC3861DelegatedAuth) + return 200, {"issuer": await self._auth.issuer()} + else: # Wouldn't expect this to be reached: the servelet shouldn't have been # registered. Still, fail gracefully if we are registered for some reason. @@ -82,13 +88,18 @@ class AuthMetadataServlet(RestServlet): self._auth = hs.get_auth() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self._config.experimental.msc3861.enabled: + if self._config.mas.enabled: + assert isinstance(self._auth, MasDelegatedAuth) + return 200, await self._auth.auth_metadata() + + elif self._config.experimental.msc3861.enabled: # If MSC3861 is enabled, we can assume self._auth is an instance of MSC3861DelegatedAuth # We import lazily here because of the authlib requirement from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth auth = cast(MSC3861DelegatedAuth, self._auth) return 200, await auth.auth_metadata() + else: # Wouldn't expect this to be reached: the servlet shouldn't have been # registered. Still, fail gracefully if we are registered for some reason. @@ -100,7 +111,6 @@ class AuthMetadataServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - # We use the MSC3861 values as they are used by multiple MSCs - if hs.config.experimental.msc3861.enabled: + if hs.config.mas.enabled or hs.config.experimental.msc3861.enabled: AuthIssuerServlet(hs).register(http_server) AuthMetadataServlet(hs).register(http_server) diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 5667af20d4..0777abde7f 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -144,7 +144,9 @@ class DeviceRestServlet(RestServlet): self.device_handler = handler self.auth_handler = hs.get_auth_handler() self._msc3852_enabled = hs.config.experimental.msc3852_enabled - self._msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled + self._auth_delegation_enabled = ( + hs.config.mas.enabled or hs.config.experimental.msc3861.enabled + ) async def on_GET( self, request: SynapseRequest, device_id: str @@ -196,7 +198,7 @@ class DeviceRestServlet(RestServlet): pass else: - if self._msc3861_oauth_delegation_enabled: + if self._auth_delegation_enabled: raise UnrecognizedRequestError(code=404) await self.auth_handler.validate_user_via_ui_auth( @@ -573,7 +575,8 @@ class DehydratedDeviceV2Servlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - if not hs.config.experimental.msc3861.enabled: + auth_delegated = hs.config.mas.enabled or hs.config.experimental.msc3861.enabled + if not auth_delegated: DeleteDevicesRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 09749b840f..9f39889c75 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -23,8 +23,9 @@ import logging import re from collections import Counter -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple +from synapse.api.auth.mas import MasDelegatedAuth from synapse.api.errors import ( InteractiveAuthIncompleteError, InvalidAPICallError, @@ -404,19 +405,11 @@ class SigningKeyUploadServlet(RestServlet): if is_cross_signing_setup: # With MSC3861, UIA is not possible. Instead, the auth service has to # explicitly mark the master key as replaceable. - if self.hs.config.experimental.msc3861.enabled: + if self.hs.config.mas.enabled: if not master_key_updatable_without_uia: - # If MSC3861 is enabled, we can assume self.auth is an instance of MSC3861DelegatedAuth - # We import lazily here because of the authlib requirement - from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth - - auth = cast(MSC3861DelegatedAuth, self.auth) - - uri = await auth.account_management_url() - if uri is not None: - url = f"{uri}?action=org.matrix.cross_signing_reset" - else: - url = await auth.issuer() + assert isinstance(self.auth, MasDelegatedAuth) + url = await self.auth.account_management_url() + url = f"{url}?action=org.matrix.cross_signing_reset" # We use a dummy session ID as this isn't really a UIA flow, but we # reuse the same API shape for better client compatibility. @@ -437,6 +430,41 @@ class SigningKeyUploadServlet(RestServlet): "then try again.", }, ) + + elif self.hs.config.experimental.msc3861.enabled: + if not master_key_updatable_without_uia: + # If MSC3861 is enabled, we can assume self.auth is an instance of MSC3861DelegatedAuth + # We import lazily here because of the authlib requirement + from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + + assert isinstance(self.auth, MSC3861DelegatedAuth) + + uri = await self.auth.account_management_url() + if uri is not None: + url = f"{uri}?action=org.matrix.cross_signing_reset" + else: + url = await self.auth.issuer() + + # We use a dummy session ID as this isn't really a UIA flow, but we + # reuse the same API shape for better client compatibility. + raise InteractiveAuthIncompleteError( + "dummy", + { + "session": "dummy", + "flows": [ + {"stages": ["org.matrix.cross_signing_reset"]}, + ], + "params": { + "org.matrix.cross_signing_reset": { + "url": url, + }, + }, + "msg": "To reset your end-to-end encryption cross-signing " + f"identity, you first need to approve it at {url} and " + "then try again.", + }, + ) + else: # Without MSC3861, we require UIA. await self.auth_handler.validate_user_via_ui_auth( diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index aa0aa36cd9..acb9111ad2 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -715,7 +715,7 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - if hs.config.experimental.msc3861.enabled: + if hs.config.mas.enabled or hs.config.experimental.msc3861.enabled: return LoginRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py index 206865e989..39c62b9e26 100644 --- a/synapse/rest/client/logout.py +++ b/synapse/rest/client/logout.py @@ -86,7 +86,7 @@ class LogoutAllRestServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - if hs.config.experimental.msc3861.enabled: + if hs.config.mas.enabled or hs.config.experimental.msc3861.enabled: return LogoutRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 58231d2b04..102c04bb67 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -56,7 +56,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.http.site import SynapseRequest -from synapse.metrics import threepid_send_requests +from synapse.metrics import SERVER_NAME_LABEL, threepid_send_requests from synapse.push.mailer import Mailer from synapse.types import JsonDict from synapse.util.msisdn import phone_number_to_msisdn @@ -82,6 +82,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs + self.server_name = hs.hostname self.identity_handler = hs.get_identity_handler() self.config = hs.config @@ -163,9 +164,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): next_link, ) - threepid_send_requests.labels(type="email", reason="register").observe( - send_attempt - ) + threepid_send_requests.labels( + type="email", + reason="register", + **{SERVER_NAME_LABEL: self.server_name}, + ).observe(send_attempt) # Wrap the session id in a JSON object return 200, {"sid": sid} @@ -177,6 +180,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs + self.server_name = hs.hostname self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -240,9 +244,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): next_link, ) - threepid_send_requests.labels(type="msisdn", reason="register").observe( - send_attempt - ) + threepid_send_requests.labels( + type="msisdn", + reason="register", + **{SERVER_NAME_LABEL: self.server_name}, + ).observe(send_attempt) return 200, ret @@ -323,10 +329,12 @@ class UsernameAvailabilityRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs + self.server_name = hs.hostname self.registration_handler = hs.get_registration_handler() self.ratelimiter = FederationRateLimiter( - hs.get_clock(), - FederationRatelimitSettings( + our_server_name=self.server_name, + clock=hs.get_clock(), + config=FederationRatelimitSettings( # Time window of 2s window_size=2000, # Artificially delay requests if rate > sleep_limit/window_size @@ -1036,7 +1044,7 @@ def _calculate_registration_flows( def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: - if hs.config.experimental.msc3861.enabled: + if hs.config.mas.enabled or hs.config.experimental.msc3861.enabled: RegisterAppServiceOnlyRestServlet(hs).register(http_server) return diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 6b0deda0df..64deae7650 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -65,6 +65,7 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache @@ -120,7 +121,7 @@ messsages_response_timer = Histogram( # picture of /messages response time for bigger rooms. We don't want the # tiny rooms that can always respond fast skewing our results when we're trying # to optimize the bigger cases. - ["room_size"], + labelnames=["room_size", SERVER_NAME_LABEL], buckets=( 0.005, 0.01, @@ -801,6 +802,7 @@ class RoomMessageListRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() self._hs = hs + self.server_name = hs.hostname self.clock = hs.get_clock() self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() @@ -849,7 +851,8 @@ class RoomMessageListRestServlet(RestServlet): processing_end_time = self.clock.time_msec() room_member_count = await make_deferred_yieldable(room_member_count_deferred) messsages_response_timer.labels( - room_size=_RoomSize.from_member_count(room_member_count) + room_size=_RoomSize.from_member_count(room_member_count), + **{SERVER_NAME_LABEL: self.server_name}, ).observe((processing_end_time - processing_start_time) / 1000) return 200, msgs @@ -1100,6 +1103,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() + self.config = hs.config def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/[join|invite|leave|ban|unban|kick] @@ -1123,12 +1127,12 @@ class RoomMembershipRestServlet(TransactionRestServlet): }: raise AuthError(403, "Guest access not allowed") - content = parse_json_object_from_request(request, allow_empty_body=True) + request_body = parse_json_object_from_request(request, allow_empty_body=True) if membership_action == "invite" and all( - key in content for key in ("medium", "address") + key in request_body for key in ("medium", "address") ): - if not all(key in content for key in ("id_server", "id_access_token")): + if not all(key in request_body for key in ("id_server", "id_access_token")): raise SynapseError( HTTPStatus.BAD_REQUEST, "`id_server` and `id_access_token` are required when doing 3pid invite", @@ -1139,12 +1143,12 @@ class RoomMembershipRestServlet(TransactionRestServlet): await self.room_member_handler.do_3pid_invite( room_id, requester.user, - content["medium"], - content["address"], - content["id_server"], + request_body["medium"], + request_body["address"], + request_body["id_server"], requester, txn_id, - content["id_access_token"], + request_body["id_access_token"], ) except ShadowBanError: # Pretend the request succeeded. @@ -1153,12 +1157,19 @@ class RoomMembershipRestServlet(TransactionRestServlet): target = requester.user if membership_action in ["invite", "ban", "unban", "kick"]: - assert_params_in_dict(content, ["user_id"]) - target = UserID.from_string(content["user_id"]) + assert_params_in_dict(request_body, ["user_id"]) + target = UserID.from_string(request_body["user_id"]) event_content = None - if "reason" in content: - event_content = {"reason": content["reason"]} + if "reason" in request_body: + event_content = {"reason": request_body["reason"]} + if self.config.experimental.msc4293_enabled: + if "org.matrix.msc4293.redact_events" in request_body: + if event_content is None: + event_content = {} + event_content["org.matrix.msc4293.redact_events"] = request_body[ + "org.matrix.msc4293.redact_events" + ] try: await self.room_member_handler.update_membership( @@ -1167,7 +1178,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): room_id=room_id, action=membership_action, txn_id=txn_id, - third_party_signed=content.get("third_party_signed", None), + third_party_signed=request_body.get("third_party_signed", None), content=event_content, ) except ShadowBanError: @@ -1213,6 +1224,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) + self.server_name = hs.hostname self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() self._store = hs.get_datastores().main @@ -1297,6 +1309,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): if with_relations: run_as_background_process( "redact_related_events", + self.server_name, self._relation_handler.redact_events_related_to, requester=requester, event_id=event_id, diff --git a/synapse/rest/client/thread_subscriptions.py b/synapse/rest/client/thread_subscriptions.py index 5307132ec3..4e7b5d06db 100644 --- a/synapse/rest/client/thread_subscriptions.py +++ b/synapse/rest/client/thread_subscriptions.py @@ -1,7 +1,6 @@ from http import HTTPStatus -from typing import Tuple +from typing import TYPE_CHECKING, Optional, Tuple -from synapse._pydantic_compat import StrictBool from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import ( @@ -10,9 +9,12 @@ from synapse.http.servlet import ( ) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.server import HomeServer from synapse.types import JsonDict, RoomID from synapse.types.rest import RequestBodyModel +from synapse.util.pydantic_models import AnyEventId + +if TYPE_CHECKING: + from synapse.server import HomeServer class ThreadSubscriptionsRestServlet(RestServlet): @@ -30,7 +32,12 @@ class ThreadSubscriptionsRestServlet(RestServlet): self.handler = hs.get_thread_subscriptions_handler() class PutBody(RequestBodyModel): - automatic: StrictBool + automatic: Optional[AnyEventId] + """ + If supplied, the event ID of an event giving rise to this automatic subscription. + + If omitted, this subscription is a manual subscription. + """ async def on_GET( self, request: SynapseRequest, room_id: str, thread_root_id: str @@ -61,15 +68,15 @@ class ThreadSubscriptionsRestServlet(RestServlet): raise SynapseError( HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM ) - requester = await self.auth.get_user_by_req(request) - body = parse_and_validate_json_object_from_request(request, self.PutBody) + requester = await self.auth.get_user_by_req(request) + await self.handler.subscribe_user_to_thread( requester.user, room_id, thread_root_id, - automatic=body.automatic, + automatic_event_id=body.automatic, ) return HTTPStatus.OK, {} diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index fa39eb9e6d..7f78379534 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -177,6 +177,8 @@ class VersionsRestServlet(RestServlet): "uk.tcpip.msc4133": self.config.experimental.msc4133_enabled, # MSC4155: Invite filtering "org.matrix.msc4155": self.config.experimental.msc4155_enabled, + # MSC4306: Support for thread subscriptions + "org.matrix.msc4306": self.config.experimental.msc4306_enabled, }, }, ) diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 043c508379..665ce77dd7 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -56,8 +56,9 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc "/_synapse/client/unsubscribe": UnsubscribeResource(hs), } - # Expose the JWKS endpoint if OAuth2 delegation is enabled - if hs.config.experimental.msc3861.enabled: + if hs.config.mas.enabled: + resources["/_synapse/mas"] = MasResource(hs) + elif hs.config.experimental.msc3861.enabled: from synapse.rest.synapse.client.jwks import JwksResource resources["/_synapse/jwks"] = JwksResource(hs) diff --git a/synapse/rest/synapse/mas/_base.py b/synapse/rest/synapse/mas/_base.py index caf392fc3a..7346198b75 100644 --- a/synapse/rest/synapse/mas/_base.py +++ b/synapse/rest/synapse/mas/_base.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, cast +from synapse.api.auth.mas import MasDelegatedAuth from synapse.api.errors import SynapseError from synapse.http.server import DirectServeJsonResource @@ -27,14 +28,21 @@ if TYPE_CHECKING: class MasBaseResource(DirectServeJsonResource): def __init__(self, hs: "HomeServer"): - # Importing this module requires authlib, which is an optional - # dependency but required if msc3861 is enabled - from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + auth = hs.get_auth() + if hs.config.mas.enabled: + assert isinstance(auth, MasDelegatedAuth) + + self._is_request_from_mas = auth.is_request_using_the_shared_secret + else: + # Importing this module requires authlib, which is an optional + # dependency but required if msc3861 is enabled + from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth + + assert isinstance(auth, MSC3861DelegatedAuth) + + self._is_request_from_mas = auth.is_request_using_the_admin_token DirectServeJsonResource.__init__(self, extract_context=True) - auth = hs.get_auth() - assert isinstance(auth, MSC3861DelegatedAuth) - self.msc3861_auth = auth self.store = cast("GenericWorkerStore", hs.get_datastores().main) self.hostname = hs.hostname @@ -43,5 +51,5 @@ class MasBaseResource(DirectServeJsonResource): Throws a 403 if the request is not coming from MAS. """ - if not self.msc3861_auth.is_request_using_the_admin_token(request): + if not self._is_request_from_mas(request): raise SynapseError(403, "This endpoint must only be called by MAS") diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index b4476e5a69..e4fe4c45ef 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -18,11 +18,12 @@ # # import logging -from typing import TYPE_CHECKING, Optional, Tuple, cast +from typing import TYPE_CHECKING, Optional, Tuple from twisted.web.resource import Resource from twisted.web.server import Request +from synapse.api.auth.mas import MasDelegatedAuth from synapse.api.errors import NotFoundError from synapse.http.server import DirectServeJsonResource from synapse.http.site import SynapseRequest @@ -52,18 +53,25 @@ class WellKnownBuilder: "base_url": self._config.registration.default_identity_server } - # We use the MSC3861 values as they are used by multiple MSCs - if self._config.experimental.msc3861.enabled: + if self._config.mas.enabled: + assert isinstance(self._auth, MasDelegatedAuth) + + result["org.matrix.msc2965.authentication"] = { + "issuer": await self._auth.issuer(), + "account": await self._auth.account_management_url(), + } + + elif self._config.experimental.msc3861.enabled: # If MSC3861 is enabled, we can assume self._auth is an instance of MSC3861DelegatedAuth # We import lazily here because of the authlib requirement from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth - auth = cast(MSC3861DelegatedAuth, self._auth) + assert isinstance(self._auth, MSC3861DelegatedAuth) result["org.matrix.msc2965.authentication"] = { - "issuer": await auth.issuer(), + "issuer": await self._auth.issuer(), } - account_management_url = await auth.account_management_url() + account_management_url = await self._auth.account_management_url() if account_management_url is not None: result["org.matrix.msc2965.authentication"]["account"] = ( account_management_url diff --git a/synapse/server.py b/synapse/server.py index 231bd14907..bf82f79bec 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -40,6 +40,7 @@ from twisted.web.resource import Resource from synapse.api.auth import Auth from synapse.api.auth.internal import InternalAuth +from synapse.api.auth.mas import MasDelegatedAuth from synapse.api.auth_blocking import AuthBlocking from synapse.api.filtering import Filtering from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter @@ -423,7 +424,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_distributor(self) -> Distributor: - return Distributor() + return Distributor(server_name=self.hostname) @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: @@ -451,6 +452,8 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_auth(self) -> Auth: + if self.config.mas.enabled: + return MasDelegatedAuth(self) if self.config.experimental.msc3861.enabled: from synapse.api.auth.msc3861_delegated import MSC3861DelegatedAuth @@ -849,7 +852,8 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_federation_ratelimiter(self) -> FederationRateLimiter: return FederationRateLimiter( - self.get_clock(), + our_server_name=self.hostname, + clock=self.get_clock(), config=self.config.ratelimiting.rc_federation, metrics_name="federation_servlets", ) @@ -980,7 +984,10 @@ class HomeServer(metaclass=abc.ABCMeta): ) # Register the threadpool with our metrics. - register_threadpool("media", media_threadpool) + server_name = self.hostname + register_threadpool( + name="media", server_name=server_name, threadpool=media_threadpool + ) return media_threadpool diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index d5f892a7c0..3d8016c264 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -51,6 +51,7 @@ from synapse.events.snapshot import ( ) from synapse.logging.context import ContextResourceUsage from synapse.logging.opentracing import tag_args, trace +from synapse.metrics import SERVER_NAME_LABEL from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet from synapse.state import v1, v2 from synapse.storage.databases.main.event_federation import StateDifference @@ -75,6 +76,7 @@ metrics_logger = logging.getLogger("synapse.state.metrics") state_groups_histogram = Histogram( "synapse_state_number_state_groups_in_resolution", "Number of state groups used when performing a state resolution", + labelnames=[SERVER_NAME_LABEL], buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), ) @@ -608,20 +610,24 @@ _biggest_room_by_cpu_counter = Counter( "synapse_state_res_cpu_for_biggest_room_seconds", "CPU time spent performing state resolution for the single most expensive " "room for state resolution", + labelnames=[SERVER_NAME_LABEL], ) _biggest_room_by_db_counter = Counter( "synapse_state_res_db_for_biggest_room_seconds", "Database time spent performing state resolution for the single most " "expensive room for state resolution", + labelnames=[SERVER_NAME_LABEL], ) _cpu_times = Histogram( "synapse_state_res_cpu_for_all_rooms_seconds", "CPU time (utime+stime) spent computing a single state resolution", + labelnames=[SERVER_NAME_LABEL], ) _db_times = Histogram( "synapse_state_res_db_for_all_rooms_seconds", "Database time spent computing a single state resolution", + labelnames=[SERVER_NAME_LABEL], ) @@ -736,7 +742,9 @@ class StateResolutionHandler: f"State groups have been deleted: {shortstr(missing_state_groups)}" ) - state_groups_histogram.observe(len(state_groups_ids)) + state_groups_histogram.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).observe(len(state_groups_ids)) new_state = await self.resolve_events_with_store( room_id, @@ -823,8 +831,12 @@ class StateResolutionHandler: room_metrics.db_time += rusage.db_txn_duration_sec room_metrics.db_events += rusage.evt_db_fetch_count - _cpu_times.observe(rusage.ru_utime + rusage.ru_stime) - _db_times.observe(rusage.db_txn_duration_sec) + _cpu_times.labels(**{SERVER_NAME_LABEL: self.server_name}).observe( + rusage.ru_utime + rusage.ru_stime + ) + _db_times.labels(**{SERVER_NAME_LABEL: self.server_name}).observe( + rusage.db_txn_duration_sec + ) def _report_metrics(self) -> None: if not self._state_res_metrics: @@ -881,7 +893,9 @@ class StateResolutionHandler: # report info on the single biggest to prometheus _, biggest_metrics = biggest[0] - prometheus_counter_metric.inc(extract_key(biggest_metrics)) + prometheus_counter_metric.labels(**{SERVER_NAME_LABEL: self.server_name}).inc( + extract_key(biggest_metrics) + ) def _make_state_cache_entry( diff --git a/synapse/state/v2.py b/synapse/state/v2.py index cb0f3d94c9..8bf6706434 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -53,7 +53,7 @@ class Clock(Protocol): # This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests. # We only ever sleep(0) though, so that other async functions can make forward # progress without waiting for stateres to complete. - def sleep(self, duration_ms: float) -> Awaitable[None]: ... + async def sleep(self, duration_ms: float) -> None: ... class StateResolutionStore(Protocol): diff --git a/synapse/storage/admin_client_config.py b/synapse/storage/admin_client_config.py index 4359721a6d..07acddc660 100644 --- a/synapse/storage/admin_client_config.py +++ b/synapse/storage/admin_client_config.py @@ -15,8 +15,12 @@ class AdminClientConfig: # `unsigned` portion of the event to inform clients that the event # is soft-failed. self.return_soft_failed_events: bool = False + self.return_policy_server_spammy_events: bool = False if account_data: self.return_soft_failed_events = account_data.get( "return_soft_failed_events", False ) + self.return_policy_server_spammy_events = account_data.get( + "return_policy_server_spammy_events", self.return_soft_failed_events + ) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index d170bbddaa..acc0abee63 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -249,6 +249,7 @@ class BackgroundUpdater: self._clock = hs.get_clock() self.db_pool = database self.hs = hs + self.server_name = hs.hostname self._database_name = database.name() @@ -395,7 +396,10 @@ class BackgroundUpdater: self._all_done = False sleep = self.sleep_enabled run_as_background_process( - "background_updates", self.run_background_updates, sleep + "background_updates", + self.server_name, + self.run_background_updates, + sleep, ) async def run_background_updates(self, sleep: bool) -> None: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 9f54430a22..95a34f7be1 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -61,6 +61,7 @@ from synapse.logging.opentracing import ( start_active_span_follows_from, trace, ) +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases @@ -82,25 +83,30 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) # The number of times we are recalculating the current state -state_delta_counter = Counter("synapse_storage_events_state_delta", "") +state_delta_counter = Counter( + "synapse_storage_events_state_delta", "", labelnames=[SERVER_NAME_LABEL] +) # The number of times we are recalculating state when there is only a # single forward extremity state_delta_single_event_counter = Counter( - "synapse_storage_events_state_delta_single_event", "" + "synapse_storage_events_state_delta_single_event", + "", + labelnames=[SERVER_NAME_LABEL], ) # The number of times we are reculating state when we could have resonably # calculated the delta when we calculated the state for an event we were # persisting. state_delta_reuse_delta_counter = Counter( - "synapse_storage_events_state_delta_reuse_delta", "" + "synapse_storage_events_state_delta_reuse_delta", "", labelnames=[SERVER_NAME_LABEL] ) # The number of forward extremities for each new event. forward_extremities_counter = Histogram( "synapse_storage_events_forward_extremities_persisted", "Number of forward extremities for each new event", + labelnames=[SERVER_NAME_LABEL], buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), ) @@ -109,22 +115,26 @@ forward_extremities_counter = Histogram( stale_forward_extremities_counter = Histogram( "synapse_storage_events_stale_forward_extremities_persisted", "Number of unchanged forward extremities for each new event", + labelnames=[SERVER_NAME_LABEL], buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"), ) state_resolutions_during_persistence = Counter( "synapse_storage_events_state_resolutions_during_persistence", "Number of times we had to do state res to calculate new current state", + labelnames=[SERVER_NAME_LABEL], ) potential_times_prune_extremities = Counter( "synapse_storage_events_potential_times_prune_extremities", "Number of times we might be able to prune extremities", + labelnames=[SERVER_NAME_LABEL], ) times_pruned_extremities = Counter( "synapse_storage_events_times_pruned_extremities", "Number of times we were actually be able to prune extremities", + labelnames=[SERVER_NAME_LABEL], ) @@ -185,6 +195,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): def __init__( self, + server_name: str, per_item_callback: Callable[ [str, _EventPersistQueueTask], Awaitable[_PersistResult], @@ -195,6 +206,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): The per_item_callback will be called for each item added via add_to_queue, and its result will be returned via the Deferreds returned from add_to_queue. """ + self.server_name = server_name self._event_persist_queues: Dict[str, Deque[_EventPersistQueueItem]] = {} self._currently_persisting_rooms: Set[str] = set() self._per_item_callback = per_item_callback @@ -299,7 +311,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): self._currently_persisting_rooms.discard(room_id) # set handle_queue_loop off in the background - run_as_background_process("persist_events", handle_queue_loop) + run_as_background_process("persist_events", self.server_name, handle_queue_loop) def _get_drainining_queue( self, room_id: str @@ -342,7 +354,7 @@ class EventsPersistenceStorageController: self._instance_name = hs.get_instance_name() self.is_mine_id = hs.is_mine_id self._event_persist_queue = _EventPeristenceQueue( - self._process_event_persist_queue_task + self.server_name, self._process_event_persist_queue_task ) self._state_resolution_handler = hs.get_state_resolution_handler() self._state_controller = state_controller @@ -707,9 +719,11 @@ class EventsPersistenceStorageController: if all_single_prev_not_state: return (new_forward_extremities, None) - state_delta_counter.inc() + state_delta_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc() if len(new_latest_event_ids) == 1: - state_delta_single_event_counter.inc() + state_delta_single_event_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc() # This is a fairly handwavey check to see if we could # have guessed what the delta would have been when @@ -724,7 +738,9 @@ class EventsPersistenceStorageController: for ev, _ in ev_ctx_rm: prev_event_ids = set(ev.prev_event_ids()) if latest_event_ids == prev_event_ids: - state_delta_reuse_delta_counter.inc() + state_delta_reuse_delta_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc() break logger.debug("Calculating state delta for room %s", room_id) @@ -833,9 +849,13 @@ class EventsPersistenceStorageController: # We only update metrics for events that change forward extremities # (e.g. we ignore backfill/outliers/etc) if result != latest_event_ids: - forward_extremities_counter.observe(len(result)) + forward_extremities_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).observe(len(result)) stale = latest_event_ids & result - stale_forward_extremities_counter.observe(len(stale)) + stale_forward_extremities_counter.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).observe(len(stale)) return result @@ -994,7 +1014,9 @@ class EventsPersistenceStorageController: ), ) - state_resolutions_during_persistence.inc() + state_resolutions_during_persistence.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc() # If the returned state matches the state group of one of the new # forward extremities then we check if we are able to prune some state @@ -1022,7 +1044,9 @@ class EventsPersistenceStorageController: """See if we can prune any of the extremities after calculating the resolved state. """ - potential_times_prune_extremities.inc() + potential_times_prune_extremities.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc() # We keep all the extremities that have the same state group, and # see if we can drop the others. @@ -1120,7 +1144,7 @@ class EventsPersistenceStorageController: return new_latest_event_ids - times_pruned_extremities.inc() + times_pruned_extremities.labels(**{SERVER_NAME_LABEL: self.server_name}).inc() logger.info( "Pruning forward extremities in room %s: from %s -> %s", diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py index df3f264b06..14b37ac543 100644 --- a/synapse/storage/controllers/purge_events.py +++ b/synapse/storage/controllers/purge_events.py @@ -46,6 +46,9 @@ class PurgeEventsStorageController: """High level interface for purging rooms and event history.""" def __init__(self, hs: "HomeServer", stores: Databases): + self.server_name = ( + hs.hostname + ) # nb must be called this for @wrap_as_background_process self.stores = stores if hs.config.worker.run_background_tasks: diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 6188195614..f7aec16c96 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -61,7 +61,7 @@ from synapse.logging.context import ( current_context, make_deferred_yieldable, ) -from synapse.metrics import LaterGauge, register_threadpool +from synapse.metrics import SERVER_NAME_LABEL, LaterGauge, register_threadpool from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine @@ -82,11 +82,23 @@ sql_logger = logging.getLogger("synapse.storage.SQL") transaction_logger = logging.getLogger("synapse.storage.txn") perf_logger = logging.getLogger("synapse.storage.TIME") -sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec") +sql_scheduling_timer = Histogram( + "synapse_storage_schedule_time", "sec", labelnames=[SERVER_NAME_LABEL] +) -sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"]) -sql_txn_count = Counter("synapse_storage_transaction_time_count", "sec", ["desc"]) -sql_txn_duration = Counter("synapse_storage_transaction_time_sum", "sec", ["desc"]) +sql_query_timer = Histogram( + "synapse_storage_query_time", "sec", labelnames=["verb", SERVER_NAME_LABEL] +) +sql_txn_count = Counter( + "synapse_storage_transaction_time_count", + "sec", + labelnames=["desc", SERVER_NAME_LABEL], +) +sql_txn_duration = Counter( + "synapse_storage_transaction_time_sum", + "sec", + labelnames=["desc", SERVER_NAME_LABEL], +) # Unique indexes which have been added in background updates. Maps from table name @@ -118,9 +130,11 @@ class _PoolConnection(Connection): def make_pool( + *, reactor: IReactorCore, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine, + server_name: str, ) -> adbapi.ConnectionPool: """Get the connection pool for the database.""" @@ -134,7 +148,12 @@ def make_pool( # etc. with LoggingContext("db.on_new_connection"): engine.on_new_connection( - LoggingDatabaseConnection(conn, engine, "on_new_connection") + LoggingDatabaseConnection( + conn=conn, + engine=engine, + default_txn_name="on_new_connection", + server_name=server_name, + ) ) connection_pool = adbapi.ConnectionPool( @@ -144,15 +163,21 @@ def make_pool( **db_args, ) - register_threadpool(f"database-{db_config.name}", connection_pool.threadpool) + register_threadpool( + name=f"database-{db_config.name}", + server_name=server_name, + threadpool=connection_pool.threadpool, + ) return connection_pool def make_conn( + *, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine, default_txn_name: str, + server_name: str, ) -> "LoggingDatabaseConnection": """Make a new connection to the database and return it. @@ -166,13 +191,18 @@ def make_conn( if not k.startswith("cp_") } native_db_conn = engine.module.connect(**db_params) - db_conn = LoggingDatabaseConnection(native_db_conn, engine, default_txn_name) + db_conn = LoggingDatabaseConnection( + conn=native_db_conn, + engine=engine, + default_txn_name=default_txn_name, + server_name=server_name, + ) engine.on_new_connection(db_conn) return db_conn -@attr.s(slots=True, auto_attribs=True) +@attr.s(slots=True, auto_attribs=True, kw_only=True) class LoggingDatabaseConnection: """A wrapper around a database connection that returns `LoggingTransaction` as its cursor class. @@ -183,6 +213,7 @@ class LoggingDatabaseConnection: conn: Connection engine: BaseDatabaseEngine default_txn_name: str + server_name: str def cursor( self, @@ -196,8 +227,9 @@ class LoggingDatabaseConnection: txn_name = self.default_txn_name return LoggingTransaction( - self.conn.cursor(), + txn=self.conn.cursor(), name=txn_name, + server_name=self.server_name, database_engine=self.engine, after_callbacks=after_callbacks, async_after_callbacks=async_after_callbacks, @@ -266,6 +298,7 @@ class LoggingTransaction: __slots__ = [ "txn", "name", + "server_name", "database_engine", "after_callbacks", "async_after_callbacks", @@ -274,8 +307,10 @@ class LoggingTransaction: def __init__( self, + *, txn: Cursor, name: str, + server_name: str, database_engine: BaseDatabaseEngine, after_callbacks: Optional[List[_CallbackListEntry]] = None, async_after_callbacks: Optional[List[_AsyncCallbackListEntry]] = None, @@ -283,6 +318,7 @@ class LoggingTransaction: ): self.txn = txn self.name = name + self.server_name = server_name self.database_engine = database_engine self.after_callbacks = after_callbacks self.async_after_callbacks = async_after_callbacks @@ -493,7 +529,9 @@ class LoggingTransaction: finally: secs = time.time() - start sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs) - sql_query_timer.labels(sql.split()[0]).observe(secs) + sql_query_timer.labels( + verb=sql.split()[0], **{SERVER_NAME_LABEL: self.server_name} + ).observe(secs) def close(self) -> None: self.txn.close() @@ -561,17 +599,23 @@ class DatabasePool: engine: BaseDatabaseEngine, ): self.hs = hs + self.server_name = hs.hostname self._clock = hs.get_clock() self._txn_limit = database_config.config.get("txn_limit", 0) self._database_config = database_config - self._db_pool = make_pool(hs.get_reactor(), database_config, engine) + self._db_pool = make_pool( + reactor=hs.get_reactor(), + db_config=database_config, + engine=engine, + server_name=self.server_name, + ) self.updates = BackgroundUpdater(hs, self) LaterGauge( - "synapse_background_update_status", - "Background update status", - [], - self.updates.get_status, + name="synapse_background_update_status", + desc="Background update status", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): self.updates.get_status()}, ) self._previous_txn_total_time = 0.0 @@ -602,6 +646,7 @@ class DatabasePool: 0.0, run_as_background_process, "upsert_safety_check", + self.server_name, self._check_safe_to_upsert, ) @@ -644,6 +689,7 @@ class DatabasePool: 15.0, run_as_background_process, "upsert_safety_check", + self.server_name, self._check_safe_to_upsert, ) @@ -866,8 +912,14 @@ class DatabasePool: self._current_txn_total_time += duration self._txn_perf_counters.update(desc, duration) - sql_txn_count.labels(desc).inc(1) - sql_txn_duration.labels(desc).inc(duration) + sql_txn_count.labels( + desc=desc, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc(1) + sql_txn_duration.labels( + desc=desc, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc(duration) async def runInteraction( self, @@ -1003,7 +1055,9 @@ class DatabasePool: operation_name="db.connection", ): sched_duration_sec = monotonic_time() - start_time - sql_scheduling_timer.observe(sched_duration_sec) + sql_scheduling_timer.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).observe(sched_duration_sec) context.add_database_scheduled(sched_duration_sec) if self._txn_limit > 0: @@ -1036,7 +1090,10 @@ class DatabasePool: ) db_conn = LoggingDatabaseConnection( - conn, self.engine, "runWithConnection" + conn=conn, + engine=self.engine, + default_txn_name="runWithConnection", + server_name=self.server_name, ) return func(db_conn, *args, **kwargs) finally: diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index 81886ff765..6442ab6c7a 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -69,11 +69,18 @@ class Databases(Generic[DataStoreT]): state_deletion: Optional[StateDeletionDataStore] = None persist_events: Optional[PersistEventsStore] = None + server_name = hs.hostname + for database_config in hs.config.database.databases: db_name = database_config.name engine = create_engine(database_config.config) - with make_conn(database_config, engine, "startup") as db_conn: + with make_conn( + db_config=database_config, + engine=engine, + default_txn_name="startup", + server_name=server_name, + ) as db_conn: logger.info("[database config %r]: Checking database server", db_name) engine.check_database(db_conn) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index 883ab93f7c..c049789e44 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -78,6 +78,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) db=database, notifier=hs.get_replication_notifier(), stream_name="account_data", + server_name=self.server_name, instance_name=self._instance_name, tables=[ ("room_account_data", "instance_name", "stream_id"), diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index dc37f67110..7794926812 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -104,10 +104,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): # caches to invalidate. (This reduces the amount of writes to the DB # that happen). self._cache_id_gen = MultiWriterIdGenerator( - db_conn, - database, + db_conn=db_conn, + db=database, notifier=hs.get_replication_notifier(), stream_name="caches", + server_name=self.server_name, instance_name=hs.get_instance_name(), tables=[ ( diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index da10afbebe..c10e2d2611 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -109,6 +109,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): db=database, notifier=hs.get_replication_notifier(), stream_name="to_device", + server_name=self.server_name, instance_name=self._instance_name, tables=[ ("device_inbox", "instance_name", "stream_id"), @@ -156,6 +157,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): run_as_background_process, DEVICE_FEDERATION_INBOX_CLEANUP_INTERVAL_MS, "_delete_old_federation_inbox_rows", + self.server_name, self._delete_old_federation_inbox_rows, ) @@ -1029,7 +1031,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): # We sleep a bit so that we don't hammer the database in a tight # loop first time we run this. - self._clock.sleep(1) + await self._clock.sleep(1) async def get_devices_with_messages( self, user_id: str, device_ids: StrCollection diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 6ed9f85800..a28cc40a95 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -103,6 +103,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): db=database, notifier=hs.get_replication_notifier(), stream_name="device_lists_stream", + server_name=self.server_name, instance_name=self._instance_name, tables=[ ("device_lists_stream", "instance_name", "stream_id"), diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index a4a8aafa0c..de72e66ceb 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -125,6 +125,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker db=database, notifier=hs.get_replication_notifier(), stream_name="e2e_cross_signing_keys", + server_name=self.server_name, instance_name=self._instance_name, tables=[ ("e2e_cross_signing_keys", "instance_name", "stream_id"), diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 28d202ef0a..26a91109df 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -45,6 +45,7 @@ from synapse.api.errors import StoreError from synapse.api.room_versions import EventFormatVersions, RoomVersion from synapse.events import EventBase, make_event_from_dict from synapse.logging.opentracing import tag_args, trace +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.background_updates import ForeignKeyConstraint @@ -70,17 +71,20 @@ if TYPE_CHECKING: oldest_pdu_in_federation_staging = Gauge( "synapse_federation_server_oldest_inbound_pdu_in_staging", "The age in seconds since we received the oldest pdu in the federation staging area", + labelnames=[SERVER_NAME_LABEL], ) number_pdus_in_federation_queue = Gauge( "synapse_federation_server_number_inbound_pdu_in_staging", "The total number of events in the inbound federation staging", + labelnames=[SERVER_NAME_LABEL], ) pdus_pruned_from_federation_queue = Counter( "synapse_federation_server_number_inbound_pdu_pruned", "The number of events in the inbound federation staging that have been " "pruned due to the queue getting too long", + labelnames=[SERVER_NAME_LABEL], ) logger = logging.getLogger(__name__) @@ -2248,7 +2252,9 @@ class EventFederationWorkerStore( if not to_delete: return False - pdus_pruned_from_federation_queue.inc(len(to_delete)) + pdus_pruned_from_federation_queue.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).inc(len(to_delete)) logger.info( "Pruning %d events in room %s from federation queue", len(to_delete), @@ -2301,8 +2307,12 @@ class EventFederationWorkerStore( "_get_stats_for_federation_staging", _get_stats_for_federation_staging_txn ) - number_pdus_in_federation_queue.set(count) - oldest_pdu_in_federation_staging.set(age) + number_pdus_in_federation_queue.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).set(count) + oldest_pdu_in_federation_staging.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).set(age) async def clean_room_for_join(self, room_id: str) -> None: await self.db_pool.runInteraction( diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 741146417f..2478367f0d 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -51,10 +51,16 @@ from synapse.api.constants import ( ) from synapse.api.errors import PartialStateConflictError from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase, StrippedStateEvent, relation_from_event +from synapse.events import ( + EventBase, + StrippedStateEvent, + is_creator, + relation_from_event, +) from synapse.events.snapshot import EventContext from synapse.events.utils import parse_stripped_state_event from synapse.logging.opentracing import trace +from synapse.metrics import SERVER_NAME_LABEL from synapse.storage._base import db_to_json, make_in_list_sql_clause from synapse.storage.database import ( DatabasePool, @@ -89,11 +95,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -persist_event_counter = Counter("synapse_storage_events_persisted_events", "") +persist_event_counter = Counter( + "synapse_storage_events_persisted_events", "", labelnames=[SERVER_NAME_LABEL] +) event_counter = Counter( "synapse_storage_events_persisted_events_sep", "", - ["type", "origin_type", "origin_entity"], + labelnames=["type", "origin_type", "origin_entity", SERVER_NAME_LABEL], ) # State event type/key pairs that we need to gather to fill in the @@ -237,6 +245,7 @@ class PersistEventsStore: db_conn: LoggingDatabaseConnection, ): self.hs = hs + self.server_name = hs.hostname self.db_pool = db self.store = main_data_store self.database_engine = db.engine @@ -357,12 +366,16 @@ class PersistEventsStore: new_event_links=new_event_links, sliding_sync_table_changes=sliding_sync_table_changes, ) - persist_event_counter.inc(len(events_and_contexts)) + persist_event_counter.labels(**{SERVER_NAME_LABEL: self.server_name}).inc( + len(events_and_contexts) + ) if not use_negative_stream_ordering: # we don't want to set the event_persisted_position to a negative # stream_ordering. - synapse.metrics.event_persisted_position.set(stream) + synapse.metrics.event_persisted_position.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).set(stream) for event, context in events_and_contexts: if context.app_service: @@ -375,13 +388,147 @@ class PersistEventsStore: origin_type = "remote" origin_entity = get_domain_from_id(event.sender) - event_counter.labels(event.type, origin_type, origin_entity).inc() + event_counter.labels( + type=event.type, + origin_type=origin_type, + origin_entity=origin_entity, + **{SERVER_NAME_LABEL: self.server_name}, + ).inc() + + if ( + not self.hs.config.experimental.msc4293_enabled + or event.type != EventTypes.Member + or event.state_key is None + ): + continue + + # check if this is an unban/join that will undo a ban/kick redaction for + # a user in the room + if event.membership in [Membership.LEAVE, Membership.JOIN]: + if ( + event.membership == Membership.LEAVE + and event.sender == event.state_key + ): + # self-leave, ignore + continue + + # if there is an existing ban/leave causing redactions for + # this user/room combination update the entry with the stream + # ordering when the redactions should stop - in the case of a backfilled + # event where the stream ordering is negative, use the current max stream + # ordering + stream_ordering = event.internal_metadata.stream_ordering + assert stream_ordering is not None + if stream_ordering < 0: + stream_ordering = self._stream_id_gen.get_current_token() + await self.db_pool.simple_update( + "room_ban_redactions", + {"room_id": event.room_id, "user_id": event.state_key}, + {"redact_end_ordering": stream_ordering}, + desc="room_ban_redactions update redact_end_ordering", + ) + + # check for msc4293 redact_events flag and apply if found + if event.membership not in [Membership.LEAVE, Membership.BAN]: + continue + redact = event.content.get("org.matrix.msc4293.redact_events", False) + if not redact or not isinstance(redact, bool): + continue + # self-bans currently are not authorized so we don't check for that + # case + if ( + event.membership == Membership.BAN + and event.sender == event.state_key + ): + continue + + # check that sender can redact + redact_allowed = await self._can_sender_redact(event) + + # Signal that this user's past events in this room + # should be redacted by adding an entry to + # `room_ban_redactions`. + if redact_allowed: + await self.db_pool.simple_upsert( + "room_ban_redactions", + {"room_id": event.room_id, "user_id": event.state_key}, + { + "redacting_event_id": event.event_id, + "redact_end_ordering": None, + }, + { + "room_id": event.room_id, + "user_id": event.state_key, + "redacting_event_id": event.event_id, + "redact_end_ordering": None, + }, + ) + + # normally the cache entry for a redacted event would be invalidated + # by an arriving redaction event, but since we are not creating redaction + # events we invalidate manually + self.store._invalidate_local_get_event_cache_room_id(event.room_id) + + self.store._invalidate_async_get_event_cache_room_id(event.room_id) if new_forward_extremities: self.store.get_latest_event_ids_in_room.prefill( (room_id,), frozenset(new_forward_extremities) ) + async def _can_sender_redact(self, event: EventBase) -> bool: + state_filter = StateFilter.from_types( + [(EventTypes.PowerLevels, ""), (EventTypes.Create, "")] + ) + state = await self.store.get_partial_filtered_current_state_ids( + event.room_id, state_filter + ) + pl_id = state[(EventTypes.PowerLevels, "")] + pl_event = await self.store.get_event(pl_id, allow_none=True) + + create_id = state[(EventTypes.Create, "")] + create_event = await self.store.get_event(create_id, allow_none=True) + + if create_event is None: + # not sure how this would happen but if it does then just deny the redaction + logger.warning("No create event found for room %s", event.room_id) + return False + + if create_event.room_version.msc4289_creator_power_enabled: + # per the spec, grant the creator infinite power level and all other users 0 + if is_creator(create_event, event.sender): + return True + if pl_event is None: + # per the spec, users other than the room creator have power level + # 0, which is less than the default to redact events (50). + return False + else: + # per the spec, if a power level event isn't in the room, grant the creator + # level 100 (the default redaction level is 50) and all other users 0 + if pl_event is None: + return create_event.sender == event.sender + + assert pl_event is not None + sender_level = pl_event.content.get("users", {}).get(event.sender) + if sender_level is None: + sender_level = pl_event.content.get("users_default", 0) + + redact_level = pl_event.content.get("redact") + if redact_level is None: + redact_level = pl_event.content.get("events_default", 0) + + room_redaction_level = pl_event.content.get("events", {}).get( + "m.room.redaction" + ) + if room_redaction_level is not None: + if sender_level < room_redaction_level: + return False + + if sender_level >= redact_level: + return True + + return False + async def _calculate_sliding_sync_table_changes( self, room_id: str, @@ -2720,7 +2867,7 @@ class PersistEventsStore: txn: LoggingTransaction, events_and_contexts: List[Tuple[EventBase, EventContext]], ) -> None: - to_prefill = [] + to_prefill: List[EventCacheEntry] = [] ev_map = {e.event_id: e for e, _ in events_and_contexts} if not ev_map: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index d9ef93f826..7f015aa22c 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -17,7 +17,7 @@ # [This file includes modifications made by New Vector Limited] # # - +import json import logging import threading import weakref @@ -68,6 +68,7 @@ from synapse.logging.opentracing import ( tag_args, trace, ) +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -138,6 +139,7 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events event_fetch_ongoing_gauge = Gauge( "synapse_event_fetch_ongoing", "The number of event fetchers that are running", + labelnames=[SERVER_NAME_LABEL], ) @@ -235,6 +237,7 @@ class EventsWorkerStore(SQLBaseStore): db=database, notifier=hs.get_replication_notifier(), stream_name="events", + server_name=self.server_name, instance_name=hs.get_instance_name(), tables=[ ("events", "instance_name", "stream_ordering"), @@ -249,6 +252,7 @@ class EventsWorkerStore(SQLBaseStore): db=database, notifier=hs.get_replication_notifier(), stream_name="backfill", + server_name=self.server_name, instance_name=hs.get_instance_name(), tables=[ ("events", "instance_name", "stream_ordering"), @@ -310,7 +314,9 @@ class EventsWorkerStore(SQLBaseStore): Tuple[Iterable[str], "defer.Deferred[Dict[str, _EventRow]]"] ] = [] self._event_fetch_ongoing = 0 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + event_fetch_ongoing_gauge.labels(**{SERVER_NAME_LABEL: self.server_name}).set( + self._event_fetch_ongoing + ) # We define this sequence here so that it can be referenced from both # the DataStore and PersistEventStore. @@ -334,6 +340,7 @@ class EventsWorkerStore(SQLBaseStore): db=database, notifier=hs.get_replication_notifier(), stream_name="un_partial_stated_event_stream", + server_name=self.server_name, instance_name=hs.get_instance_name(), tables=[("un_partial_stated_event_stream", "instance_name", "stream_id")], sequence_name="un_partial_stated_event_stream_sequence", @@ -364,6 +371,12 @@ class EventsWorkerStore(SQLBaseStore): replaces_index="event_txn_id_device_id_txn_id", ) + self._has_finished_sliding_sync_background_jobs = False + """ + Flag to track when the sliding sync background jobs have + finished (so we don't have to keep querying it every time) + """ + def get_un_partial_stated_events_token(self, instance_name: str) -> int: return ( self._un_partial_stated_events_stream_id_gen.get_current_token_for_writer( @@ -976,6 +989,13 @@ class EventsWorkerStore(SQLBaseStore): self._event_ref.clear() self._current_event_fetches.clear() + def _invalidate_async_get_event_cache_room_id(self, room_id: str) -> None: + """ + Clears the async get_event cache for a room. Currently a no-op until + an async get_event cache is implemented - see https://github.com/matrix-org/synapse/pull/13242 + for preliminary work. + """ + async def _get_events_from_cache( self, events: Iterable[str], update_metrics: bool = True ) -> Dict[str, EventCacheEntry]: @@ -1124,14 +1144,18 @@ class EventsWorkerStore(SQLBaseStore): and self._event_fetch_ongoing < EVENT_QUEUE_THREADS ): self._event_fetch_ongoing += 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + event_fetch_ongoing_gauge.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).set(self._event_fetch_ongoing) # `_event_fetch_ongoing` is decremented in `_fetch_thread`. should_start = True else: should_start = False if should_start: - run_as_background_process("fetch_events", self._fetch_thread) + run_as_background_process( + "fetch_events", self.server_name, self._fetch_thread + ) async def _fetch_thread(self) -> None: """Services requests for events from `_event_fetch_list`.""" @@ -1146,7 +1170,9 @@ class EventsWorkerStore(SQLBaseStore): event_fetches_to_fail = [] with self._event_fetch_lock: self._event_fetch_ongoing -= 1 - event_fetch_ongoing_gauge.set(self._event_fetch_ongoing) + event_fetch_ongoing_gauge.labels( + **{SERVER_NAME_LABEL: self.server_name} + ).set(self._event_fetch_ongoing) # There may still be work remaining in `_event_fetch_list` if we # failed, or it was added in between us deciding to exit and @@ -1575,6 +1601,44 @@ class EventsWorkerStore(SQLBaseStore): if d: d.redactions.append(redacter) + # check for MSC4932 redactions + to_check = [] + events: List[_EventRow] = [] + for e in evs: + event = event_dict.get(e) + if not event: + continue + events.append(event) + event_json = json.loads(event.json) + room_id = event_json.get("room_id") + user_id = event_json.get("sender") + to_check.append((room_id, user_id)) + + # likely that some of these events may be for the same room/user combo, in + # which case we don't need to do redundant queries + to_check_set = set(to_check) + for room_and_user in to_check_set: + room_redactions_sql = "SELECT redacting_event_id, redact_end_ordering FROM room_ban_redactions WHERE room_id = ? and user_id = ?" + txn.execute(room_redactions_sql, room_and_user) + + res = txn.fetchone() + # we have a redaction for a room, user_id combo - apply it to matching events + if not res: + continue + for e_row in events: + e_json = json.loads(e_row.json) + room_id = e_json.get("room_id") + user_id = e_json.get("sender") + if room_and_user != (room_id, user_id): + continue + redacting_event_id, redact_end_ordering = res + if redact_end_ordering: + # Avoid redacting any events arriving *after* the membership event which + # ends an active redaction - note that this will always redact + # backfilled events, as they have a negative stream ordering + if e_row.stream_ordering >= redact_end_ordering: + continue + e_row.redactions.append(redacting_event_id) return event_dict def _maybe_redact_event_row( @@ -2608,13 +2672,19 @@ class EventsWorkerStore(SQLBaseStore): async def have_finished_sliding_sync_background_jobs(self) -> bool: """Return if it's safe to use the sliding sync membership tables.""" - return await self.db_pool.updates.have_completed_background_updates( + if self._has_finished_sliding_sync_background_jobs: + # as an optimisation, once the job finishes, don't issue another + # database transaction to check it, since it won't 'un-finish' + return True + + self._has_finished_sliding_sync_background_jobs = await self.db_pool.updates.have_completed_background_updates( ( _BackgroundUpdates.SLIDING_SYNC_PREFILL_JOINED_ROOMS_TO_RECALCULATE_TABLE_BG_UPDATE, _BackgroundUpdates.SLIDING_SYNC_JOINED_ROOMS_BG_UPDATE, _BackgroundUpdates.SLIDING_SYNC_MEMBERSHIP_SNAPSHOTS_BG_UPDATE, ) ) + return self._has_finished_sliding_sync_background_jobs async def get_sent_invite_count_by_user(self, user_id: str, from_ts: int) -> int: """ diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index 8277ad8c33..e733f65cb1 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -24,9 +24,13 @@ from types import TracebackType from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type from weakref import WeakValueDictionary +from twisted.internet import defer from twisted.internet.task import LoopingCall -from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, + wrap_as_background_process, +) from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( DatabasePool, @@ -196,6 +200,7 @@ class LockStore(SQLBaseStore): return None lock = Lock( + self.server_name, self._reactor, self._clock, self, @@ -263,6 +268,7 @@ class LockStore(SQLBaseStore): ) lock = Lock( + self.server_name, self._reactor, self._clock, self, @@ -366,6 +372,7 @@ class Lock: def __init__( self, + server_name: str, reactor: ISynapseReactor, clock: Clock, store: LockStore, @@ -374,6 +381,11 @@ class Lock: lock_key: str, token: str, ) -> None: + """ + Args: + server_name: The homeserver name (used to label metrics) (this should be `hs.hostname`). + """ + self._server_name = server_name self._reactor = reactor self._clock = clock self._store = store @@ -396,6 +408,7 @@ class Lock: self._looping_call = self._clock.looping_call( self._renew, _RENEWAL_INTERVAL_MS, + self._server_name, self._store, self._clock, self._read_write, @@ -405,31 +418,55 @@ class Lock: ) @staticmethod - @wrap_as_background_process("Lock._renew") - async def _renew( + def _renew( + server_name: str, store: LockStore, clock: Clock, read_write: bool, lock_name: str, lock_key: str, token: str, - ) -> None: + ) -> "defer.Deferred[None]": """Renew the lock. Note: this is a static method, rather than using self.*, so that we don't end up with a reference to `self` in the reactor, which would stop this from being cleaned up if we dropped the context manager. + + Args: + server_name: The homeserver name (used to label metrics) (this should be `hs.hostname`). """ - table = "worker_read_write_locks" if read_write else "worker_locks" - await store.db_pool.simple_update( - table=table, - keyvalues={ - "lock_name": lock_name, - "lock_key": lock_key, - "token": token, - }, - updatevalues={"last_renewed_ts": clock.time_msec()}, - desc="renew_lock", + + async def _internal_renew( + store: LockStore, + clock: Clock, + read_write: bool, + lock_name: str, + lock_key: str, + token: str, + ) -> None: + table = "worker_read_write_locks" if read_write else "worker_locks" + await store.db_pool.simple_update( + table=table, + keyvalues={ + "lock_name": lock_name, + "lock_key": lock_key, + "token": token, + }, + updatevalues={"last_renewed_ts": clock.time_msec()}, + desc="renew_lock", + ) + + return run_as_background_process( + "Lock._renew", + server_name, + _internal_renew, + store, + clock, + read_write, + lock_name, + lock_key, + token, ) async def is_still_valid(self) -> bool: diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index 9ce1100b5c..a3467bff3d 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -23,7 +23,7 @@ import logging import time from typing import TYPE_CHECKING, Dict, List, Tuple, cast -from synapse.metrics import GaugeBucketCollector +from synapse.metrics import SERVER_NAME_LABEL, GaugeBucketCollector from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -42,9 +42,10 @@ logger = logging.getLogger(__name__) # Collect metrics on the number of forward extremities that exist. _extremities_collecter = GaugeBucketCollector( - "synapse_forward_extremities", - "Number of rooms on the server with the given number of forward extremities" + name="synapse_forward_extremities", + documentation="Number of rooms on the server with the given number of forward extremities" " or fewer", + labelnames=[SERVER_NAME_LABEL], buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500], ) @@ -54,9 +55,10 @@ _extremities_collecter = GaugeBucketCollector( # we could remove from state resolution by reducing the graph to a single # forward extremity. _excess_state_events_collecter = GaugeBucketCollector( - "synapse_excess_extremity_events", - "Number of rooms on the server with the given number of excess extremity " + name="synapse_excess_extremity_events", + documentation="Number of rooms on the server with the given number of excess extremity " "events, or fewer", + labelnames=[SERVER_NAME_LABEL], buckets=[0] + [1 << n for n in range(12)], ) @@ -100,10 +102,12 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): res = await self.db_pool.runInteraction("read_forward_extremities", fetch) - _extremities_collecter.update_data(x[0] for x in res) + _extremities_collecter.update_data( + values=(x[0] for x in res), labels=(self.server_name,) + ) _excess_state_events_collecter.update_data( - (x[0] - 1) * x[1] for x in res if x[1] + values=((x[0] - 1) * x[1] for x in res if x[1]), labels=(self.server_name,) ) async def count_daily_e2ee_messages(self) -> int: diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 12cff1d352..587f51df2c 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -91,6 +91,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore) db=database, notifier=hs.get_replication_notifier(), stream_name="presence_stream", + server_name=self.server_name, instance_name=self._instance_name, tables=[("presence_stream", "instance_name", "stream_id")], sequence_name="presence_stream_sequence", diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 3bc977d497..d686140556 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -110,6 +110,7 @@ def _load_rules( msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events, msc4210_enabled=experimental_config.msc4210_enabled, + msc4306_enabled=experimental_config.msc4306_enabled, ) return filtered_rules @@ -146,6 +147,7 @@ class PushRulesWorkerStore( db=database, notifier=hs.get_replication_notifier(), stream_name="push_rules_stream", + server_name=self.server_name, instance_name=self._instance_name, tables=[ ("push_rules_stream", "instance_name", "stream_id"), diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index a8a37b6c85..9a0a12b5c1 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -88,6 +88,7 @@ class PusherWorkerStore(SQLBaseStore): db=database, notifier=hs.get_replication_notifier(), stream_name="pushers", + server_name=self.server_name, instance_name=self._instance_name, tables=[ ("pushers", "instance_name", "id"), diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 16af68108d..d74bb0184a 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -124,6 +124,7 @@ class ReceiptsWorkerStore(SQLBaseStore): db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): + super().__init__(database, db_conn, hs) self._instance_name = hs.get_instance_name() # In the worker store this is an ID tracker which we overwrite in the non-worker @@ -138,6 +139,7 @@ class ReceiptsWorkerStore(SQLBaseStore): db_conn=db_conn, db=database, notifier=hs.get_replication_notifier(), + server_name=self.server_name, stream_name="receipts", instance_name=self._instance_name, tables=[("receipts_linearized", "instance_name", "stream_id")], @@ -145,8 +147,6 @@ class ReceiptsWorkerStore(SQLBaseStore): writers=hs.config.worker.writers.receipts, ) - super().__init__(database, db_conn, hs) - max_receipts_stream_id = self.get_max_receipt_stream_id() receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict( db_conn, diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 604365badf..6ffc3aed34 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -160,6 +160,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): db=database, notifier=hs.get_replication_notifier(), stream_name="un_partial_stated_room_stream", + server_name=self.server_name, instance_name=self._instance_name, tables=[("un_partial_stated_room_stream", "instance_name", "stream_id")], sequence_name="un_partial_stated_room_stream_sequence", diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index d7699ce6cd..67e7e99baa 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -43,7 +43,7 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.logging.opentracing import trace -from synapse.metrics import LaterGauge +from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import ( @@ -117,10 +117,10 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): self._count_known_servers, ) LaterGauge( - "synapse_federation_known_servers", - "", - [], - lambda: self._known_servers_count, + name="synapse_federation_known_servers", + desc="", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): self._known_servers_count}, ) @wrap_as_background_process("_count_known_servers") diff --git a/synapse/storage/databases/main/thread_subscriptions.py b/synapse/storage/databases/main/thread_subscriptions.py index e04e692e6a..24a99cf449 100644 --- a/synapse/storage/databases/main/thread_subscriptions.py +++ b/synapse/storage/databases/main/thread_subscriptions.py @@ -14,7 +14,7 @@ import logging from typing import ( TYPE_CHECKING, Any, - Dict, + FrozenSet, Iterable, List, Optional, @@ -33,6 +33,7 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.types import EventOrderings from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -50,6 +51,14 @@ class ThreadSubscription: """ +class AutomaticSubscriptionConflicted: + """ + Marker return value to signal that an automatic subscription was skipped, + because it conflicted with an unsubscription that we consider to have + been made later than the event causing the automatic subscription. + """ + + class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -69,6 +78,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): db=database, notifier=hs.get_replication_notifier(), stream_name="thread_subscriptions", + server_name=self.server_name, instance_name=self._instance_name, tables=[ ("thread_subscriptions", "instance_name", "stream_id"), @@ -90,6 +100,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): self.get_subscription_for_thread.invalidate( (row.user_id, row.room_id, row.event_id) ) + self.get_subscribers_to_thread.invalidate((row.room_id, row.event_id)) super().process_replication_rows(stream_name, instance_name, token, rows) @@ -100,75 +111,196 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): self._thread_subscriptions_id_gen.advance(instance_name, token) super().process_replication_position(stream_name, instance_name, token) + @staticmethod + def _should_skip_autosubscription_after_unsubscription( + *, + autosub: EventOrderings, + unsubscribed_at: EventOrderings, + ) -> bool: + """ + Returns whether an automatic subscription occurring *after* an unsubscription + should be skipped, because the unsubscription already 'acknowledges' the event + causing the automatic subscription (the cause event). + + To determine *after*, we use `stream_ordering` unless the event is backfilled + (negative `stream_ordering`) and fallback to topological ordering. + + Args: + autosub: the stream_ordering and topological_ordering of the cause event + unsubscribed_at: + the maximum stream ordering and the maximum topological ordering at the time of unsubscription + + Returns: + True if the automatic subscription should be skipped + """ + # For normal rooms, these two orderings should be positive, because + # they don't refer to a specific event but rather the maximum at the + # time of unsubscription. + # + # However, for rooms that have never been joined and that are being peeked at, + # we might not have a single non-backfilled event and therefore the stream + # ordering might be negative, so we don't assert this case. + assert unsubscribed_at.topological > 0 + + unsubscribed_at_backfilled = unsubscribed_at.stream < 0 + if ( + not unsubscribed_at_backfilled + and unsubscribed_at.stream >= autosub.stream > 0 + ): + # non-backfilled events: the unsubscription is later according to + # the stream + return True + + if autosub.stream < 0: + # the auto-subscription cause event was backfilled, so fall back to + # topological ordering + if unsubscribed_at.topological >= autosub.topological: + return True + + return False + async def subscribe_user_to_thread( - self, user_id: str, room_id: str, thread_root_event_id: str, *, automatic: bool - ) -> Optional[int]: + self, + user_id: str, + room_id: str, + thread_root_event_id: str, + *, + automatic_event_orderings: Optional[EventOrderings], + ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: """Updates a user's subscription settings for a specific thread root. If no change would be made to the subscription, does not produce any database change. + Case-by-case: + - if we already have an automatic subscription: + - new automatic subscriptions will be no-ops (no database write), + - new manual subscriptions will overwrite the automatic subscription + - if we already have a manual subscription: + we don't update (no database write) in either case, because: + - the existing manual subscription wins over a new automatic subscription request + - there would be no need to write a manual subscription because we already have one + Args: user_id: The ID of the user whose settings are being updated. room_id: The ID of the room the thread root belongs to. thread_root_event_id: The event ID of the thread root. - automatic: Whether the subscription was performed automatically by the user's client. - Only `False` will overwrite an existing value of automatic for a subscription row. + automatic_event_orderings: + Value depends on whether the subscription was performed automatically by the user's client. + For manual subscriptions: None. + For automatic subscriptions: the orderings of the event. Returns: - The stream ID for this update, if the update isn't no-opped. + If a subscription is made: (int) the stream ID for this update. + If a subscription already exists and did not need to be updated: None + If an automatic subscription conflicted with an unsubscription: AutomaticSubscriptionConflicted """ assert self._can_write_to_thread_subscriptions - def _subscribe_user_to_thread_txn(txn: LoggingTransaction) -> Optional[int]: - already_automatic = self.db_pool.simple_select_one_onecol_txn( - txn, - table="thread_subscriptions", - keyvalues={ - "user_id": user_id, - "event_id": thread_root_event_id, - "room_id": room_id, - "subscribed": True, - }, - retcol="automatic", - allow_none=True, - ) - - if already_automatic is None: - already_subscribed = False - already_automatic = True - else: - already_subscribed = True - # convert int (SQLite bool) to Python bool - already_automatic = bool(already_automatic) - - if already_subscribed and already_automatic == automatic: - # there is nothing we need to do here - return None - - stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) - - values: Dict[str, Optional[Union[bool, int, str]]] = { - "subscribed": True, - "stream_id": stream_id, - "instance_name": self._instance_name, - "automatic": already_automatic and automatic, - } - - self.db_pool.simple_upsert_txn( - txn, - table="thread_subscriptions", - keyvalues={ - "user_id": user_id, - "event_id": thread_root_event_id, - "room_id": room_id, - }, - values=values, - ) - + def _invalidate_subscription_caches(txn: LoggingTransaction) -> None: txn.call_after( self.get_subscription_for_thread.invalidate, (user_id, room_id, thread_root_event_id), ) + txn.call_after( + self.get_subscribers_to_thread.invalidate, + (room_id, thread_root_event_id), + ) + + def _subscribe_user_to_thread_txn( + txn: LoggingTransaction, + ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: + requested_automatic = automatic_event_orderings is not None + + row = self.db_pool.simple_select_one_txn( + txn, + table="thread_subscriptions", + keyvalues={ + "user_id": user_id, + "event_id": thread_root_event_id, + "room_id": room_id, + }, + retcols=( + "subscribed", + "automatic", + "unsubscribed_at_stream_ordering", + "unsubscribed_at_topological_ordering", + ), + allow_none=True, + ) + + if row is None: + # We have never subscribed before, simply insert the row and finish + stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) + self.db_pool.simple_insert_txn( + txn, + table="thread_subscriptions", + values={ + "user_id": user_id, + "event_id": thread_root_event_id, + "room_id": room_id, + "subscribed": True, + "stream_id": stream_id, + "instance_name": self._instance_name, + "automatic": requested_automatic, + "unsubscribed_at_stream_ordering": None, + "unsubscribed_at_topological_ordering": None, + }, + ) + _invalidate_subscription_caches(txn) + return stream_id + + # we already have either a subscription or a prior unsubscription here + ( + subscribed, + already_automatic, + unsubscribed_at_stream_ordering, + unsubscribed_at_topological_ordering, + ) = row + + if subscribed and (not already_automatic or requested_automatic): + # we are already subscribed and the current subscription state + # is good enough (either we already have a manual subscription, + # or we requested an automatic subscription) + # In that case, nothing to change here. + # (See docstring for case-by-case explanation) + return None + + if not subscribed and requested_automatic: + assert automatic_event_orderings is not None + # we previously unsubscribed and we are now automatically subscribing + # Check whether the new autosubscription should be skipped + if ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription( + autosub=automatic_event_orderings, + unsubscribed_at=EventOrderings( + unsubscribed_at_stream_ordering, + unsubscribed_at_topological_ordering, + ), + ): + # skip the subscription + return AutomaticSubscriptionConflicted() + + # At this point: we have now finished checking that we need to make + # a subscription, updating the current row. + + stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) + self.db_pool.simple_update_txn( + txn, + table="thread_subscriptions", + keyvalues={ + "user_id": user_id, + "event_id": thread_root_event_id, + "room_id": room_id, + }, + updatevalues={ + "subscribed": True, + "stream_id": stream_id, + "instance_name": self._instance_name, + "automatic": requested_automatic, + "unsubscribed_at_stream_ordering": None, + "unsubscribed_at_topological_ordering": None, + }, + ) + _invalidate_subscription_caches(txn) return stream_id @@ -213,6 +345,21 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) + # Find the maximum stream ordering and topological ordering of the room, + # which we then store against this unsubscription so we can skip future + # automatic subscriptions that are caused by an event logically earlier + # than this unsubscription. + txn.execute( + """ + SELECT MAX(stream_ordering) AS mso, MAX(topological_ordering) AS mto FROM events + WHERE room_id = ? + """, + (room_id,), + ) + ord_row = txn.fetchone() + assert ord_row is not None + max_stream_ordering, max_topological_ordering = ord_row + self.db_pool.simple_update_txn( txn, table="thread_subscriptions", @@ -226,6 +373,8 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): "subscribed": False, "stream_id": stream_id, "instance_name": self._instance_name, + "unsubscribed_at_stream_ordering": max_stream_ordering, + "unsubscribed_at_topological_ordering": max_topological_ordering, }, ) @@ -233,6 +382,10 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): self.get_subscription_for_thread.invalidate, (user_id, room_id, thread_root_event_id), ) + txn.call_after( + self.get_subscribers_to_thread.invalidate, + (room_id, thread_root_event_id), + ) return stream_id @@ -245,7 +398,9 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): Purge all subscriptions for the user. The fact that subscriptions have been purged will not be streamed; all stream rows for the user will in fact be removed. - This is intended only for dealing with user deactivation. + + This must only be used for user deactivation, + because it does not invalidate the `subscribers_to_thread` cache. """ def _purge_thread_subscription_settings_for_user_txn( @@ -306,6 +461,42 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): return ThreadSubscription(automatic=automatic) + # max_entries=100 rationale: + # this returns a potentially large datastructure + # (since each entry contains a set which contains a potentially large number of user IDs), + # whereas the default of 10'000 entries for @cached feels more + # suitable for very small cache entries. + # + # Overall, when bearing in mind the usual profile of a small community-server or company-server + # (where cache tuning hasn't been done, so we're in out-of-box configuration), it is very + # unlikely we would benefit from keeping hot the subscribers for as many as 100 threads, + # since it's unlikely that so many threads will be active in a short span of time on a small homeserver. + # It feels that medium servers will probably also not exhaust this limit. + # Larger homeservers are more likely to be carefully tuned, either with a larger global cache factor + # or carefully following the usage patterns & cache metrics. + # Finally, the query is not so intensive that computing it every time is a huge deal, but given people + # often send messages back-to-back in the same thread it seems like it would offer a mild benefit. + @cached(max_entries=100) + async def get_subscribers_to_thread( + self, room_id: str, thread_root_event_id: str + ) -> FrozenSet[str]: + """ + Returns: + the set of user_ids for local users who are subscribed to the given thread. + """ + return frozenset( + await self.db_pool.simple_select_onecol( + table="thread_subscriptions", + keyvalues={ + "room_id": room_id, + "event_id": thread_root_event_id, + "subscribed": True, + }, + retcol="user_id", + desc="get_subscribers_to_thread", + ) + ) + def get_max_thread_subscriptions_stream_id(self) -> int: """Get the current maximum stream_id for thread subscriptions. @@ -315,7 +506,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): return self._thread_subscriptions_id_gen.get_current_token() async def get_updated_thread_subscriptions( - self, from_id: int, to_id: int, limit: int + self, *, from_id: int, to_id: int, limit: int ) -> List[Tuple[int, str, str, str]]: """Get updates to thread subscriptions between two stream IDs. @@ -348,7 +539,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): ) async def get_updated_thread_subscriptions_for_user( - self, user_id: str, from_id: int, to_id: int, limit: int + self, user_id: str, *, from_id: int, to_id: int, limit: int ) -> List[Tuple[int, str, str]]: """Get updates to thread subscriptions for a specific user. diff --git a/synapse/storage/schema/main/delta/92/08_room_ban_redactions.sql b/synapse/storage/schema/main/delta/92/08_room_ban_redactions.sql new file mode 100644 index 0000000000..566ddcbdd7 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/08_room_ban_redactions.sql @@ -0,0 +1,21 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +CREATE TABLE room_ban_redactions( + room_id text NOT NULL, + user_id text NOT NULL, + redacting_event_id text NOT NULL, + redact_end_ordering bigint DEFAULT NULL, -- stream ordering after which redactions are not applied + CONSTRAINT room_ban_redaction_uniqueness UNIQUE (room_id, user_id) +); + diff --git a/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql b/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql new file mode 100644 index 0000000000..03b8a1a635 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql @@ -0,0 +1,20 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- The maximum stream_ordering in the room when the unsubscription was made. +ALTER TABLE thread_subscriptions + ADD COLUMN unsubscribed_at_stream_ordering BIGINT; + +-- The maximum topological_ordering in the room when the unsubscription was made. +ALTER TABLE thread_subscriptions + ADD COLUMN unsubscribed_at_topological_ordering BIGINT; diff --git a/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql.postgres b/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql.postgres new file mode 100644 index 0000000000..fc5d555db5 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/09_thread_subscriptions_update.sql.postgres @@ -0,0 +1,18 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_stream_ordering IS + $$The maximum stream_ordering in the room when the unsubscription was made.$$; + +COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_topological_ordering IS + $$The maximum topological_ordering in the room when the unsubscription was made.$$; diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 026a0517d2..a15a161ce8 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -195,6 +195,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): db stream_name: A name for the stream, for use in the `stream_positions` table. (Does not need to be the same as the replication stream name) + server_name: The homeserver name of the server (used to label metrics) + (this should be `hs.hostname`). instance_name: The name of this instance. tables: List of tables associated with the stream. Tuple of table name, column name that stores the writer's instance name, and @@ -210,10 +212,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): def __init__( self, + *, db_conn: LoggingDatabaseConnection, db: DatabasePool, notifier: "ReplicationNotifier", stream_name: str, + server_name: str, instance_name: str, tables: List[Tuple[str, str, str]], sequence_name: str, @@ -223,6 +227,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): self._db = db self._notifier = notifier self._stream_name = stream_name + self.server_name = server_name self._instance_name = instance_name self._positive = positive self._writers = writers @@ -561,6 +566,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): txn.call_after( run_as_background_process, "MultiWriterIdGenerator._update_table", + self.server_name, self._db.runInteraction, "MultiWriterIdGenerator._update_table", self._update_stream_positions_table_txn, @@ -597,6 +603,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): txn.call_after( run_as_background_process, "MultiWriterIdGenerator._update_table", + self.server_name, self._db.runInteraction, "MultiWriterIdGenerator._update_table", self._update_stream_positions_table_txn, diff --git a/synapse/synapse_rust/events.pyi b/synapse/synapse_rust/events.pyi index 7d3422572d..a82211283b 100644 --- a/synapse/synapse_rust/events.pyi +++ b/synapse/synapse_rust/events.pyi @@ -33,6 +33,9 @@ class EventInternalMetadata: proactively_send: bool redacted: bool + policy_server_spammy: bool + """whether the policy server indicated that this event is spammy""" + txn_id: str """The transaction ID, if it was set when the event was created.""" token_id: int diff --git a/synapse/synapse_rust/http_client.pyi b/synapse/synapse_rust/http_client.pyi index cdc501e606..9fb7831e6b 100644 --- a/synapse/synapse_rust/http_client.pyi +++ b/synapse/synapse_rust/http_client.pyi @@ -10,17 +10,19 @@ # See the GNU Affero General Public License for more details: # . -from typing import Awaitable, Mapping +from typing import Mapping + +from twisted.internet.defer import Deferred from synapse.types import ISynapseReactor class HttpClient: def __init__(self, reactor: ISynapseReactor, user_agent: str) -> None: ... - def get(self, url: str, response_limit: int) -> Awaitable[bytes]: ... + def get(self, url: str, response_limit: int) -> Deferred[bytes]: ... def post( self, url: str, response_limit: int, headers: Mapping[str, str], request_body: str, - ) -> Awaitable[bytes]: ... + ) -> Deferred[bytes]: ... diff --git a/synapse/synapse_rust/push.pyi b/synapse/synapse_rust/push.pyi index 3f317c3288..a3e12ad648 100644 --- a/synapse/synapse_rust/push.pyi +++ b/synapse/synapse_rust/push.pyi @@ -49,6 +49,7 @@ class FilteredPushRules: msc3664_enabled: bool, msc4028_push_encrypted_events: bool, msc4210_enabled: bool, + msc4306_enabled: bool, ): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... @@ -67,13 +68,19 @@ class PushRuleEvaluator: room_version_feature_flags: Tuple[str, ...], msc3931_enabled: bool, msc4210_enabled: bool, + msc4306_enabled: bool, ): ... def run( self, push_rules: FilteredPushRules, user_id: Optional[str], display_name: Optional[str], + msc4306_thread_subscription_state: Optional[bool], ) -> Collection[Union[Mapping, str]]: ... def matches( - self, condition: JsonDict, user_id: Optional[str], display_name: Optional[str] + self, + condition: JsonDict, + user_id: Optional[str], + display_name: Optional[str], + msc4306_thread_subscription_state: Optional[bool] = None, ) -> bool: ... diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 914bb6cb23..943f211b11 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -73,6 +73,7 @@ if TYPE_CHECKING: from typing_extensions import Self from synapse.appservice.api import ApplicationService + from synapse.events import EventBase from synapse.storage.databases.main import DataStore, PurgeEventsStore from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.util.id_generators import MultiWriterIdGenerator @@ -1530,3 +1531,31 @@ class ScheduledTask: result: Optional[JsonMapping] # Optional error that should be assigned a value when the status is FAILED error: Optional[str] + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class EventOrderings: + stream: int + """ + The stream_ordering of the event. + Negative numbers mean the event was backfilled. + """ + + topological: int + """ + The topological_ordering of the event. + Currently this is equivalent to the `depth` attributes of + the PDU. + """ + + @staticmethod + def from_event(event: "EventBase") -> "EventOrderings": + """ + Get the orderings from an event. + + Preconditions: + - the event must have been persisted (otherwise it won't have a stream ordering) + """ + stream = event.internal_metadata.stream_ordering + assert stream is not None + return EventOrderings(stream, event.depth) diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index bd4d20accb..36129c3a67 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -27,7 +27,6 @@ from typing import ( Any, Callable, Dict, - Generator, Iterator, Mapping, Optional, @@ -42,7 +41,6 @@ from matrix_common.versionstring import get_distribution_version_string from typing_extensions import ParamSpec from twisted.internet import defer, task -from twisted.internet.defer import Deferred from twisted.internet.interfaces import IDelayedCall, IReactorTime from twisted.internet.task import LoopingCall from twisted.python.failure import Failure @@ -121,13 +119,11 @@ class Clock: _reactor: IReactorTime = attr.ib() - @defer.inlineCallbacks - def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]": + async def sleep(self, seconds: float) -> None: d: defer.Deferred[float] = defer.Deferred() with context.PreserveLoggingContext(): self._reactor.callLater(seconds, d.callback, seconds) - res = yield d - return res + await d def time(self) -> float: """Returns the current system time in seconds since epoch.""" diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py index 3fb697751f..4c0f129423 100644 --- a/synapse/util/batching_queue.py +++ b/synapse/util/batching_queue.py @@ -37,6 +37,7 @@ from prometheus_client import Gauge from twisted.internet import defer from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable +from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util import Clock @@ -49,19 +50,19 @@ R = TypeVar("R") number_queued = Gauge( "synapse_util_batching_queue_number_queued", "The number of items waiting in the queue across all keys", - labelnames=("name",), + labelnames=("name", SERVER_NAME_LABEL), ) number_in_flight = Gauge( "synapse_util_batching_queue_number_pending", "The number of items across all keys either being processed or waiting in a queue", - labelnames=("name",), + labelnames=("name", SERVER_NAME_LABEL), ) number_of_keys = Gauge( "synapse_util_batching_queue_number_of_keys", "The number of distinct keys that have items queued", - labelnames=("name",), + labelnames=("name", SERVER_NAME_LABEL), ) @@ -85,6 +86,8 @@ class BatchingQueue(Generic[V, R]): Args: name: A name for the queue, used for logging contexts and metrics. This must be unique, otherwise the metrics will be wrong. + server_name: The homeserver name of the server (used to label metrics) + (this should be `hs.hostname`). clock: The clock to use to schedule work. process_batch_callback: The callback to to be run to process a batch of work. @@ -92,11 +95,14 @@ class BatchingQueue(Generic[V, R]): def __init__( self, + *, name: str, + server_name: str, clock: Clock, process_batch_callback: Callable[[List[V]], Awaitable[R]], ): self._name = name + self.server_name = server_name self._clock = clock # The set of keys currently being processed. @@ -109,14 +115,18 @@ class BatchingQueue(Generic[V, R]): # The function to call with batches of values. self._process_batch_callback = process_batch_callback - number_queued.labels(self._name).set_function( - lambda: sum(len(q) for q in self._next_values.values()) + number_queued.labels( + name=self._name, **{SERVER_NAME_LABEL: self.server_name} + ).set_function(lambda: sum(len(q) for q in self._next_values.values())) + + number_of_keys.labels( + name=self._name, **{SERVER_NAME_LABEL: self.server_name} + ).set_function(lambda: len(self._next_values)) + + self._number_in_flight_metric: Gauge = number_in_flight.labels( + name=self._name, **{SERVER_NAME_LABEL: self.server_name} ) - number_of_keys.labels(self._name).set_function(lambda: len(self._next_values)) - - self._number_in_flight_metric: Gauge = number_in_flight.labels(self._name) - async def add_to_queue(self, value: V, key: Hashable = ()) -> R: """Adds the value to the queue with the given key, returning the result of the processing function for the batch that included the given value. @@ -135,7 +145,9 @@ class BatchingQueue(Generic[V, R]): # If we're not currently processing the key fire off a background # process to start processing. if key not in self._processing_keys: - run_as_background_process(self._name, self._process_queue, key) + run_as_background_process( + self._name, self.server_name, self._process_queue, key + ) with self._number_in_flight_metric.track_inprogress(): return await make_deferred_yieldable(d) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 3087ad6adc..710a29e3f0 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -42,9 +42,10 @@ TRACK_MEMORY_USAGE = False # We track cache metrics in a special registry that lets us update the metrics # just before they are returned from the scrape endpoint. -CACHE_METRIC_REGISTRY = DynamicCollectorRegistry() - -caches_by_name: Dict[str, Sized] = {} +# +# The `SERVER_NAME_LABEL` is included in the individual metrics added to this registry, +# so we don't need to worry about it on the collector itself. +CACHE_METRIC_REGISTRY = DynamicCollectorRegistry() # type: ignore[missing-server-name-label] cache_size = Gauge( "synapse_util_caches_cache_size", @@ -242,8 +243,7 @@ def register_cache( server_name=server_name, collect_callback=collect_callback, ) - metric_name = "cache_%s_%s" % (cache_type, cache_name) - caches_by_name[cache_name] = cache + metric_name = "cache_%s_%s_%s" % (cache_type, cache_name, server_name) CACHE_METRIC_REGISTRY.register_hook(metric_name, metric.collect) return metric diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 0c6c912918..92d446ce2a 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -43,6 +43,7 @@ from prometheus_client import Gauge from twisted.internet import defer from twisted.python.failure import Failure +from synapse.metrics import SERVER_NAME_LABEL from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -50,7 +51,7 @@ from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry cache_pending_metric = Gauge( "synapse_util_caches_cache_pending", "Number of lookups currently pending for this cache", - ["name"], + labelnames=["name", SERVER_NAME_LABEL], ) T = TypeVar("T") @@ -111,7 +112,9 @@ class DeferredCache(Generic[KT, VT]): ] = cache_type() def metrics_cb() -> None: - cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) + cache_pending_metric.labels( + name=name, **{SERVER_NAME_LABEL: server_name} + ).set(len(self._pending_deferred_cache)) # cache is used for completed results and maps to the result itself, rather than # a Deferred. diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 4be4c6f01b..1962a3fdfa 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -99,7 +99,9 @@ class ExpiringCache(Generic[KT, VT]): return def f() -> "defer.Deferred[None]": - return run_as_background_process("prune_cache", self._prune_cache) + return run_as_background_process( + "prune_cache", server_name, self._prune_cache + ) self._clock.looping_call(f, self._expiry_ms / 2) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 466362e79c..927162700a 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -45,11 +45,13 @@ from typing import ( overload, ) -from twisted.internet import reactor +from twisted.internet import defer, reactor from twisted.internet.interfaces import IReactorTime from synapse.config import cache as cache_config -from synapse.metrics.background_process_metrics import wrap_as_background_process +from synapse.metrics.background_process_metrics import ( + run_as_background_process, +) from synapse.metrics.jemalloc import get_jemalloc_stats from synapse.util import Clock, caches from synapse.util.caches import CacheMetric, EvictionReason, register_cache @@ -118,103 +120,121 @@ USE_GLOBAL_LIST = False GLOBAL_ROOT = ListNode["_Node"].create_root_node() -@wrap_as_background_process("LruCache._expire_old_entries") -async def _expire_old_entries( - clock: Clock, expiry_seconds: float, autotune_config: Optional[dict] -) -> None: +def _expire_old_entries( + server_name: str, + clock: Clock, + expiry_seconds: float, + autotune_config: Optional[dict], +) -> "defer.Deferred[None]": """Walks the global cache list to find cache entries that haven't been accessed in the given number of seconds, or if a given memory threshold has been breached. """ - if autotune_config: - max_cache_memory_usage = autotune_config["max_cache_memory_usage"] - target_cache_memory_usage = autotune_config["target_cache_memory_usage"] - min_cache_ttl = autotune_config["min_cache_ttl"] / 1000 - now = int(clock.time()) - node = GLOBAL_ROOT.prev_node - assert node is not None + async def _internal_expire_old_entries( + clock: Clock, expiry_seconds: float, autotune_config: Optional[dict] + ) -> None: + if autotune_config: + max_cache_memory_usage = autotune_config["max_cache_memory_usage"] + target_cache_memory_usage = autotune_config["target_cache_memory_usage"] + min_cache_ttl = autotune_config["min_cache_ttl"] / 1000 - i = 0 + now = int(clock.time()) + node = GLOBAL_ROOT.prev_node + assert node is not None - logger.debug("Searching for stale caches") + i = 0 - evicting_due_to_memory = False + logger.debug("Searching for stale caches") - # determine if we're evicting due to memory - jemalloc_interface = get_jemalloc_stats() - if jemalloc_interface and autotune_config: - try: - jemalloc_interface.refresh_stats() - mem_usage = jemalloc_interface.get_stat("allocated") - if mem_usage > max_cache_memory_usage: - logger.info("Begin memory-based cache eviction.") - evicting_due_to_memory = True - except Exception: - logger.warning( - "Unable to read allocated memory, skipping memory-based cache eviction." - ) + evicting_due_to_memory = False - while node is not GLOBAL_ROOT: - # Only the root node isn't a `_TimedListNode`. - assert isinstance(node, _TimedListNode) - - # if node has not aged past expiry_seconds and we are not evicting due to memory usage, there's - # nothing to do here - if ( - node.last_access_ts_secs > now - expiry_seconds - and not evicting_due_to_memory - ): - break - - # if entry is newer than min_cache_entry_ttl then do not evict and don't evict anything newer - if evicting_due_to_memory and now - node.last_access_ts_secs < min_cache_ttl: - break - - cache_entry = node.get_cache_entry() - next_node = node.prev_node - - # The node should always have a reference to a cache entry and a valid - # `prev_node`, as we only drop them when we remove the node from the - # list. - assert next_node is not None - assert cache_entry is not None - cache_entry.drop_from_cache() - - # Check mem allocation periodically if we are evicting a bunch of caches - if jemalloc_interface and evicting_due_to_memory and (i + 1) % 100 == 0: + # determine if we're evicting due to memory + jemalloc_interface = get_jemalloc_stats() + if jemalloc_interface and autotune_config: try: jemalloc_interface.refresh_stats() mem_usage = jemalloc_interface.get_stat("allocated") - if mem_usage < target_cache_memory_usage: - evicting_due_to_memory = False - logger.info("Stop memory-based cache eviction.") + if mem_usage > max_cache_memory_usage: + logger.info("Begin memory-based cache eviction.") + evicting_due_to_memory = True except Exception: logger.warning( - "Unable to read allocated memory, this may affect memory-based cache eviction." + "Unable to read allocated memory, skipping memory-based cache eviction." ) - # If we've failed to read the current memory usage then we - # should stop trying to evict based on memory usage - evicting_due_to_memory = False - # If we do lots of work at once we yield to allow other stuff to happen. - if (i + 1) % 10000 == 0: - logger.debug("Waiting during drop") - if node.last_access_ts_secs > now - expiry_seconds: - await clock.sleep(0.5) - else: - await clock.sleep(0) - logger.debug("Waking during drop") + while node is not GLOBAL_ROOT: + # Only the root node isn't a `_TimedListNode`. + assert isinstance(node, _TimedListNode) - node = next_node + # if node has not aged past expiry_seconds and we are not evicting due to memory usage, there's + # nothing to do here + if ( + node.last_access_ts_secs > now - expiry_seconds + and not evicting_due_to_memory + ): + break - # If we've yielded then our current node may have been evicted, so we - # need to check that its still valid. - if node.prev_node is None: - break + # if entry is newer than min_cache_entry_ttl then do not evict and don't evict anything newer + if ( + evicting_due_to_memory + and now - node.last_access_ts_secs < min_cache_ttl + ): + break - i += 1 + cache_entry = node.get_cache_entry() + next_node = node.prev_node - logger.info("Dropped %d items from caches", i) + # The node should always have a reference to a cache entry and a valid + # `prev_node`, as we only drop them when we remove the node from the + # list. + assert next_node is not None + assert cache_entry is not None + cache_entry.drop_from_cache() + + # Check mem allocation periodically if we are evicting a bunch of caches + if jemalloc_interface and evicting_due_to_memory and (i + 1) % 100 == 0: + try: + jemalloc_interface.refresh_stats() + mem_usage = jemalloc_interface.get_stat("allocated") + if mem_usage < target_cache_memory_usage: + evicting_due_to_memory = False + logger.info("Stop memory-based cache eviction.") + except Exception: + logger.warning( + "Unable to read allocated memory, this may affect memory-based cache eviction." + ) + # If we've failed to read the current memory usage then we + # should stop trying to evict based on memory usage + evicting_due_to_memory = False + + # If we do lots of work at once we yield to allow other stuff to happen. + if (i + 1) % 10000 == 0: + logger.debug("Waiting during drop") + if node.last_access_ts_secs > now - expiry_seconds: + await clock.sleep(0.5) + else: + await clock.sleep(0) + logger.debug("Waking during drop") + + node = next_node + + # If we've yielded then our current node may have been evicted, so we + # need to check that its still valid. + if node.prev_node is None: + break + + i += 1 + + logger.info("Dropped %d items from caches", i) + + return run_as_background_process( + "LruCache._expire_old_entries", + server_name, + _internal_expire_old_entries, + clock, + expiry_seconds, + autotune_config, + ) def setup_expire_lru_cache_entries(hs: "HomeServer") -> None: @@ -234,10 +254,12 @@ def setup_expire_lru_cache_entries(hs: "HomeServer") -> None: global USE_GLOBAL_LIST USE_GLOBAL_LIST = True + server_name = hs.hostname clock = hs.get_clock() clock.looping_call( _expire_old_entries, 30 * 1000, + server_name, clock, expiry_time, hs.config.caches.cache_autotuning, diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 95786bd3dd..f48ae3373c 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -58,7 +58,13 @@ class Distributor: model will do for today. """ - def __init__(self) -> None: + def __init__(self, server_name: str) -> None: + """ + Args: + server_name: The homeserver name of the server (used to label metrics) + (this should be `hs.hostname`). + """ + self.server_name = server_name self.signals: Dict[str, Signal] = {} self.pre_registration: Dict[str, List[Callable]] = {} @@ -91,7 +97,9 @@ class Distributor: if name not in self.signals: raise KeyError("%r does not have a signal named %s" % (self, name)) - run_as_background_process(name, self.signals[name].fire, *args, **kwargs) + run_as_background_process( + name, self.server_name, self.signals[name].fire, *args, **kwargs + ) P = ParamSpec("P") diff --git a/synapse/util/pydantic_models.py b/synapse/util/pydantic_models.py index ba9e7bb7d5..4880709501 100644 --- a/synapse/util/pydantic_models.py +++ b/synapse/util/pydantic_models.py @@ -13,7 +13,11 @@ # # -from synapse._pydantic_compat import BaseModel, Extra +import re +from typing import Any, Callable, Generator + +from synapse._pydantic_compat import BaseModel, Extra, StrictStr +from synapse.types import EventID class ParseModel(BaseModel): @@ -37,3 +41,43 @@ class ParseModel(BaseModel): extra = Extra.ignore # By default, don't allow fields to be reassigned after parsing. allow_mutation = False + + +class AnyEventId(StrictStr): + """ + A validator for strings that need to be an Event ID. + + Accepts any valid grammar of Event ID from any room version. + """ + + EVENT_ID_HASH_ROOM_VERSION_3_PLUS = re.compile( + r"^([a-zA-Z0-9-_]{43}|[a-zA-Z0-9+/]{43})$" + ) + + @classmethod + def __get_validators__(cls) -> Generator[Callable[..., Any], Any, Any]: + yield from super().__get_validators__() # type: ignore + yield cls.validate_event_id + + @classmethod + def validate_event_id(cls, value: str) -> str: + if not value.startswith("$"): + raise ValueError("Event ID must start with `$`") + + if ":" in value: + # Room versions 1 and 2 + EventID.from_string(value) # throws on fail + else: + # Room versions 3+: event ID is $ + a base64 sha256 hash + # Room version 3 is base64, 4+ are base64Url + # In both cases, the base64 is unpadded. + # refs: + # - https://spec.matrix.org/v1.15/rooms/v3/ e.g. $acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk + # - https://spec.matrix.org/v1.15/rooms/v4/ e.g. $Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg + b64_hash = value[1:] + if cls.EVENT_ID_HASH_ROOM_VERSION_3_PLUS.fullmatch(b64_hash) is None: + raise ValueError( + "Event ID must either have a domain part or be a valid hash" + ) + + return value diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 3f067b792c..f5e592d80e 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -52,7 +52,7 @@ from synapse.logging.context import ( run_in_background, ) from synapse.logging.opentracing import start_active_span -from synapse.metrics import Histogram, LaterGauge +from synapse.metrics import SERVER_NAME_LABEL, Histogram, LaterGauge from synapse.util import Clock if typing.TYPE_CHECKING: @@ -65,17 +65,17 @@ logger = logging.getLogger(__name__) rate_limit_sleep_counter = Counter( "synapse_rate_limit_sleep", "Number of requests slept by the rate limiter", - ["rate_limiter_name"], + labelnames=["rate_limiter_name", SERVER_NAME_LABEL], ) rate_limit_reject_counter = Counter( "synapse_rate_limit_reject", "Number of requests rejected by the rate limiter", - ["rate_limiter_name"], + labelnames=["rate_limiter_name", SERVER_NAME_LABEL], ) queue_wait_timer = Histogram( "synapse_rate_limit_queue_wait_time_seconds", "Amount of time spent waiting for the rate limiter to let our request through.", - ["rate_limiter_name"], + labelnames=["rate_limiter_name", SERVER_NAME_LABEL], buckets=( 0.005, 0.01, @@ -119,7 +119,10 @@ def _get_counts_from_rate_limiter_instance( # Only track metrics if they provided a `metrics_name` to # differentiate this instance of the rate limiter. if rate_limiter_instance.metrics_name: - key = (rate_limiter_instance.metrics_name,) + key = ( + rate_limiter_instance.metrics_name, + rate_limiter_instance.our_server_name, + ) counts[key] = count_func(rate_limiter_instance) return counts @@ -129,10 +132,10 @@ def _get_counts_from_rate_limiter_instance( # differentiate one really noisy homeserver from a general # ratelimit tuning problem across the federation. LaterGauge( - "synapse_rate_limit_sleep_affected_hosts", - "Number of hosts that had requests put to sleep", - ["rate_limiter_name"], - lambda: _get_counts_from_rate_limiter_instance( + name="synapse_rate_limit_sleep_affected_hosts", + desc="Number of hosts that had requests put to sleep", + labelnames=["rate_limiter_name", SERVER_NAME_LABEL], + caller=lambda: _get_counts_from_rate_limiter_instance( lambda rate_limiter_instance: sum( ratelimiter.should_sleep() for ratelimiter in rate_limiter_instance.ratelimiters.values() @@ -140,10 +143,10 @@ LaterGauge( ), ) LaterGauge( - "synapse_rate_limit_reject_affected_hosts", - "Number of hosts that had requests rejected", - ["rate_limiter_name"], - lambda: _get_counts_from_rate_limiter_instance( + name="synapse_rate_limit_reject_affected_hosts", + desc="Number of hosts that had requests rejected", + labelnames=["rate_limiter_name", SERVER_NAME_LABEL], + caller=lambda: _get_counts_from_rate_limiter_instance( lambda rate_limiter_instance: sum( ratelimiter.should_reject() for ratelimiter in rate_limiter_instance.ratelimiters.values() @@ -157,6 +160,7 @@ class FederationRateLimiter: def __init__( self, + our_server_name: str, clock: Clock, config: FederationRatelimitSettings, metrics_name: Optional[str] = None, @@ -170,11 +174,15 @@ class FederationRateLimiter: for this rate limiter. """ + self.our_server_name = our_server_name self.metrics_name = metrics_name def new_limiter() -> "_PerHostRatelimiter": return _PerHostRatelimiter( - clock=clock, config=config, metrics_name=metrics_name + our_server_name=our_server_name, + clock=clock, + config=config, + metrics_name=metrics_name, ) self.ratelimiters: DefaultDict[str, "_PerHostRatelimiter"] = ( @@ -205,6 +213,7 @@ class FederationRateLimiter: class _PerHostRatelimiter: def __init__( self, + our_server_name: str, clock: Clock, config: FederationRatelimitSettings, metrics_name: Optional[str] = None, @@ -218,6 +227,7 @@ class _PerHostRatelimiter: for this rate limiter. from the rest in the metrics """ + self.our_server_name = our_server_name self.clock = clock self.metrics_name = metrics_name @@ -279,7 +289,10 @@ class _PerHostRatelimiter: async def _on_enter_with_tracing(self, request_id: object) -> None: maybe_metrics_cm: ContextManager = contextlib.nullcontext() if self.metrics_name: - maybe_metrics_cm = queue_wait_timer.labels(self.metrics_name).time() + maybe_metrics_cm = queue_wait_timer.labels( + rate_limiter_name=self.metrics_name, + **{SERVER_NAME_LABEL: self.our_server_name}, + ).time() with start_active_span("ratelimit wait"), maybe_metrics_cm: await self._on_enter(request_id) @@ -296,7 +309,10 @@ class _PerHostRatelimiter: if self.should_reject(): logger.debug("Ratelimiter(%s): rejecting request", self.host) if self.metrics_name: - rate_limit_reject_counter.labels(self.metrics_name).inc() + rate_limit_reject_counter.labels( + rate_limiter_name=self.metrics_name, + **{SERVER_NAME_LABEL: self.our_server_name}, + ).inc() raise LimitExceededError( limiter_name="rc_federation", retry_after_ms=int(self.window_size / self.sleep_limit), @@ -333,7 +349,10 @@ class _PerHostRatelimiter: self.sleep_sec, ) if self.metrics_name: - rate_limit_sleep_counter.labels(self.metrics_name).inc() + rate_limit_sleep_counter.labels( + rate_limiter_name=self.metrics_name, + **{SERVER_NAME_LABEL: self.our_server_name}, + ).inc() ret_defer = run_in_background(self.clock.sleep, self.sleep_sec) self.sleeping_requests.add(request_id) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 42be1c8d28..149df405b3 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -59,7 +59,9 @@ class NotRetryingDestination(Exception): async def get_retry_limiter( + *, destination: str, + our_server_name: str, clock: Clock, store: DataStore, ignore_backoff: bool = False, @@ -74,6 +76,7 @@ async def get_retry_limiter( Args: destination: name of homeserver + our_server_name: Our homeserver name (used to label metrics) (`hs.hostname`) clock: timing source store: datastore ignore_backoff: true to ignore the historical backoff data and @@ -82,7 +85,12 @@ async def get_retry_limiter( Example usage: try: - limiter = await get_retry_limiter(destination, clock, store) + limiter = await get_retry_limiter( + destination=destination, + our_server_name=self.server_name, + clock=clock, + store=store, + ) with limiter: response = await do_request() except NotRetryingDestination: @@ -114,11 +122,12 @@ async def get_retry_limiter( backoff_on_failure = not ignore_backoff return RetryDestinationLimiter( - destination, - clock, - store, - failure_ts, - retry_interval, + destination=destination, + our_server_name=our_server_name, + clock=clock, + store=store, + failure_ts=failure_ts, + retry_interval=retry_interval, backoff_on_failure=backoff_on_failure, **kwargs, ) @@ -151,7 +160,9 @@ async def filter_destinations_by_retry_limiter( class RetryDestinationLimiter: def __init__( self, + *, destination: str, + our_server_name: str, clock: Clock, store: DataStore, failure_ts: Optional[int], @@ -169,6 +180,7 @@ class RetryDestinationLimiter: Args: destination + our_server_name: Our homeserver name (used to label metrics) (`hs.hostname`) clock store failure_ts: when this destination started failing (in ms since @@ -184,6 +196,7 @@ class RetryDestinationLimiter: backoff_on_all_error_codes: Whether we should back off on any error code. """ + self.our_server_name = our_server_name self.clock = clock self.store = store self.destination = destination @@ -318,4 +331,6 @@ class RetryDestinationLimiter: logger.exception("Failed to store destination_retry_timings") # we deliberately do this in the background. - run_as_background_process("store_retry_timings", store_retry_timings) + run_as_background_process( + "store_retry_timings", self.our_server_name, store_retry_timings + ) diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py index 5169656c73..fdcacdf128 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py @@ -30,7 +30,7 @@ from synapse.logging.context import ( nested_logging_context, set_current_context, ) -from synapse.metrics import LaterGauge +from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.metrics.background_process_metrics import ( run_as_background_process, wrap_as_background_process, @@ -101,6 +101,9 @@ class TaskScheduler: def __init__(self, hs: "HomeServer"): self._hs = hs + self.server_name = ( + hs.hostname + ) # nb must be called this for @wrap_as_background_process self._store = hs.get_datastores().main self._clock = hs.get_clock() self._running_tasks: Set[str] = set() @@ -128,10 +131,10 @@ class TaskScheduler: ) LaterGauge( - "synapse_scheduler_running_tasks", - "The number of concurrent running tasks handled by the TaskScheduler", - labels=None, - caller=lambda: len(self._running_tasks), + name="synapse_scheduler_running_tasks", + desc="The number of concurrent running tasks handled by the TaskScheduler", + labelnames=[SERVER_NAME_LABEL], + caller=lambda: {(self.server_name,): len(self._running_tasks)}, ) def register_action( @@ -354,7 +357,7 @@ class TaskScheduler: finally: self._launching_new_tasks = False - run_as_background_process("launch_scheduled_tasks", inner) + run_as_background_process("launch_scheduled_tasks", self.server_name, inner) @wrap_as_background_process("clean_scheduled_tasks") async def _clean_scheduled_tasks(self) -> None: @@ -485,4 +488,4 @@ class TaskScheduler: self._running_tasks.add(task.id) await self.update_task(task.id, status=TaskStatus.ACTIVE) - run_as_background_process(f"task-{task.action}", wrapper) + run_as_background_process(f"task-{task.action}", self.server_name, wrapper) diff --git a/synapse/visibility.py b/synapse/visibility.py index 501fad3839..d460d8f4c2 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -116,13 +116,27 @@ async def filter_events_for_client( # We copy the events list to guarantee any modifications we make will only # happen within the function. events_before_filtering = events.copy() + # Default case is to *exclude* soft-failed events + events = [e for e in events if not e.internal_metadata.is_soft_failed()] client_config = await storage.main.get_admin_client_config_for_user(user_id) - if not ( - filter_send_to_client - and client_config.return_soft_failed_events - and await storage.main.is_server_admin(user_id) - ): - events = [e for e in events if not e.internal_metadata.is_soft_failed()] + if filter_send_to_client and await storage.main.is_server_admin(user_id): + if client_config.return_soft_failed_events: + # The user has requested that all events be included, so do that. + # We copy the list for mutation safety. + events = events_before_filtering.copy() + elif client_config.return_policy_server_spammy_events: + # Include events that were soft failed by a policy server (marked spammy), + # but exclude all other soft failed events. We also want to include all + # not-soft-failed events, per usual operation. + events = [ + e + for e in events_before_filtering + if not e.internal_metadata.is_soft_failed() + or e.internal_metadata.policy_server_spammy + ] + # else - no change in behaviour; use default case + # else - no change in behaviour; use default case + if len(events_before_filtering) != len(events): if filtered_event_logger.isEnabledFor(logging.DEBUG): filtered_event_logger.debug( diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 95a4683d03..b8fb21ab0d 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -23,7 +23,7 @@ from unittest.mock import AsyncMock, Mock import pymacaroons -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.auth.internal import InternalAuth from synapse.api.auth_blocking import AuthBlocking diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 743c52d969..8ad9a5a6f7 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -25,7 +25,7 @@ from unittest.mock import patch import jsonschema -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EduTypes, EventContentFields from synapse.api.errors import SynapseError diff --git a/tests/api/test_urls.py b/tests/api/test_urls.py index ce156a05dc..fecc7e3e2d 100644 --- a/tests/api/test_urls.py +++ b/tests/api/test_urls.py @@ -13,7 +13,7 @@ # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.urls import LoginSSORedirectURIBuilder from synapse.server import HomeServer diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 47d590ecea..63cb5ff46f 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -22,7 +22,7 @@ from unittest.mock import Mock, patch from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.app.generic_worker import GenericWorkerServer from synapse.app.homeserver import SynapseHomeServer diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 8fcd928d31..5eba6d20c8 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -21,7 +21,7 @@ from typing import Any, List, Mapping, Optional, Sequence, Union from unittest.mock import Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.appservice import ApplicationService from synapse.server import HomeServer diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index a5bf7e0635..11319bc52d 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -24,7 +24,7 @@ from unittest.mock import AsyncMock, Mock from typing_extensions import TypeAlias from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.appservice import ( ApplicationService, @@ -53,11 +53,24 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): self.clock = MockClock() self.store = Mock() self.as_api = Mock() + + self.hs = Mock( + spec_set=[ + "get_datastores", + "get_clock", + "get_application_service_api", + "hostname", + ] + ) + self.hs.get_clock.return_value = self.clock + self.hs.get_datastores.return_value = Mock( + main=self.store, + ) + self.hs.get_application_service_api.return_value = self.as_api + self.recoverer = Mock() self.recoverer_fn = Mock(return_value=self.recoverer) - self.txnctrl = _TransactionController( - clock=cast(Clock, self.clock), store=self.store, as_api=self.as_api - ) + self.txnctrl = _TransactionController(self.hs) self.txnctrl.RECOVERER_CLASS = self.recoverer_fn def test_single_service_up_txn_sent(self) -> None: @@ -163,6 +176,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.service = Mock() self.callback = AsyncMock() self.recoverer = _Recoverer( + server_name="test_server", clock=cast(Clock, self.clock), as_api=self.as_api, store=self.store, diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py index 713bddeb90..833cfe628b 100644 --- a/tests/config/test_oauth_delegation.py +++ b/tests/config/test_oauth_delegation.py @@ -20,6 +20,7 @@ # import os +import tempfile from unittest.mock import Mock from synapse.config import ConfigError @@ -275,3 +276,168 @@ class MSC3861OAuthDelegation(TestCase): self.config_dict["enable_3pid_changes"] = True with self.assertRaises(ConfigError): self.parse_config() + + +class MasAuthDelegation(TestCase): + """Test that the Homeserver fails to initialize if the config is invalid.""" + + def setUp(self) -> None: + self.config_dict: JsonDict = { + **default_config("test"), + "public_baseurl": BASE_URL, + "enable_registration": False, + "matrix_authentication_service": { + "enabled": True, + "endpoint": "http://localhost:1324/", + "secret": "verysecret", + }, + } + + def parse_config(self) -> HomeServerConfig: + config = HomeServerConfig() + config.parse_config_dict(self.config_dict, "", "") + return config + + def test_endpoint_has_to_be_a_url(self) -> None: + self.config_dict["matrix_authentication_service"]["endpoint"] = "not a url" + with self.assertRaises(ConfigError): + self.parse_config() + + def test_secret_and_secret_path_are_mutually_exclusive(self) -> None: + with tempfile.NamedTemporaryFile() as f: + self.config_dict["matrix_authentication_service"]["secret"] = "verysecret" + self.config_dict["matrix_authentication_service"]["secret_path"] = f.name + with self.assertRaises(ConfigError): + self.parse_config() + + def test_secret_path_loads_secret(self) -> None: + with tempfile.NamedTemporaryFile(buffering=0) as f: + f.write(b"53C237") + del self.config_dict["matrix_authentication_service"]["secret"] + self.config_dict["matrix_authentication_service"]["secret_path"] = f.name + config = self.parse_config() + self.assertEqual(config.mas.secret(), "53C237") + + def test_secret_path_must_exist(self) -> None: + del self.config_dict["matrix_authentication_service"]["secret"] + self.config_dict["matrix_authentication_service"]["secret_path"] = ( + "/not/a/valid/file" + ) + with self.assertRaises(ConfigError): + self.parse_config() + + def test_registration_cannot_be_enabled(self) -> None: + self.config_dict["enable_registration"] = True + with self.assertRaises(ConfigError): + self.parse_config() + + def test_user_consent_cannot_be_enabled(self) -> None: + tmpdir = self.mktemp() + os.mkdir(tmpdir) + self.config_dict["user_consent"] = { + "require_at_registration": True, + "version": "1", + "template_dir": tmpdir, + "server_notice_content": { + "msgtype": "m.text", + "body": "foo", + }, + } + with self.assertRaises(ConfigError): + self.parse_config() + + def test_password_config_cannot_be_enabled(self) -> None: + self.config_dict["password_config"] = {"enabled": True} + with self.assertRaises(ConfigError): + self.parse_config() + + @skip_unless(HAS_AUTHLIB, "requires authlib") + def test_oidc_sso_cannot_be_enabled(self) -> None: + self.config_dict["oidc_providers"] = [ + { + "idp_id": "microsoft", + "idp_name": "Microsoft", + "issuer": "https://login.microsoftonline.com//v2.0", + "client_id": "", + "client_secret": "", + "scopes": ["openid", "profile"], + "authorization_endpoint": "https://login.microsoftonline.com//oauth2/v2.0/authorize", + "token_endpoint": "https://login.microsoftonline.com//oauth2/v2.0/token", + "userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo", + } + ] + + with self.assertRaises(ConfigError): + self.parse_config() + + def test_cas_sso_cannot_be_enabled(self) -> None: + self.config_dict["cas_config"] = { + "enabled": True, + "server_url": "https://cas-server.com", + "displayname_attribute": "name", + "required_attributes": {"userGroup": "staff", "department": "None"}, + } + + with self.assertRaises(ConfigError): + self.parse_config() + + def test_auth_providers_cannot_be_enabled(self) -> None: + self.config_dict["modules"] = [ + { + "module": f"{__name__}.{CustomAuthModule.__qualname__}", + "config": {}, + } + ] + + # This requires actually setting up an HS, as the module will be run on setup, + # which should raise as the module tries to register an auth provider + config = self.parse_config() + reactor, clock = get_clock() + with self.assertRaises(ConfigError): + setup_test_homeserver( + self.addCleanup, reactor=reactor, clock=clock, config=config + ) + + @skip_unless(HAS_AUTHLIB, "requires authlib") + def test_jwt_auth_cannot_be_enabled(self) -> None: + self.config_dict["jwt_config"] = { + "enabled": True, + "secret": "my-secret-token", + "algorithm": "HS256", + } + + with self.assertRaises(ConfigError): + self.parse_config() + + def test_login_via_existing_session_cannot_be_enabled(self) -> None: + self.config_dict["login_via_existing_session"] = {"enabled": True} + with self.assertRaises(ConfigError): + self.parse_config() + + def test_captcha_cannot_be_enabled(self) -> None: + self.config_dict.update( + enable_registration_captcha=True, + recaptcha_public_key="test", + recaptcha_private_key="test", + ) + with self.assertRaises(ConfigError): + self.parse_config() + + def test_refreshable_tokens_cannot_be_enabled(self) -> None: + self.config_dict.update( + refresh_token_lifetime="24h", + refreshable_access_token_lifetime="10m", + nonrefreshable_access_token_lifetime="24h", + ) + with self.assertRaises(ConfigError): + self.parse_config() + + def test_session_lifetime_cannot_be_set(self) -> None: + self.config_dict["session_lifetime"] = "24h" + with self.assertRaises(ConfigError): + self.parse_config() + + def test_enable_3pid_changes_cannot_be_enabled(self) -> None: + self.config_dict["enable_3pid_changes"] = True + with self.assertRaises(ConfigError): + self.parse_config() diff --git a/tests/config/test_room_directory.py b/tests/config/test_room_directory.py index 5208381279..5f3d8be2a5 100644 --- a/tests/config/test_room_directory.py +++ b/tests/config/test_room_directory.py @@ -19,7 +19,7 @@ # import yaml -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin import synapse.rest.client.login diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 3bfaf1c80d..80f9bd097e 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -31,7 +31,7 @@ from signedjson.types import SigningKey, VerifyKey from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import SynapseError from synapse.crypto import keyring diff --git a/tests/events/test_auto_accept_invites.py b/tests/events/test_auto_accept_invites.py index d2100e9903..ab183f8106 100644 --- a/tests/events/test_auto_accept_invites.py +++ b/tests/events/test_auto_accept_invites.py @@ -27,7 +27,7 @@ from unittest.mock import Mock import attr from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index e48983ddfe..a61f1369f4 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -23,7 +23,7 @@ from unittest.mock import AsyncMock, Mock import attr -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EduTypes from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index f96bbe7705..6d24730ed7 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -19,7 +19,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.events import EventBase from synapse.events.snapshot import EventContext diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index 1dc6004b35..c6ebefbf38 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -822,6 +822,32 @@ class SerializeEventTestCase(stdlib_unittest.TestCase): "unsigned": {"io.element.synapse.soft_failed": True}, }, ) + self.assertEqual( + self.serialize( + MockEvent( + type="foo", + event_id="test", + room_id="!foo:bar", + content={"foo": "bar"}, + internal_metadata={ + "soft_failed": True, + "policy_server_spammy": True, + }, + ), + [], + True, + ), + { + "type": "foo", + "event_id": "test", + "room_id": "!foo:bar", + "content": {"foo": "bar"}, + "unsigned": { + "io.element.synapse.soft_failed": True, + "io.element.synapse.policy_server_spammy": True, + }, + }, + ) def test_make_serialize_config_for_admin_retains_other_fields(self) -> None: non_default_config = SerializeEventConfig( diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 1e1ed8e642..f99911b102 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -2,7 +2,7 @@ from typing import Callable, Collection, List, Optional, Tuple from unittest import mock from unittest.mock import AsyncMock, Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes from synapse.events import EventBase diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index 585f3b798c..df688cd21f 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -23,7 +23,7 @@ from unittest import mock import twisted.web.client from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.room_versions import RoomVersions from synapse.events import EventBase diff --git a/tests/federation/test_federation_devices.py b/tests/federation/test_federation_devices.py index ba27e69479..bf6204a7e3 100644 --- a/tests/federation/test_federation_devices.py +++ b/tests/federation/test_federation_devices.py @@ -21,7 +21,7 @@ import logging from unittest.mock import AsyncMock, Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.handlers.device import DeviceListUpdater from synapse.server import HomeServer diff --git a/tests/federation/test_federation_media.py b/tests/federation/test_federation_media.py index cd4905239f..b9ec2794a3 100644 --- a/tests/federation/test_federation_media.py +++ b/tests/federation/test_federation_media.py @@ -22,7 +22,7 @@ import os import shutil import tempfile -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.media.filepath import MediaFilePaths from synapse.media.media_storage import MediaStorage diff --git a/tests/federation/test_federation_out_of_band_membership.py b/tests/federation/test_federation_out_of_band_membership.py index f77b8fe300..acf343930f 100644 --- a/tests/federation/test_federation_out_of_band_membership.py +++ b/tests/federation/test_federation_out_of_band_membership.py @@ -29,7 +29,7 @@ from unittest.mock import Mock import attr from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.room_versions import RoomVersion, RoomVersions diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 267ea0b06e..b8dd61d04f 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -24,7 +24,7 @@ from signedjson import key, sign from signedjson.types import BaseKey, SigningKey from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms from synapse.api.presence import UserPresenceState diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 58ead90909..52fd32ba85 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -25,7 +25,7 @@ from unittest.mock import Mock from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, Membership from synapse.api.errors import FederationError diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 166a01c1a2..14345be0f3 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -21,7 +21,7 @@ from collections import OrderedDict from typing import Any, Dict, List, Optional -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersion, RoomVersions diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py index 9ff853a83d..906d241f1a 100644 --- a/tests/handlers/test_admin.py +++ b/tests/handlers/test_admin.py @@ -22,7 +22,7 @@ from collections import Counter from unittest.mock import Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin import synapse.storage diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 25cf5269b8..a47b03b143 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -25,7 +25,7 @@ from unittest.mock import AsyncMock, Mock from parameterized import parameterized from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin import synapse.storage diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index c417431e85..0d9940c63e 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -23,7 +23,7 @@ from unittest.mock import AsyncMock import pymacaroons -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import AuthError, ResourceLimitError from synapse.rest import admin diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index f41f7d36ad..9de5e67863 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -21,7 +21,7 @@ from typing import Any, Dict from unittest.mock import AsyncMock, Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.handlers.cas import CasResponse from synapse.server import HomeServer diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py index d7b54383db..b7b8387780 100644 --- a/tests/handlers/test_deactivate_account.py +++ b/tests/handlers/test_deactivate_account.py @@ -19,7 +19,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules, Membership from synapse.push.rulekinds import PRIORITY_CLASS_MAP diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 1e989ca528..195cdfeaef 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -24,7 +24,7 @@ from typing import Optional from unittest import mock from twisted.internet.defer import ensureDeferred -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import NotFoundError, SynapseError diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index b7058d8002..4d6243ef74 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -22,7 +22,7 @@ from typing import Any, Awaitable, Callable, Dict from unittest.mock import AsyncMock, Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.api.errors import synapse.rest.admin diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 182f9dab5d..fda485d413 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -26,7 +26,7 @@ from unittest import mock from parameterized import parameterized from signedjson import key as key, sign as sign -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py index 3ec46402b7..9b280659ab 100644 --- a/tests/handlers/test_e2e_room_keys.py +++ b/tests/handlers/test_e2e_room_keys.py @@ -23,7 +23,7 @@ import copy from unittest import mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import SynapseError from synapse.server import HomeServer diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index b64a8a86a2..4de90e6578 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -24,7 +24,7 @@ from unittest import TestCase from unittest.mock import AsyncMock, Mock, patch from twisted.internet.defer import Deferred -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes from synapse.api.errors import ( diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 51eca56c3b..02dd60e76d 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -21,7 +21,7 @@ from typing import Optional from unittest import mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import AuthError, StoreError from synapse.api.room_versions import RoomVersion diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 990c906d2c..0a1092eae4 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -21,7 +21,7 @@ import logging from typing import Tuple -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py index 20f2306d4c..2b0638bc12 100644 --- a/tests/handlers/test_oauth_delegation.py +++ b/tests/handlers/test_oauth_delegation.py @@ -20,12 +20,16 @@ # import json +import threading +import time from http import HTTPStatus +from http.server import BaseHTTPRequestHandler, HTTPServer from io import BytesIO -from typing import Any, Dict, Union +from typing import Any, Coroutine, Dict, Generator, Optional, TypeVar, Union from unittest.mock import ANY, AsyncMock, Mock from urllib.parse import parse_qs +from parameterized import parameterized_class from signedjson.key import ( encode_verify_key_base64, generate_signing_key, @@ -33,8 +37,10 @@ from signedjson.key import ( ) from signedjson.sign import sign_json -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.defer import Deferred, ensureDeferred +from twisted.internet.testing import MemoryReactor +from synapse.api.auth.mas import MasDelegatedAuth from synapse.api.errors import ( AuthError, Codes, @@ -48,7 +54,7 @@ from synapse.http.site import SynapseRequest from synapse.rest import admin from synapse.rest.client import account, devices, keys, login, logout, register from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests.server import FakeChannel @@ -109,12 +115,7 @@ async def get_json(url: str) -> JsonDict: class MSC3861OAuthDelegation(HomeserverTestCase): servlets = [ account.register_servlets, - devices.register_servlets, keys.register_servlets, - register.register_servlets, - login.register_servlets, - logout.register_servlets, - admin.register_servlets, ] def default_config(self) -> Dict[str, Any]: @@ -635,6 +636,535 @@ class MSC3861OAuthDelegation(HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.json_body) + def test_admin_token(self) -> None: + """The handler should return a requester with admin rights when admin_token is used.""" + self._set_introspection_returnvalue({"active": False}) + + request = Mock(args={}) + request.args[b"access_token"] = [b"admin_token_value"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = self.get_success(self.auth.get_user_by_req(request)) + self.assertEqual( + requester.user.to_string(), + OIDC_ADMIN_USERID, + ) + self.assertEqual(requester.is_guest, False) + self.assertEqual(requester.device_id, None) + self.assertEqual( + get_awaitable_result(self.auth.is_server_admin(requester)), True + ) + + # There should be no call to the introspection endpoint + self._rust_client.post.assert_not_called() + + @override_config({"mau_stats_only": True}) + def test_request_tracking(self) -> None: + """Using an access token should update the client_ips and MAU tables.""" + # To start, there are no MAU users. + store = self.hs.get_datastores().main + mau = self.get_success(store.get_monthly_active_count()) + self.assertEqual(mau, 0) + + known_token = "token-token-GOOD-:)" + + async def mock_http_client_request( + url: str, request_body: str, **kwargs: Any + ) -> bytes: + """Mocked auth provider response.""" + token = parse_qs(request_body)["token"][0] + if token == known_token: + return json.dumps( + { + "active": True, + "scope": MATRIX_USER_SCOPE, + "sub": SUBJECT, + "username": USERNAME, + }, + ).encode("utf-8") + + return json.dumps({"active": False}).encode("utf-8") + + self._rust_client.post = mock_http_client_request + + EXAMPLE_IPV4_ADDR = "123.123.123.123" + EXAMPLE_USER_AGENT = "httprettygood" + + # First test a known access token + channel = FakeChannel(self.site, self.reactor) + # type-ignore: FakeChannel is a mock of an HTTPChannel, not a proper HTTPChannel + req = SynapseRequest(channel, self.site, self.hs.hostname) # type: ignore[arg-type] + req.client.host = EXAMPLE_IPV4_ADDR + req.requestHeaders.addRawHeader("Authorization", f"Bearer {known_token}") + req.requestHeaders.addRawHeader("User-Agent", EXAMPLE_USER_AGENT) + req.content = BytesIO(b"") + req.requestReceived( + b"GET", + b"/_matrix/client/v3/account/whoami", + b"1.1", + ) + channel.await_result() + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + self.assertEqual(channel.json_body["user_id"], USER_ID, channel.json_body) + + # Expect to see one MAU entry, from the first request + mau = self.get_success(store.get_monthly_active_count()) + self.assertEqual(mau, 1) + + conn_infos = self.get_success( + store.get_user_ip_and_agents(UserID.from_string(USER_ID)) + ) + self.assertEqual(len(conn_infos), 1, conn_infos) + conn_info = conn_infos[0] + self.assertEqual(conn_info["access_token"], known_token) + self.assertEqual(conn_info["ip"], EXAMPLE_IPV4_ADDR) + self.assertEqual(conn_info["user_agent"], EXAMPLE_USER_AGENT) + + # Now test MAS making a request using the special __oidc_admin token + MAS_IPV4_ADDR = "127.0.0.1" + MAS_USER_AGENT = "masmasmas" + + channel = FakeChannel(self.site, self.reactor) + req = SynapseRequest(channel, self.site, self.hs.hostname) # type: ignore[arg-type] + req.client.host = MAS_IPV4_ADDR + req.requestHeaders.addRawHeader( + "Authorization", f"Bearer {self.auth._admin_token()}" + ) + req.requestHeaders.addRawHeader("User-Agent", MAS_USER_AGENT) + req.content = BytesIO(b"") + req.requestReceived( + b"GET", + b"/_matrix/client/v3/account/whoami", + b"1.1", + ) + channel.await_result() + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + self.assertEqual( + channel.json_body["user_id"], OIDC_ADMIN_USERID, channel.json_body + ) + + # Still expect to see one MAU entry, from the first request + mau = self.get_success(store.get_monthly_active_count()) + self.assertEqual(mau, 1) + + conn_infos = self.get_success( + store.get_user_ip_and_agents(UserID.from_string(OIDC_ADMIN_USERID)) + ) + self.assertEqual(conn_infos, []) + + +class FakeMasHandler(BaseHTTPRequestHandler): + server: "FakeMasServer" + + def do_POST(self) -> None: + self.server.calls += 1 + + if self.path != "/oauth2/introspect": + self.send_response(404) + self.end_headers() + self.wfile.close() + return + + auth = self.headers.get("Authorization") + if auth is None or auth != f"Bearer {self.server.secret}": + self.send_response(401) + self.end_headers() + self.wfile.close() + return + + content_length = self.headers.get("Content-Length") + if content_length is None: + self.send_response(400) + self.end_headers() + self.wfile.close() + return + + raw_body = self.rfile.read(int(content_length)) + body = parse_qs(raw_body) + param = body.get(b"token") + if param is None: + self.send_response(400) + self.end_headers() + self.wfile.close() + return + + self.server.last_token_seen = param[0].decode("utf-8") + + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps(self.server.introspection_response).encode("utf-8")) + + def log_message(self, format: str, *args: Any) -> None: + # Don't log anything; by default, the server logs to stderr + pass + + +class FakeMasServer(HTTPServer): + """A fake MAS server for testing. + + This opens a real HTTP server on a random port, on a separate thread. + """ + + introspection_response: JsonDict = {} + """Determines what the response to the introspection endpoint will be.""" + + secret: str = "verysecret" + """The shared secret used to authenticate the introspection endpoint.""" + + last_token_seen: Optional[str] = None + """What is the last access token seen by the introspection endpoint.""" + + calls: int = 0 + """How many times has the introspection endpoint been called.""" + + _thread: threading.Thread + + def __init__(self) -> None: + super().__init__(("127.0.0.1", 0), FakeMasHandler) + + self._thread = threading.Thread( + target=self.serve_forever, + name="FakeMasServer", + kwargs={"poll_interval": 0.01}, + daemon=True, + ) + self._thread.start() + + def shutdown(self) -> None: + super().shutdown() + self._thread.join() + + @property + def endpoint(self) -> str: + return f"http://127.0.0.1:{self.server_port}/" + + +T = TypeVar("T") + + +class MasAuthDelegation(HomeserverTestCase): + server: FakeMasServer + + def till_deferred_has_result( + self, + awaitable: Union[ + "Coroutine[Deferred[Any], Any, T]", + "Generator[Deferred[Any], Any, T]", + "Deferred[T]", + ], + ) -> "Deferred[T]": + """Wait until a deferred has a result. + + This is useful because the Rust HTTP client will resolve the deferred + using reactor.callFromThread, which are only run when we call + reactor.advance. + """ + deferred = ensureDeferred(awaitable) + tries = 0 + while not deferred.called: + time.sleep(0.1) + self.reactor.advance(0) + tries += 1 + if tries > 100: + raise Exception("Timed out waiting for deferred to resolve") + + return deferred + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + config["public_baseurl"] = BASE_URL + config["disable_registration"] = True + config["matrix_authentication_service"] = { + "enabled": True, + "endpoint": self.server.endpoint, + "secret": self.server.secret, + } + return config + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.server = FakeMasServer() + hs = self.setup_test_homeserver() + # This triggers the server startup hooks, which starts the Tokio thread pool + reactor.run() + self._auth = checked_cast(MasDelegatedAuth, hs.get_auth()) + return hs + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + # Provision the user and the device we use in the tests. + store = homeserver.get_datastores().main + self.get_success(store.register_user(USER_ID)) + self.get_success( + store.store_device(USER_ID, DEVICE, initial_device_display_name=None) + ) + + def tearDown(self) -> None: + self.server.shutdown() + # MemoryReactor doesn't trigger the shutdown phases, and we want the + # Tokio thread pool to be stopped + # XXX: This logic should probably get moved somewhere else + shutdown_triggers = self.reactor.triggers.get("shutdown", {}) + for phase in ["before", "during", "after"]: + triggers = shutdown_triggers.get(phase, []) + for callbable, args, kwargs in triggers: + callbable(*args, **kwargs) + + def test_simple_introspection(self) -> None: + self.server.introspection_response = { + "active": True, + "sub": SUBJECT, + "scope": " ".join( + [ + MATRIX_USER_SCOPE, + f"{MATRIX_DEVICE_SCOPE_PREFIX}{DEVICE}", + ] + ), + "username": USERNAME, + "expires_in": 60, + } + + requester = self.get_success( + self.till_deferred_has_result( + self._auth.get_user_by_access_token("some_token") + ) + ) + + self.assertEquals(requester.user.to_string(), USER_ID) + self.assertEquals(requester.device_id, DEVICE) + self.assertFalse(self.get_success(self._auth.is_server_admin(requester))) + + self.assertEquals( + self.server.last_token_seen, + "some_token", + ) + + def test_unexpiring_token(self) -> None: + self.server.introspection_response = { + "active": True, + "sub": SUBJECT, + "scope": " ".join( + [ + MATRIX_USER_SCOPE, + f"{MATRIX_DEVICE_SCOPE_PREFIX}{DEVICE}", + ] + ), + "username": USERNAME, + } + + requester = self.get_success( + self.till_deferred_has_result( + self._auth.get_user_by_access_token("some_token") + ) + ) + + self.assertEquals(requester.user.to_string(), USER_ID) + self.assertEquals(requester.device_id, DEVICE) + self.assertFalse(self.get_success(self._auth.is_server_admin(requester))) + + self.assertEquals( + self.server.last_token_seen, + "some_token", + ) + + def test_inexistent_device(self) -> None: + self.server.introspection_response = { + "active": True, + "sub": SUBJECT, + "scope": " ".join( + [ + MATRIX_USER_SCOPE, + f"{MATRIX_DEVICE_SCOPE_PREFIX}ABCDEF", + ] + ), + "username": USERNAME, + "expires_in": 60, + } + + failure = self.get_failure( + self.till_deferred_has_result( + self._auth.get_user_by_access_token("some_token") + ), + InvalidClientTokenError, + ) + self.assertEqual(failure.value.code, 401) + + def test_inexistent_user(self) -> None: + self.server.introspection_response = { + "active": True, + "sub": SUBJECT, + "scope": " ".join([MATRIX_USER_SCOPE]), + "username": "inexistent_user", + "expires_in": 60, + } + + failure = self.get_failure( + self.till_deferred_has_result( + self._auth.get_user_by_access_token("some_token") + ), + AuthError, + ) + # This is a 500, it should never happen really + self.assertEqual(failure.value.code, 500) + + def test_missing_scope(self) -> None: + self.server.introspection_response = { + "active": True, + "sub": SUBJECT, + "scope": "openid", + "username": USERNAME, + "expires_in": 60, + } + + failure = self.get_failure( + self.till_deferred_has_result( + self._auth.get_user_by_access_token("some_token") + ), + InvalidClientTokenError, + ) + self.assertEqual(failure.value.code, 401) + + def test_invalid_response(self) -> None: + self.server.introspection_response = {} + + failure = self.get_failure( + self.till_deferred_has_result( + self._auth.get_user_by_access_token("some_token") + ), + SynapseError, + ) + self.assertEqual(failure.value.code, 503) + + def test_device_id_in_body(self) -> None: + self.server.introspection_response = { + "active": True, + "sub": SUBJECT, + "scope": MATRIX_USER_SCOPE, + "username": USERNAME, + "expires_in": 60, + "device_id": DEVICE, + } + + requester = self.get_success( + self.till_deferred_has_result( + self._auth.get_user_by_access_token("some_token") + ) + ) + + self.assertEqual(requester.device_id, DEVICE) + + def test_admin_scope(self) -> None: + self.server.introspection_response = { + "active": True, + "sub": SUBJECT, + "scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]), + "username": USERNAME, + "expires_in": 60, + } + + requester = self.get_success( + self.till_deferred_has_result( + self._auth.get_user_by_access_token("some_token") + ) + ) + + self.assertEqual(requester.user.to_string(), USER_ID) + self.assertTrue(self.get_success(self._auth.is_server_admin(requester))) + + def test_cached_expired_introspection(self) -> None: + """The handler should raise an error if the introspection response gives + an expiry time, the introspection response is cached and then the entry is + re-requested after it has expired.""" + + self.server.introspection_response = { + "active": True, + "sub": SUBJECT, + "scope": " ".join( + [ + MATRIX_USER_SCOPE, + f"{MATRIX_DEVICE_SCOPE_PREFIX}{DEVICE}", + ] + ), + "username": USERNAME, + "expires_in": 60, + } + + self.assertEqual(self.server.calls, 0) + + request = Mock(args={}) + request.args[b"access_token"] = [b"some_token"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + + # The first CS-API request causes a successful introspection + self.get_success( + self.till_deferred_has_result(self._auth.get_user_by_req(request)) + ) + self.assertEqual(self.server.calls, 1) + + # Sleep for 60 seconds so the token expires. + self.reactor.advance(60.0) + + # Now the CS-API request fails because the token expired + self.assertFailure( + self.till_deferred_has_result(self._auth.get_user_by_req(request)), + InvalidClientTokenError, + ) + # Ensure another introspection request was not sent + self.assertEqual(self.server.calls, 1) + + +@parameterized_class( + ("config",), + [ + ( + { + "matrix_authentication_service": { + "enabled": True, + "endpoint": "http://localhost:1234/", + "secret": "secret", + }, + }, + ), + ] + # Run the tests with experimental delegation only if authlib is available + + [ + ( + { + "experimental_features": { + "msc3861": { + "enabled": True, + "issuer": ISSUER, + "client_id": CLIENT_ID, + "client_auth_method": "client_secret_post", + "client_secret": CLIENT_SECRET, + "admin_token": "admin_token_value", + } + } + }, + ), + ] + * HAS_AUTHLIB, +) +class DisabledEndpointsTestCase(HomeserverTestCase): + servlets = [ + account.register_servlets, + devices.register_servlets, + keys.register_servlets, + register.register_servlets, + login.register_servlets, + logout.register_servlets, + admin.register_servlets, + ] + + config: Dict[str, Any] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + config["public_baseurl"] = BASE_URL + config["disable_registration"] = True + config.update(self.config) + return config + def expect_unauthorized( self, method: str, path: str, content: Union[bytes, str, JsonDict] = "" ) -> None: @@ -774,13 +1304,11 @@ class MSC3861OAuthDelegation(HomeserverTestCase): # Because we still support those endpoints with ASes, it checks the # access token before returning 404 - self._set_introspection_returnvalue( - { - "active": True, - "sub": SUBJECT, - "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]), - "username": USERNAME, - }, + self.hs.get_auth().get_user_by_req = AsyncMock( # type: ignore[method-assign] + return_value=create_requester( + user_id=USER_ID, + device_id=DEVICE, + ) ) self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices", auth=True) @@ -810,118 +1338,3 @@ class MSC3861OAuthDelegation(HomeserverTestCase): self.expect_unrecognized("GET", "/_synapse/admin/v1/users/foo/admin") self.expect_unrecognized("PUT", "/_synapse/admin/v1/users/foo/admin") self.expect_unrecognized("POST", "/_synapse/admin/v1/account_validity/validity") - - def test_admin_token(self) -> None: - """The handler should return a requester with admin rights when admin_token is used.""" - self._set_introspection_returnvalue({"active": False}) - - request = Mock(args={}) - request.args[b"access_token"] = [b"admin_token_value"] - request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = self.get_success(self.auth.get_user_by_req(request)) - self.assertEqual( - requester.user.to_string(), - OIDC_ADMIN_USERID, - ) - self.assertEqual(requester.is_guest, False) - self.assertEqual(requester.device_id, None) - self.assertEqual( - get_awaitable_result(self.auth.is_server_admin(requester)), True - ) - - # There should be no call to the introspection endpoint - self._rust_client.post.assert_not_called() - - @override_config({"mau_stats_only": True}) - def test_request_tracking(self) -> None: - """Using an access token should update the client_ips and MAU tables.""" - # To start, there are no MAU users. - store = self.hs.get_datastores().main - mau = self.get_success(store.get_monthly_active_count()) - self.assertEqual(mau, 0) - - known_token = "token-token-GOOD-:)" - - async def mock_http_client_request( - url: str, request_body: str, **kwargs: Any - ) -> bytes: - """Mocked auth provider response.""" - token = parse_qs(request_body)["token"][0] - if token == known_token: - return json.dumps( - { - "active": True, - "scope": MATRIX_USER_SCOPE, - "sub": SUBJECT, - "username": USERNAME, - }, - ).encode("utf-8") - - return json.dumps({"active": False}).encode("utf-8") - - self._rust_client.post = mock_http_client_request - - EXAMPLE_IPV4_ADDR = "123.123.123.123" - EXAMPLE_USER_AGENT = "httprettygood" - - # First test a known access token - channel = FakeChannel(self.site, self.reactor) - # type-ignore: FakeChannel is a mock of an HTTPChannel, not a proper HTTPChannel - req = SynapseRequest(channel, self.site) # type: ignore[arg-type] - req.client.host = EXAMPLE_IPV4_ADDR - req.requestHeaders.addRawHeader("Authorization", f"Bearer {known_token}") - req.requestHeaders.addRawHeader("User-Agent", EXAMPLE_USER_AGENT) - req.content = BytesIO(b"") - req.requestReceived( - b"GET", - b"/_matrix/client/v3/account/whoami", - b"1.1", - ) - channel.await_result() - self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) - self.assertEqual(channel.json_body["user_id"], USER_ID, channel.json_body) - - # Expect to see one MAU entry, from the first request - mau = self.get_success(store.get_monthly_active_count()) - self.assertEqual(mau, 1) - - conn_infos = self.get_success( - store.get_user_ip_and_agents(UserID.from_string(USER_ID)) - ) - self.assertEqual(len(conn_infos), 1, conn_infos) - conn_info = conn_infos[0] - self.assertEqual(conn_info["access_token"], known_token) - self.assertEqual(conn_info["ip"], EXAMPLE_IPV4_ADDR) - self.assertEqual(conn_info["user_agent"], EXAMPLE_USER_AGENT) - - # Now test MAS making a request using the special __oidc_admin token - MAS_IPV4_ADDR = "127.0.0.1" - MAS_USER_AGENT = "masmasmas" - - channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel, self.site) # type: ignore[arg-type] - req.client.host = MAS_IPV4_ADDR - req.requestHeaders.addRawHeader( - "Authorization", f"Bearer {self.auth._admin_token()}" - ) - req.requestHeaders.addRawHeader("User-Agent", MAS_USER_AGENT) - req.content = BytesIO(b"") - req.requestReceived( - b"GET", - b"/_matrix/client/v3/account/whoami", - b"1.1", - ) - channel.await_result() - self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) - self.assertEqual( - channel.json_body["user_id"], OIDC_ADMIN_USERID, channel.json_body - ) - - # Still expect to see one MAU entry, from the first request - mau = self.get_success(store.get_monthly_active_count()) - self.assertEqual(mau, 1) - - conn_infos = self.get_success( - store.get_user_ip_and_agents(UserID.from_string(OIDC_ADMIN_USERID)) - ) - self.assertEqual(conn_infos, []) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index ff8e3c5cb6..db37e7d185 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -25,7 +25,7 @@ from urllib.parse import parse_qs, urlparse import pymacaroons -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.handlers.sso import MappingException from synapse.http.site import SynapseRequest diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index ed203eb299..0a78fe0304 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -25,7 +25,7 @@ from http import HTTPStatus from typing import Any, Dict, List, Optional, Type, Union from unittest.mock import AsyncMock, Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse from synapse.api.constants import LoginType diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 6b7bf112c2..51b6c60531 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -29,7 +29,7 @@ from signedjson.key import ( get_verify_key, ) -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.presence import UserDevicePresenceState, UserPresenceState @@ -90,6 +90,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): prev_state, new_state, is_mine=True, + our_server_name=self.hs.hostname, wheel_timer=wheel_timer, now=now, persist=False, @@ -137,6 +138,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): prev_state, new_state, is_mine=True, + our_server_name=self.hs.hostname, wheel_timer=wheel_timer, now=now, persist=False, @@ -187,6 +189,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): prev_state, new_state, is_mine=True, + our_server_name=self.hs.hostname, wheel_timer=wheel_timer, now=now, persist=False, @@ -235,6 +238,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): prev_state, new_state, is_mine=True, + our_server_name=self.hs.hostname, wheel_timer=wheel_timer, now=now, persist=False, @@ -275,6 +279,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): prev_state, new_state, is_mine=False, + our_server_name=self.hs.hostname, wheel_timer=wheel_timer, now=now, persist=False, @@ -314,6 +319,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): prev_state, new_state, is_mine=True, + our_server_name=self.hs.hostname, wheel_timer=wheel_timer, now=now, persist=False, @@ -341,6 +347,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): prev_state, new_state, is_mine=True, + our_server_name=self.hs.hostname, wheel_timer=wheel_timer, now=now, persist=False, @@ -431,6 +438,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): prev_state, new_state, is_mine=True, + our_server_name=self.hs.hostname, wheel_timer=wheel_timer, now=now, persist=True, @@ -494,6 +502,7 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): prev_state, new_state, is_mine=True, + our_server_name=self.hs.hostname, wheel_timer=wheel_timer, now=now, persist=False, diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 2b9b56da95..93934e9ff7 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -23,7 +23,7 @@ from unittest.mock import AsyncMock, Mock from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.types from synapse.api.errors import AuthError, SynapseError diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 7c5bec2b76..cf04ac6e00 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -22,7 +22,7 @@ from copy import deepcopy from typing import List -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EduTypes, ReceiptTypes from synapse.server import HomeServer diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 99bd0de834..43ded2fc10 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -22,7 +22,7 @@ from typing import Any, Collection, List, Optional, Tuple from unittest.mock import AsyncMock, Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.auth.internal import InternalAuth from synapse.api.constants import UserTypes diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index d87fe9d62c..3084f180f5 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -1,6 +1,6 @@ from unittest.mock import AsyncMock, patch -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin import synapse.rest.client.login diff --git a/tests/handlers/test_room_policy.py b/tests/handlers/test_room_policy.py index 26642c18ea..3ea6f13cce 100644 --- a/tests/handlers/test_room_policy.py +++ b/tests/handlers/test_room_policy.py @@ -15,7 +15,7 @@ from typing import Optional from unittest import mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.events import EventBase, make_event_from_dict from synapse.rest import admin diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index bf18c1e72a..27646d7365 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -22,7 +22,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple from unittest import mock from twisted.internet.defer import ensureDeferred -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import ( EventContentFields, @@ -45,6 +45,7 @@ from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests import unittest +from tests.unittest import override_config def _create_event( @@ -245,6 +246,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) + @override_config({"rc_room_creation": {"burst_count": 1000, "per_second": 1}}) def test_large_space(self) -> None: """Test a space with a large number of rooms.""" rooms = [self.room] @@ -527,6 +529,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ) self._assert_hierarchy(result, expected) + @override_config({"rc_room_creation": {"burst_count": 1000, "per_second": 1}}) def test_pagination(self) -> None: """Test simple pagination works.""" room_ids = [] @@ -564,6 +567,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): self._assert_hierarchy(result, expected) self.assertNotIn("next_batch", result) + @override_config({"rc_room_creation": {"burst_count": 1000, "per_second": 1}}) def test_invalid_pagination_token(self) -> None: """An invalid pagination token, or changing other parameters, shoudl be rejected.""" room_ids = [] @@ -615,6 +619,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): SynapseError, ) + @override_config({"rc_room_creation": {"burst_count": 1000, "per_second": 1}}) def test_max_depth(self) -> None: """Create a deep tree to test the max depth against.""" spaces = [self.space] diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 1aca354826..98a4276a3a 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -24,7 +24,7 @@ from unittest.mock import AsyncMock, Mock import attr -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import RedirectException from synapse.module_api import ModuleApi diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py index 7144c58217..8c390f0c57 100644 --- a/tests/handlers/test_sliding_sync.py +++ b/tests/handlers/test_sliding_sync.py @@ -24,7 +24,7 @@ from unittest.mock import patch import attr from parameterized import parameterized, parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import ( EventTypes, diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py index 25e9130aaf..896e4fac9a 100644 --- a/tests/handlers/test_sso.py +++ b/tests/handlers/test_sso.py @@ -21,7 +21,7 @@ from http import HTTPStatus from typing import BinaryIO, Callable, Dict, List, Optional, Tuple from unittest.mock import Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.web.http_headers import Headers from synapse.api.errors import Codes, SynapseError diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index bdb6fdb120..cd17cd86e0 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Tuple, cast -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, room diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 6b202dfbd5..cea61bed6a 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -24,7 +24,7 @@ from unittest.mock import AsyncMock, Mock, patch from parameterized import parameterized, parameterized_class from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import AccountDataTypes, EventTypes, JoinRules from synapse.api.errors import Codes, ResourceLimitError diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 1126b6f183..614b12c62a 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -26,7 +26,7 @@ from unittest.mock import ANY, AsyncMock, Mock, call from netaddr import IPSet -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.web.resource import Resource from synapse.api.constants import EduTypes @@ -92,6 +92,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): user_agent=b"SynapseInTrialTest/0.0.0", ip_allowlist=None, ip_blocklist=IPSet(), + proxy_config=None, ) # the tests assume that we are starting at unix time 1000 diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 0da423142c..7458fe0885 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -21,7 +21,7 @@ from typing import Any, Tuple from unittest.mock import AsyncMock, Mock, patch from urllib.parse import quote -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import UserTypes diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py index 968d119a50..3d3904eac7 100644 --- a/tests/handlers/test_worker_lock.py +++ b/tests/handlers/test_worker_lock.py @@ -23,7 +23,7 @@ import logging import platform from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.util import Clock diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index a1243b053d..12428e64a9 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -45,6 +45,7 @@ from twisted.web.http_headers import Headers from twisted.web.iweb import IPolicyForHTTPS, IResponse from synapse.config.homeserver import HomeServerConfig +from synapse.config.server import parse_proxy_config from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.http.federation.srv_resolver import Server, SrvResolver @@ -280,6 +281,7 @@ class MatrixFederationAgentTests(unittest.TestCase): user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided. ip_allowlist=IPSet(), ip_blocklist=IPSet(), + proxy_config=parse_proxy_config({}), _srv_resolver=self.mock_resolver, _well_known_resolver=self.well_known_resolver, ) @@ -1023,6 +1025,7 @@ class MatrixFederationAgentTests(unittest.TestCase): user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. ip_allowlist=IPSet(), ip_blocklist=IPSet(), + proxy_config=None, _srv_resolver=self.mock_resolver, _well_known_resolver=WellKnownResolver( server_name="OUR_STUB_HOMESERVER_NAME", diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index dff5a5d262..393f3ab0bd 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -40,8 +40,8 @@ from unittest.mock import Mock from twisted.internet.defer import Deferred from twisted.internet.error import ConnectionDone +from twisted.internet.testing import MemoryReactorClock from twisted.python.failure import Failure -from twisted.test.proto_helpers import MemoryReactorClock from twisted.web.server import Site from synapse.http.server import ( diff --git a/tests/http/test_client.py b/tests/http/test_client.py index ac6470ebbd..a02f6fc728 100644 --- a/tests/http/test_client.py +++ b/tests/http/test_client.py @@ -27,8 +27,8 @@ from netaddr import IPSet from twisted.internet.defer import Deferred from twisted.internet.error import DNSLookupError +from twisted.internet.testing import AccumulatingProtocol from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol from twisted.web.client import Agent, ResponseDone from twisted.web.iweb import UNKNOWN_LENGTH diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py index d5ebf10eac..224883b635 100644 --- a/tests/http/test_matrixfederationclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -27,7 +27,7 @@ from parameterized import parameterized from twisted.internet import defer from twisted.internet.defer import Deferred, TimeoutError from twisted.internet.error import ConnectingCancelledError, DNSLookupError -from twisted.test.proto_helpers import MemoryReactor, StringTransport +from twisted.internet.testing import MemoryReactor, StringTransport from twisted.web.client import Agent, ResponseNeverReceived from twisted.web.http import HTTPChannel from twisted.web.http_headers import Headers diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index 2ef8a95c45..5bc5d18d81 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -39,6 +39,7 @@ from twisted.internet.protocol import Factory, Protocol from twisted.protocols.tls import TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel +from synapse.config.server import ProxyConfig, parse_proxy_config from synapse.http.client import BlocklistingReactorWrapper from synapse.http.connectproxyclient import BasicProxyCredentials from synapse.http.proxyagent import ProxyAgent, parse_proxy @@ -241,7 +242,7 @@ class TestBasicProxyCredentials(TestCase): ) -class MatrixFederationAgentTests(TestCase): +class ProxyAgentTests(TestCase): def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() @@ -379,27 +380,40 @@ class MatrixFederationAgentTests(TestCase): self.assertEqual(body, b"result") def test_http_request(self) -> None: - agent = ProxyAgent(self.reactor) + agent = ProxyAgent(reactor=self.reactor) self._test_request_direct_connection(agent, b"http", b"test.com", b"") def test_https_request(self) -> None: - agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy()) + agent = ProxyAgent(reactor=self.reactor, contextFactory=get_test_https_policy()) self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") - def test_http_request_use_proxy_empty_environment(self) -> None: - agent = ProxyAgent(self.reactor, use_proxy=True) + def test_http_request_proxy_config_empty_environment(self) -> None: + agent = ProxyAgent( + reactor=self.reactor, + proxy_config=parse_proxy_config({}), + ) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"}) def test_http_request_via_uppercase_no_proxy(self) -> None: - agent = ProxyAgent(self.reactor, use_proxy=True) + """ + Ensure hosts listed in the NO_PROXY environment variable are not sent via the + proxy. + """ + agent = ProxyAgent( + reactor=self.reactor, + proxy_config=parse_proxy_config({}), + ) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict( os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"} ) def test_http_request_via_no_proxy(self) -> None: - agent = ProxyAgent(self.reactor, use_proxy=True) + agent = ProxyAgent( + reactor=self.reactor, + proxy_config=parse_proxy_config({}), + ) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict( @@ -407,23 +421,26 @@ class MatrixFederationAgentTests(TestCase): ) def test_https_request_via_no_proxy(self) -> None: agent = ProxyAgent( - self.reactor, + reactor=self.reactor, contextFactory=get_test_https_policy(), - use_proxy=True, + proxy_config=parse_proxy_config({}), ) self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") @patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"}) def test_http_request_via_no_proxy_star(self) -> None: - agent = ProxyAgent(self.reactor, use_proxy=True) + agent = ProxyAgent( + reactor=self.reactor, + proxy_config=parse_proxy_config({}), + ) self._test_request_direct_connection(agent, b"http", b"test.com", b"") @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"}) def test_https_request_via_no_proxy_star(self) -> None: agent = ProxyAgent( - self.reactor, + reactor=self.reactor, contextFactory=get_test_https_policy(), - use_proxy=True, + proxy_config=parse_proxy_config({}), ) self._test_request_direct_connection(agent, b"https", b"test.com", b"abc") @@ -433,9 +450,72 @@ class MatrixFederationAgentTests(TestCase): Tests that requests can be made through a proxy. """ self._do_http_request_via_proxy( - expect_proxy_ssl=False, expected_auth_credentials=None + proxy_config=parse_proxy_config({}), + expect_proxy_ssl=False, + expected_auth_credentials=None, ) + def test_given_http_proxy_config(self) -> None: + self._do_http_request_via_proxy( + proxy_config=parse_proxy_config({"http_proxy": "proxy.com:8888"}), + expect_proxy_ssl=False, + expected_auth_credentials=None, + ) + + def test_given_https_proxy_config(self) -> None: + self._do_https_request_via_proxy( + proxy_config=parse_proxy_config({"https_proxy": "proxy.com"}), + expect_proxy_ssl=False, + expected_auth_credentials=None, + ) + + def test_given_no_proxy_hosts_config(self) -> None: + agent = ProxyAgent( + reactor=self.reactor, + proxy_config=parse_proxy_config( + {"http_proxy": "proxy.com:8888", "no_proxy_hosts": ["test.com"]} + ), + ) + self._test_request_direct_connection(agent, b"http", b"test.com", b"") + + @patch.dict( + os.environ, + {"http_proxy": "unused.com", "no_proxy": "unused.com"}, + ) + def test_given_http_proxy_config_overrides_environment_config(self) -> None: + """Tests that the given `http_proxy` in file config overrides the environment config.""" + self._do_http_request_via_proxy( + proxy_config=parse_proxy_config({"http_proxy": "proxy.com:8888"}), + expect_proxy_ssl=False, + expected_auth_credentials=None, + ) + + @patch.dict( + os.environ, + {"https_proxy": "unused.com", "no_proxy": "unused.com"}, + ) + def test_given_https_proxy_config_overrides_environment_config(self) -> None: + """Tests that the given `https_proxy` in file config overrides the environment config.""" + self._do_https_request_via_proxy( + proxy_config=parse_proxy_config({"https_proxy": "proxy.com"}), + expect_proxy_ssl=False, + expected_auth_credentials=None, + ) + + @patch.dict( + os.environ, + {"https_proxy": "unused.com", "no_proxy": "unused.com"}, + ) + def test_given_no_proxy_config_overrides_environment_config(self) -> None: + """Tests that the given `no_proxy_hosts` in file config overrides the `no_proxy` environment config.""" + agent = ProxyAgent( + reactor=self.reactor, + proxy_config=parse_proxy_config( + {"http_proxy": "proxy.com:8888", "no_proxy_hosts": ["test.com"]} + ), + ) + self._test_request_direct_connection(agent, b"http", b"test.com", b"") + @patch.dict( os.environ, {"http_proxy": "bob:pinkponies@proxy.com:8888", "no_proxy": "unused.com"}, @@ -445,7 +525,9 @@ class MatrixFederationAgentTests(TestCase): Tests that authenticated requests can be made through a proxy. """ self._do_http_request_via_proxy( - expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" + proxy_config=parse_proxy_config({}), + expect_proxy_ssl=False, + expected_auth_credentials=b"bob:pinkponies", ) @patch.dict( @@ -453,7 +535,9 @@ class MatrixFederationAgentTests(TestCase): ) def test_http_request_via_https_proxy(self) -> None: self._do_http_request_via_proxy( - expect_proxy_ssl=True, expected_auth_credentials=None + proxy_config=parse_proxy_config({}), + expect_proxy_ssl=True, + expected_auth_credentials=None, ) @patch.dict( @@ -465,14 +549,18 @@ class MatrixFederationAgentTests(TestCase): ) def test_http_request_via_https_proxy_with_auth(self) -> None: self._do_http_request_via_proxy( - expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" + proxy_config=parse_proxy_config({}), + expect_proxy_ssl=True, + expected_auth_credentials=b"bob:pinkponies", ) @patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"}) def test_https_request_via_proxy(self) -> None: """Tests that TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( - expect_proxy_ssl=False, expected_auth_credentials=None + proxy_config=parse_proxy_config({}), + expect_proxy_ssl=False, + expected_auth_credentials=None, ) @patch.dict( @@ -482,7 +570,9 @@ class MatrixFederationAgentTests(TestCase): def test_https_request_via_proxy_with_auth(self) -> None: """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( - expect_proxy_ssl=False, expected_auth_credentials=b"bob:pinkponies" + proxy_config=parse_proxy_config({}), + expect_proxy_ssl=False, + expected_auth_credentials=b"bob:pinkponies", ) @patch.dict( @@ -491,7 +581,9 @@ class MatrixFederationAgentTests(TestCase): def test_https_request_via_https_proxy(self) -> None: """Tests that TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( - expect_proxy_ssl=True, expected_auth_credentials=None + proxy_config=parse_proxy_config({}), + expect_proxy_ssl=True, + expected_auth_credentials=None, ) @patch.dict( @@ -501,11 +593,14 @@ class MatrixFederationAgentTests(TestCase): def test_https_request_via_https_proxy_with_auth(self) -> None: """Tests that authenticated, TLS-encrypted requests can be made through a proxy""" self._do_https_request_via_proxy( - expect_proxy_ssl=True, expected_auth_credentials=b"bob:pinkponies" + proxy_config=parse_proxy_config({}), + expect_proxy_ssl=True, + expected_auth_credentials=b"bob:pinkponies", ) def _do_http_request_via_proxy( self, + proxy_config: ProxyConfig, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, ) -> None: @@ -517,10 +612,15 @@ class MatrixFederationAgentTests(TestCase): """ if expect_proxy_ssl: agent = ProxyAgent( - self.reactor, use_proxy=True, contextFactory=get_test_https_policy() + reactor=self.reactor, + proxy_config=proxy_config, + contextFactory=get_test_https_policy(), ) else: - agent = ProxyAgent(self.reactor, use_proxy=True) + agent = ProxyAgent( + reactor=self.reactor, + proxy_config=proxy_config, + ) self.reactor.lookups["proxy.com"] = "1.2.3.5" d = agent.request(b"GET", b"http://test.com") @@ -580,6 +680,7 @@ class MatrixFederationAgentTests(TestCase): def _do_https_request_via_proxy( self, + proxy_config: ProxyConfig, expect_proxy_ssl: bool = False, expected_auth_credentials: Optional[bytes] = None, ) -> None: @@ -590,9 +691,9 @@ class MatrixFederationAgentTests(TestCase): expected_auth_credentials: credentials to authenticate at proxy """ agent = ProxyAgent( - self.reactor, + reactor=self.reactor, contextFactory=get_test_https_policy(), - use_proxy=True, + proxy_config=proxy_config, ) self.reactor.lookups["proxy.com"] = "1.2.3.5" @@ -713,11 +814,11 @@ class MatrixFederationAgentTests(TestCase): def test_http_request_via_proxy_with_blocklist(self) -> None: # The blocklist includes the configured proxy IP. agent = ProxyAgent( - BlocklistingReactorWrapper( + reactor=BlocklistingReactorWrapper( self.reactor, ip_allowlist=None, ip_blocklist=IPSet(["1.0.0.0/8"]) ), - self.reactor, - use_proxy=True, + proxy_reactor=self.reactor, + proxy_config=parse_proxy_config({}), ) self.reactor.lookups["proxy.com"] = "1.2.3.5" @@ -759,12 +860,12 @@ class MatrixFederationAgentTests(TestCase): def test_https_request_via_uppercase_proxy_with_blocklist(self) -> None: # The blocklist includes the configured proxy IP. agent = ProxyAgent( - BlocklistingReactorWrapper( + reactor=BlocklistingReactorWrapper( self.reactor, ip_allowlist=None, ip_blocklist=IPSet(["1.0.0.0/8"]) ), - self.reactor, + proxy_reactor=self.reactor, contextFactory=get_test_https_policy(), - use_proxy=True, + proxy_config=parse_proxy_config({}), ) self.reactor.lookups["proxy.com"] = "1.2.3.5" @@ -852,7 +953,10 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"http_proxy": "proxy.com:8888"}) def test_proxy_with_no_scheme(self) -> None: - http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) + http_proxy_agent = ProxyAgent( + reactor=self.reactor, + proxy_config=parse_proxy_config({}), + ) proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint) self.assertEqual(proxy_ep._hostText, "proxy.com") self.assertEqual(proxy_ep._port, 8888) @@ -860,18 +964,27 @@ class MatrixFederationAgentTests(TestCase): @patch.dict(os.environ, {"http_proxy": "socks://proxy.com:8888"}) def test_proxy_with_unsupported_scheme(self) -> None: with self.assertRaises(ValueError): - ProxyAgent(self.reactor, use_proxy=True) + ProxyAgent( + reactor=self.reactor, + proxy_config=parse_proxy_config({}), + ) @patch.dict(os.environ, {"http_proxy": "http://proxy.com:8888"}) def test_proxy_with_http_scheme(self) -> None: - http_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) + http_proxy_agent = ProxyAgent( + reactor=self.reactor, + proxy_config=parse_proxy_config({}), + ) proxy_ep = checked_cast(HostnameEndpoint, http_proxy_agent.http_proxy_endpoint) self.assertEqual(proxy_ep._hostText, "proxy.com") self.assertEqual(proxy_ep._port, 8888) @patch.dict(os.environ, {"http_proxy": "https://proxy.com:8888"}) def test_proxy_with_https_scheme(self) -> None: - https_proxy_agent = ProxyAgent(self.reactor, use_proxy=True) + https_proxy_agent = ProxyAgent( + reactor=self.reactor, + proxy_config=parse_proxy_config({}), + ) proxy_ep = checked_cast(_WrapperEndpoint, https_proxy_agent.http_proxy_endpoint) self.assertEqual(proxy_ep._wrappedEndpoint._hostText, "proxy.com") self.assertEqual(proxy_ep._wrappedEndpoint._port, 8888) diff --git a/tests/http/test_simple_client.py b/tests/http/test_simple_client.py index b7806fa947..c5ead59988 100644 --- a/tests/http/test_simple_client.py +++ b/tests/http/test_simple_client.py @@ -24,7 +24,7 @@ from netaddr import IPSet from twisted.internet import defer from twisted.internet.error import DNSLookupError -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient diff --git a/tests/http/test_site.py b/tests/http/test_site.py index fc620c705a..2eca4587e7 100644 --- a/tests/http/test_site.py +++ b/tests/http/test_site.py @@ -20,7 +20,7 @@ # from twisted.internet.address import IPv6Address -from twisted.test.proto_helpers import MemoryReactor, StringTransport +from twisted.internet.testing import MemoryReactor, StringTransport from synapse.app.homeserver import SynapseHomeServer from synapse.server import HomeServer diff --git a/tests/logging/test_loggers.py b/tests/logging/test_loggers.py new file mode 100644 index 0000000000..9a9bf14376 --- /dev/null +++ b/tests/logging/test_loggers.py @@ -0,0 +1,127 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# +# +import logging + +from synapse.logging.loggers import ExplicitlyConfiguredLogger + +from tests.unittest import TestCase + + +class ExplicitlyConfiguredLoggerTestCase(TestCase): + def _create_explicitly_configured_logger(self) -> logging.Logger: + original_logger_class = logging.getLoggerClass() + logging.setLoggerClass(ExplicitlyConfiguredLogger) + logger = logging.getLogger("test") + # Restore the original logger class + logging.setLoggerClass(original_logger_class) + + return logger + + def test_no_logs_when_not_set(self) -> None: + """ + Test to make sure that nothing is logged when the logger is *not* explicitly + configured. + """ + root_logger = logging.getLogger() + root_logger.setLevel(logging.DEBUG) + + logger = self._create_explicitly_configured_logger() + + with self.assertLogs(logger=logger, level=logging.NOTSET) as cm: + # XXX: We have to set this again because of a Python bug: + # https://github.com/python/cpython/issues/136958 (feel free to remove once + # that is resolved and we update to a newer Python version that includes the + # fix) + logger.setLevel(logging.NOTSET) + + logger.debug("debug message") + logger.info("info message") + logger.warning("warning message") + logger.error("error message") + + # Nothing should be logged since the logger is *not* explicitly configured + # + # FIXME: Remove this whole block once we update to Python 3.10 or later and + # have access to `assertNoLogs` (replace `assertLogs` with `assertNoLogs`) + self.assertIncludes( + set(cm.output), + set(), + exact=True, + ) + # Stub log message to avoid `assertLogs` failing since it expects at least + # one log message to be logged. + logger.setLevel(logging.INFO) + logger.info("stub message so `assertLogs` doesn't fail") + + def test_logs_when_explicitly_configured(self) -> None: + """ + Test to make sure that logs are emitted when the logger is explicitly configured. + """ + root_logger = logging.getLogger() + root_logger.setLevel(logging.INFO) + + logger = self._create_explicitly_configured_logger() + + with self.assertLogs(logger=logger, level=logging.DEBUG) as cm: + logger.debug("debug message") + logger.info("info message") + logger.warning("warning message") + logger.error("error message") + + self.assertIncludes( + set(cm.output), + { + "DEBUG:test:debug message", + "INFO:test:info message", + "WARNING:test:warning message", + "ERROR:test:error message", + }, + exact=True, + ) + + def test_is_enabled_for_not_set(self) -> None: + """ + Test to make sure `logger.isEnabledFor(...)` returns False when the logger is + not explicitly configured. + """ + + logger = self._create_explicitly_configured_logger() + + # Unset the logger (not configured) + logger.setLevel(logging.NOTSET) + + # The logger shouldn't be enabled for any level + self.assertFalse(logger.isEnabledFor(logging.DEBUG)) + self.assertFalse(logger.isEnabledFor(logging.INFO)) + self.assertFalse(logger.isEnabledFor(logging.WARNING)) + self.assertFalse(logger.isEnabledFor(logging.ERROR)) + + def test_is_enabled_for_info(self) -> None: + """ + Test to make sure `logger.isEnabledFor(...)` returns True any levels above the + explicitly configured level. + """ + + logger = self._create_explicitly_configured_logger() + + # Explicitly configure the logger to `INFO` level + logger.setLevel(logging.INFO) + + # The logger should be enabled for INFO and above once explicitly configured + self.assertFalse(logger.isEnabledFor(logging.DEBUG)) + self.assertTrue(logger.isEnabledFor(logging.INFO)) + self.assertTrue(logger.isEnabledFor(logging.WARNING)) + self.assertTrue(logger.isEnabledFor(logging.ERROR)) diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py index c7ef2bd7a4..7c5046baba 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py @@ -22,7 +22,7 @@ from typing import Awaitable, cast from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactorClock +from twisted.internet.testing import MemoryReactorClock from synapse.logging.context import ( LoggingContext, diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py index 4178a8d831..e0fd12ccf7 100644 --- a/tests/logging/test_remote_handler.py +++ b/tests/logging/test_remote_handler.py @@ -21,7 +21,7 @@ from typing import Tuple from twisted.internet.protocol import Protocol -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.internet.testing import AccumulatingProtocol, MemoryReactorClock from synapse.logging import RemoteHandler diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index d9cbbbd51e..60de8d786f 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -171,8 +171,9 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): site.site_tag = "test-site" site.server_version_string = "Server v1" site.reactor = Mock() + request = SynapseRequest( - cast(HTTPChannel, FakeChannel(site, self.reactor)), site + cast(HTTPChannel, FakeChannel(site, self.reactor)), site, "test_server" ) # Call requestReceived to finish instantiating the object. request.content = BytesIO() diff --git a/tests/media/test_media_retention.py b/tests/media/test_media_retention.py index 89cf61430a..6e01b9aecb 100644 --- a/tests/media/test_media_retention.py +++ b/tests/media/test_media_retention.py @@ -24,7 +24,7 @@ from typing import Iterable, Optional from matrix_common.types.mxc_uri import MXCUri -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, register, room diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index 2f7cf4569b..bf334c0371 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -33,8 +33,8 @@ from PIL import Image as Image from twisted.internet import defer from twisted.internet.defer import Deferred +from twisted.internet.testing import MemoryReactor from twisted.python.failure import Failure -from twisted.test.proto_helpers import MemoryReactor from twisted.web.http_headers import Headers from twisted.web.iweb import UNKNOWN_LENGTH, IResponse from twisted.web.resource import Resource diff --git a/tests/media/test_oembed.py b/tests/media/test_oembed.py index b8265ff9ca..afae7e048c 100644 --- a/tests/media/test_oembed.py +++ b/tests/media/test_oembed.py @@ -24,7 +24,7 @@ from typing import Any from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.media.oembed import OEmbedProvider, OEmbedResult from synapse.server import HomeServer diff --git a/tests/media/test_url_previewer.py b/tests/media/test_url_previewer.py index 0ae414d408..bd7190e3e9 100644 --- a/tests/media/test_url_previewer.py +++ b/tests/media/test_url_previewer.py @@ -20,7 +20,7 @@ # import os -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.util import Clock diff --git a/tests/metrics/test_background_process_metrics.py b/tests/metrics/test_background_process_metrics.py index f0f6cb2912..1f47601b95 100644 --- a/tests/metrics/test_background_process_metrics.py +++ b/tests/metrics/test_background_process_metrics.py @@ -14,6 +14,8 @@ class TestBackgroundProcessMetrics(StdlibTestCase): mock_logging_context = Mock(spec=LoggingContext) mock_logging_context.get_resource_usage.return_value = usage - process = _BackgroundProcess("test process", mock_logging_context) + process = _BackgroundProcess( + desc="test process", server_name="test_server", ctx=mock_logging_context + ) # Should not raise process.update_metrics() diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index e92d5f6dfa..61874564a6 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -59,7 +59,8 @@ class TestMauLimit(unittest.TestCase): foo: int bar: int - gauge: InFlightGauge[MetricEntry] = InFlightGauge( + # This is a test and does not matter if it uses `SERVER_NAME_LABEL`. + gauge: InFlightGauge[MetricEntry] = InFlightGauge( # type: ignore[missing-server-name-label] "test1", "", labels=["test_label"], sub_metrics=["foo", "bar"] ) @@ -159,26 +160,149 @@ class CacheMetricsTests(unittest.HomeserverTestCase): name=CACHE_NAME, server_name=self.hs.hostname, max_entries=777 ) - items = { - x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") - for x in filter( - lambda x: b"cache_metrics_test_fgjkbdfg" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - } + metrics_map = get_latest_metrics() - self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") - self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + cache_size_metric = f'synapse_util_caches_cache_size{{name="{CACHE_NAME}",server_name="{self.hs.hostname}"}}' + cache_max_size_metric = f'synapse_util_caches_cache_max_size{{name="{CACHE_NAME}",server_name="{self.hs.hostname}"}}' + + cache_size_metric_value = metrics_map.get(cache_size_metric) + self.assertIsNotNone( + cache_size_metric_value, + f"Missing metric {cache_size_metric} in cache metrics {metrics_map}", + ) + cache_max_size_metric_value = metrics_map.get(cache_max_size_metric) + self.assertIsNotNone( + cache_max_size_metric_value, + f"Missing metric {cache_max_size_metric} in cache metrics {metrics_map}", + ) + + self.assertEqual(cache_size_metric_value, "0.0") + self.assertEqual(cache_max_size_metric_value, "777.0") cache.prefill("1", "hi") - items = { - x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") - for x in filter( - lambda x: b"cache_metrics_test_fgjkbdfg" in x, - generate_latest(REGISTRY).split(b"\n"), - ) - } + metrics_map = get_latest_metrics() - self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") - self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + cache_size_metric_value = metrics_map.get(cache_size_metric) + self.assertIsNotNone( + cache_size_metric_value, + f"Missing metric {cache_size_metric} in cache metrics {metrics_map}", + ) + cache_max_size_metric_value = metrics_map.get(cache_max_size_metric) + self.assertIsNotNone( + cache_max_size_metric_value, + f"Missing metric {cache_max_size_metric} in cache metrics {metrics_map}", + ) + + self.assertEqual(cache_size_metric_value, "1.0") + self.assertEqual(cache_max_size_metric_value, "777.0") + + def test_cache_metric_multiple_servers(self) -> None: + """ + Test that cache metrics are reported correctly across multiple servers. We will + have an metrics entry for each homeserver that is labeled with the `server_name` + label. + """ + CACHE_NAME = "cache_metric_multiple_servers_test" + cache1: DeferredCache[str, str] = DeferredCache( + name=CACHE_NAME, server_name="hs1", max_entries=777 + ) + cache2: DeferredCache[str, str] = DeferredCache( + name=CACHE_NAME, server_name="hs2", max_entries=777 + ) + + metrics_map = get_latest_metrics() + + hs1_cache_size_metric = ( + f'synapse_util_caches_cache_size{{name="{CACHE_NAME}",server_name="hs1"}}' + ) + hs2_cache_size_metric = ( + f'synapse_util_caches_cache_size{{name="{CACHE_NAME}",server_name="hs2"}}' + ) + hs1_cache_max_size_metric = f'synapse_util_caches_cache_max_size{{name="{CACHE_NAME}",server_name="hs1"}}' + hs2_cache_max_size_metric = f'synapse_util_caches_cache_max_size{{name="{CACHE_NAME}",server_name="hs2"}}' + + # Find the metrics for the caches from both homeservers + hs1_cache_size_metric_value = metrics_map.get(hs1_cache_size_metric) + self.assertIsNotNone( + hs1_cache_size_metric_value, + f"Missing metric {hs1_cache_size_metric} in cache metrics {metrics_map}", + ) + hs2_cache_size_metric_value = metrics_map.get(hs2_cache_size_metric) + self.assertIsNotNone( + hs2_cache_size_metric_value, + f"Missing metric {hs2_cache_size_metric} in cache metrics {metrics_map}", + ) + hs1_cache_max_size_metric_value = metrics_map.get(hs1_cache_max_size_metric) + self.assertIsNotNone( + hs1_cache_max_size_metric_value, + f"Missing metric {hs1_cache_max_size_metric} in cache metrics {metrics_map}", + ) + hs2_cache_max_size_metric_value = metrics_map.get(hs2_cache_max_size_metric) + self.assertIsNotNone( + hs2_cache_max_size_metric_value, + f"Missing metric {hs2_cache_max_size_metric} in cache metrics {metrics_map}", + ) + + # Sanity check the metric values + self.assertEqual(hs1_cache_size_metric_value, "0.0") + self.assertEqual(hs2_cache_size_metric_value, "0.0") + self.assertEqual(hs1_cache_max_size_metric_value, "777.0") + self.assertEqual(hs2_cache_max_size_metric_value, "777.0") + + # Add something to both caches to change the numbers + cache1.prefill("1", "hi") + cache2.prefill("2", "ho") + + metrics_map = get_latest_metrics() + + # Find the metrics for the caches from both homeservers + hs1_cache_size_metric_value = metrics_map.get(hs1_cache_size_metric) + self.assertIsNotNone( + hs1_cache_size_metric_value, + f"Missing metric {hs1_cache_size_metric} in cache metrics {metrics_map}", + ) + hs2_cache_size_metric_value = metrics_map.get(hs2_cache_size_metric) + self.assertIsNotNone( + hs2_cache_size_metric_value, + f"Missing metric {hs2_cache_size_metric} in cache metrics {metrics_map}", + ) + hs1_cache_max_size_metric_value = metrics_map.get(hs1_cache_max_size_metric) + self.assertIsNotNone( + hs1_cache_max_size_metric_value, + f"Missing metric {hs1_cache_max_size_metric} in cache metrics {metrics_map}", + ) + hs2_cache_max_size_metric_value = metrics_map.get(hs2_cache_max_size_metric) + self.assertIsNotNone( + hs2_cache_max_size_metric_value, + f"Missing metric {hs2_cache_max_size_metric} in cache metrics {metrics_map}", + ) + + # Sanity check the metric values + self.assertEqual(hs1_cache_size_metric_value, "1.0") + self.assertEqual(hs2_cache_size_metric_value, "1.0") + self.assertEqual(hs1_cache_max_size_metric_value, "777.0") + self.assertEqual(hs2_cache_max_size_metric_value, "777.0") + + +def get_latest_metrics() -> Dict[str, str]: + """ + Collect the latest metrics from the registry and parse them into an easy to use map. + The key includes the metric name and labels. + + Example output: + { + "synapse_util_caches_cache_size": "0.0", + "synapse_util_caches_cache_max_size{name="some_cache",server_name="hs1"}": "777.0", + ... + } + """ + metric_map = { + x.split(b" ")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: len(x) > 0 and not x.startswith(b"#"), + generate_latest(REGISTRY).split(b"\n"), + ) + } + + return metric_map diff --git a/tests/metrics/test_phone_home_stats.py b/tests/metrics/test_phone_home_stats.py index 5339d649df..cf18d8635d 100644 --- a/tests/metrics/test_phone_home_stats.py +++ b/tests/metrics/test_phone_home_stats.py @@ -14,7 +14,7 @@ import logging from unittest.mock import AsyncMock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.app.phone_stats_home import ( PHONE_HOME_INTERVAL_SECONDS, diff --git a/tests/module_api/test_account_data_manager.py b/tests/module_api/test_account_data_manager.py index 1a1d5609b2..6539871c11 100644 --- a/tests/module_api/test_account_data_manager.py +++ b/tests/module_api/test_account_data_manager.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import SynapseError from synapse.rest import admin diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 9972af3aa3..6b761de36d 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -22,7 +22,7 @@ from typing import Any, Dict, Optional from unittest.mock import AsyncMock, Mock from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EduTypes, EventTypes from synapse.api.errors import NotFoundError diff --git a/tests/module_api/test_event_unsigned_addition.py b/tests/module_api/test_event_unsigned_addition.py index c429eff4d6..52e3858e6f 100644 --- a/tests/module_api/test_event_unsigned_addition.py +++ b/tests/module_api/test_event_unsigned_addition.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.events import EventBase from synapse.rest import admin, login, room diff --git a/tests/module_api/test_spamchecker.py b/tests/module_api/test_spamchecker.py index 926fe30b43..fa19232ee9 100644 --- a/tests/module_api/test_spamchecker.py +++ b/tests/module_api/test_spamchecker.py @@ -14,7 +14,7 @@ # from typing import Literal, Union -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.rest import admin, login, room, room_upgrade_rest_servlet diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 16c1292812..fad5c7affb 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -24,9 +24,9 @@ from unittest.mock import AsyncMock, patch from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor -from synapse.api.constants import EventContentFields, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.room_versions import RoomVersions from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.rest import admin @@ -206,7 +206,10 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): bulk_evaluator._action_for_event_by_user.assert_not_called() def _create_and_process( - self, bulk_evaluator: BulkPushRuleEvaluator, content: Optional[JsonDict] = None + self, + bulk_evaluator: BulkPushRuleEvaluator, + content: Optional[JsonDict] = None, + type: str = "test", ) -> bool: """Returns true iff the `mentions` trigger an event push action.""" # Create a new message event which should cause a notification. @@ -214,7 +217,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self.event_creation_handler.create_event( self.requester, { - "type": "test", + "type": type, "room_id": self.room_id, "content": content or {}, "sender": f"@bob:{self.hs.hostname}", @@ -446,3 +449,73 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): }, ) ) + + @override_config({"experimental_features": {"msc4306_enabled": True}}) + def test_thread_subscriptions(self) -> None: + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + (thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token) + + self.assertFalse( + self._create_and_process( + bulk_evaluator, + { + "msgtype": "m.text", + "body": "test message before subscription", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + type=EventTypes.Message, + ) + ) + + self.get_success( + self.hs.get_datastores().main.subscribe_user_to_thread( + self.alice, + self.room_id, + thread_root_id, + automatic_event_orderings=None, + ) + ) + + self.assertTrue( + self._create_and_process( + bulk_evaluator, + { + "msgtype": "m.text", + "body": "test message after subscription", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + type="m.room.message", + ) + ) + + def test_with_disabled_thread_subscriptions(self) -> None: + """ + Test what happens with threaded events when MSC4306 is disabled. + + FUTURE: If MSC4306 becomes enabled-by-default/accepted, this test is to be removed. + """ + bulk_evaluator = BulkPushRuleEvaluator(self.hs) + (thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token) + + # When MSC4306 is not enabled, a threaded message generates a notification + # by default. + self.assertTrue( + self._create_and_process( + bulk_evaluator, + { + "msgtype": "m.text", + "body": "test message before subscription", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + type="m.room.message", + ) + ) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 4fafb71897..4d885c78eb 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -27,7 +27,7 @@ import pkg_resources from parameterized import parameterized from twisted.internet.defer import Deferred -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes, SynapseError diff --git a/tests/push/test_http.py b/tests/push/test_http.py index b42fd284b6..370233c730 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -23,7 +23,7 @@ from unittest.mock import Mock from parameterized import parameterized from twisted.internet.defer import Deferred -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 98a3a22154..3a351acffa 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Union, cast -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventTypes, HistoryVisibility, Membership @@ -150,6 +150,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): *, related_events: Optional[JsonDict] = None, msc4210: bool = False, + msc4306: bool = False, ) -> PushRuleEvaluator: event = FrozenEvent( { @@ -176,6 +177,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_version_feature_flags=event.room_version.msc3931_push_features, msc3931_enabled=True, msc4210_enabled=msc4210, + msc4306_enabled=msc4306, ) def test_display_name(self) -> None: @@ -806,6 +808,112 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): ) ) + def test_thread_subscription_subscribed(self) -> None: + """ + Test MSC4306 thread subscription push rules against an event in a subscribed thread. + """ + evaluator = self._get_evaluator( + { + "msgtype": "m.text", + "body": "Squawk", + "m.relates_to": { + "event_id": "$threadroot", + "rel_type": "m.thread", + }, + }, + msc4306=True, + ) + self.assertTrue( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": True, + }, + None, + None, + msc4306_thread_subscription_state=True, + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": False, + }, + None, + None, + msc4306_thread_subscription_state=True, + ) + ) + + def test_thread_subscription_unsubscribed(self) -> None: + """ + Test MSC4306 thread subscription push rules against an event in an unsubscribed thread. + """ + evaluator = self._get_evaluator( + { + "msgtype": "m.text", + "body": "Squawk", + "m.relates_to": { + "event_id": "$threadroot", + "rel_type": "m.thread", + }, + }, + msc4306=True, + ) + self.assertFalse( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": True, + }, + None, + None, + msc4306_thread_subscription_state=False, + ) + ) + self.assertTrue( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": False, + }, + None, + None, + msc4306_thread_subscription_state=False, + ) + ) + + def test_thread_subscription_unthreaded(self) -> None: + """ + Test MSC4306 thread subscription push rules against an unthreaded event. + """ + evaluator = self._get_evaluator( + {"msgtype": "m.text", "body": "Squawk"}, msc4306=True + ) + self.assertFalse( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": True, + }, + None, + None, + msc4306_thread_subscription_state=None, + ) + ) + self.assertFalse( + evaluator.matches( + { + "kind": "io.element.msc4306.thread_subscription", + "subscribed": False, + }, + None, + None, + msc4306_thread_subscription_state=None, + ) + ) + class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 75a36303c9..453eb7750b 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -23,8 +23,8 @@ from typing import Any, Dict, List, Optional, Set, Tuple from twisted.internet.address import IPv4Address from twisted.internet.protocol import Protocol, connectionDone +from twisted.internet.testing import MemoryReactor from twisted.python.failure import Failure -from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource from synapse.app.generic_worker import GenericWorkerServer diff --git a/tests/replication/storage/_base.py b/tests/replication/storage/_base.py index 27dff0034f..97e744127c 100644 --- a/tests/replication/storage/_base.py +++ b/tests/replication/storage/_base.py @@ -22,7 +22,7 @@ from typing import Any, Callable, Iterable, Optional from unittest.mock import Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.util import Clock diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py index 1afe523d02..b3ca204995 100644 --- a/tests/replication/storage/test_events.py +++ b/tests/replication/storage/test_events.py @@ -24,7 +24,7 @@ from typing import Any, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import ReceiptTypes from synapse.api.room_versions import RoomVersions diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 2a0189a4e1..cd6fe53a96 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -22,7 +22,7 @@ from typing import Any, List, Optional from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase diff --git a/tests/replication/tcp/streams/test_thread_subscriptions.py b/tests/replication/tcp/streams/test_thread_subscriptions.py index 30c3415ad4..7283aa851e 100644 --- a/tests/replication/tcp/streams/test_thread_subscriptions.py +++ b/tests/replication/tcp/streams/test_thread_subscriptions.py @@ -12,7 +12,7 @@ # . # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.replication.tcp.streams._base import ( _STREAM_UPDATE_TARGET_ROW_COUNT, @@ -62,7 +62,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase): "@test_user:example.org", room_id, thread_root_id, - automatic=True, + automatic_event_orderings=None, ) ) updates.append(thread_root_id) @@ -75,7 +75,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase): "@test_user:example.org", other_room_id, other_thread_root_id, - automatic=False, + automatic_event_orderings=None, ) ) @@ -124,7 +124,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase): for user_id in users: self.get_success( store.subscribe_user_to_thread( - user_id, room_id, thread_root_id, automatic=True + user_id, room_id, thread_root_id, automatic_event_orderings=None ) ) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index c8958189f8..e2b2299106 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -18,6 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # +import logging from unittest.mock import Mock from synapse.handlers.typing import RoomMember, TypingWriterHandler @@ -99,75 +100,86 @@ class TypingStreamTestCase(BaseStreamTestCase): This is emulated by jumping the stream ahead, then reconnecting (which sends the proper position and RDATA). """ - typing = self.hs.get_typing_handler() - assert isinstance(typing, TypingWriterHandler) + # FIXME: Because huge RDATA log line is triggered in this test, + # trial breaks, sometimes (flakily) failing the test run. + # ref: https://github.com/twisted/twisted/issues/12482 + # To remove this, we would need to fix the above issue and + # update, including in olddeps (so several years' wait). + server_logger = logging.getLogger("tests.server") + server_logger_was_disabled = server_logger.disabled + server_logger.disabled = True + try: + typing = self.hs.get_typing_handler() + assert isinstance(typing, TypingWriterHandler) - # Create a typing update before we reconnect so that there is a missing - # update to fetch. - typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True) + # Create a typing update before we reconnect so that there is a missing + # update to fetch. + typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True) - self.reconnect() + self.reconnect() - typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True) + typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True) - self.reactor.advance(0) + self.reactor.advance(0) - # We should now see an attempt to connect to the master - request = self.handle_http_replication_attempt() - self.assert_request_is_get_repl_stream_updates(request, "typing") + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "typing") - self.mock_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] - self.assertEqual(stream_name, "typing") - self.assertEqual(1, len(rdata_rows)) - row: TypingStream.TypingStreamRow = rdata_rows[0] - self.assertEqual(ROOM_ID, row.room_id) - self.assertEqual([USER_ID], row.user_ids) + self.mock_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "typing") + self.assertEqual(1, len(rdata_rows)) + row: TypingStream.TypingStreamRow = rdata_rows[0] + self.assertEqual(ROOM_ID, row.room_id) + self.assertEqual([USER_ID], row.user_ids) - # Push the stream forward a bunch so it can be reset. - for i in range(100): - typing._push_update( - member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True + # Push the stream forward a bunch so it can be reset. + for i in range(100): + typing._push_update( + member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True + ) + self.reactor.advance(0) + + # Disconnect. + self.disconnect() + + # Reset the typing handler + self.hs.get_replication_streams()["typing"].last_token = 0 + self.hs.get_replication_command_handler()._streams["typing"].last_token = 0 + typing._latest_room_serial = 0 + typing._typing_stream_change_cache = StreamChangeCache( + name="TypingStreamChangeCache", + server_name=self.hs.hostname, + current_stream_pos=typing._latest_room_serial, ) - self.reactor.advance(0) + typing._reset() - # Disconnect. - self.disconnect() + # Reconnect. + self.reconnect() + self.pump(0.1) - # Reset the typing handler - self.hs.get_replication_streams()["typing"].last_token = 0 - self.hs.get_replication_command_handler()._streams["typing"].last_token = 0 - typing._latest_room_serial = 0 - typing._typing_stream_change_cache = StreamChangeCache( - name="TypingStreamChangeCache", - server_name=self.hs.hostname, - current_stream_pos=typing._latest_room_serial, - ) - typing._reset() + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "typing") - # Reconnect. - self.reconnect() - self.pump(0.1) + # Reset the test code. + self.mock_handler.on_rdata.reset_mock() + self.mock_handler.on_rdata.assert_not_called() - # We should now see an attempt to connect to the master - request = self.handle_http_replication_attempt() - self.assert_request_is_get_repl_stream_updates(request, "typing") + # Push additional data. + typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False) + self.reactor.advance(0) - # Reset the test code. - self.mock_handler.on_rdata.reset_mock() - self.mock_handler.on_rdata.assert_not_called() + self.mock_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] + self.assertEqual(stream_name, "typing") + self.assertEqual(1, len(rdata_rows)) + row = rdata_rows[0] + self.assertEqual(ROOM_ID_2, row.room_id) + self.assertEqual([], row.user_ids) - # Push additional data. - typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False) - self.reactor.advance(0) - - self.mock_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] - self.assertEqual(stream_name, "typing") - self.assertEqual(1, len(rdata_rows)) - row = rdata_rows[0] - self.assertEqual(ROOM_ID_2, row.room_id) - self.assertEqual([], row.user_ids) - - # The token should have been reset. - self.assertEqual(token, 1) + # The token should have been reset. + self.assertEqual(token, 1) + finally: + server_logger.disabled = server_logger_was_disabled diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py index 7820de8acc..640ed4e8f3 100644 --- a/tests/replication/test_auth.py +++ b/tests/replication/test_auth.py @@ -20,7 +20,7 @@ # import logging -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.rest.client import register from synapse.server import HomeServer diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 14c9483f2b..440c1d45af 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -21,7 +21,7 @@ from unittest import mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.app.generic_worker import GenericWorkerServer from synapse.replication.tcp.commands import FederationAckCommand diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 5b7ed95a23..1fed4ec631 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -28,7 +28,7 @@ from signedjson.key import ( get_verify_key, ) -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersion @@ -74,6 +74,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): user_agent=b"SynapseInTrialTest/0.0.0", ip_allowlist=None, ip_blocklist=IPSet(), + proxy_config=None, ) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index f36af877c4..228a803c1d 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -23,7 +23,7 @@ import os from typing import Any, Optional, Tuple from twisted.internet.protocol import Factory -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.web.http import HTTPChannel from twisted.web.server import Request diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 1b0bdc262a..d63054c631 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -22,7 +22,7 @@ import logging from unittest.mock import Mock from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, room diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index ce6ad75901..797ad003ef 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -21,7 +21,7 @@ import logging from unittest.mock import patch -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, room, sync diff --git a/tests/replication/test_sharded_receipts.py b/tests/replication/test_sharded_receipts.py index e400267819..6b3ecdad78 100644 --- a/tests/replication/test_sharded_receipts.py +++ b/tests/replication/test_sharded_receipts.py @@ -20,7 +20,7 @@ # import logging -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import ReceiptTypes from synapse.rest import admin diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index fc2a6c569b..b74e8388e9 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -24,7 +24,7 @@ from typing import Dict, cast from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin diff --git a/tests/rest/admin/test_background_updates.py b/tests/rest/admin/test_background_updates.py index f33aada64b..dd116e79f1 100644 --- a/tests/rest/admin/test_background_updates.py +++ b/tests/rest/admin/test_background_updates.py @@ -22,7 +22,7 @@ from typing import Collection from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py index 660fa465bd..c564e0c9a7 100644 --- a/tests/rest/admin/test_device.py +++ b/tests/rest/admin/test_device.py @@ -22,7 +22,7 @@ import urllib.parse from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 6047ce1f4a..a6f958658f 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -20,7 +20,7 @@ # from typing import List -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index d5ae3345f5..cfea480bf0 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -22,7 +22,7 @@ from typing import List, Optional from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index da0e9749aa..f863b5f8e7 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -24,7 +24,7 @@ from typing import Dict from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 67d1db8ff8..b8e111c804 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -22,7 +22,7 @@ import random import string from typing import Optional -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 20fe939e4e..6454f18571 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -28,7 +28,7 @@ from unittest.mock import AsyncMock, Mock from parameterized import parameterized from twisted.internet.task import deferLater -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventTypes, Membership, RoomTypes @@ -2917,6 +2917,39 @@ class MakeRoomAdminTestCase(unittest.HomeserverTestCase): ) self.assertEquals(pl["users"][self.admin_user], 100) + def test_v12_room_with_many_user_pls(self) -> None: + """Test that you can be promoted to the admin user's PL in v12 rooms that contain a range of user PLs.""" + room_id = self.helper.create_room_as( + self.creator, + tok=self.creator_tok, + room_version=RoomVersions.V12.identifier, + is_public=True, + extra_content={ + "power_level_content_override": { + "users": { + self.second_user_id: 50, + }, + }, + }, + ) + + self.helper.join(room_id, self.admin_user, tok=self.admin_user_tok) + self.helper.join(room_id, self.second_user_id, tok=self.second_tok) + + channel = self.make_request( + "POST", + f"/_synapse/admin/v1/rooms/{room_id}/make_room_admin", + content={}, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + + pl = self.helper.get_state( + room_id, EventTypes.PowerLevels, tok=self.creator_tok + ) + self.assertEquals(pl["users"][self.admin_user], 100) + class BlockRoomTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/admin/test_scheduled_tasks.py b/tests/rest/admin/test_scheduled_tasks.py index 9654e9322b..ea7afc0101 100644 --- a/tests/rest/admin/test_scheduled_tasks.py +++ b/tests/rest/admin/test_scheduled_tasks.py @@ -15,7 +15,7 @@ # from typing import Mapping, Optional, Tuple -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 150caeeee2..1f77e31d48 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -20,7 +20,7 @@ # from typing import List, Sequence -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index 07ec49c4e5..10efc4ef8b 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -21,7 +21,7 @@ # from typing import Dict, List, Optional -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 52491ff3ba..4432b6a7a0 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -32,7 +32,7 @@ from unittest.mock import AsyncMock, Mock, patch from parameterized import parameterized, parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index 4dd5de33d3..9c3ab3e64c 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -20,7 +20,7 @@ # from typing import Optional -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.errors import Codes, SynapseError diff --git a/tests/rest/client/sliding_sync/test_connection_tracking.py b/tests/rest/client/sliding_sync/test_connection_tracking.py index 5b819103c2..f8ce1104a8 100644 --- a/tests/rest/client/sliding_sync/test_connection_tracking.py +++ b/tests/rest/client/sliding_sync/test_connection_tracking.py @@ -15,7 +15,7 @@ import logging from parameterized import parameterized, parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventTypes diff --git a/tests/rest/client/sliding_sync/test_extension_account_data.py b/tests/rest/client/sliding_sync/test_extension_account_data.py index 799fbb1856..5949065722 100644 --- a/tests/rest/client/sliding_sync/test_extension_account_data.py +++ b/tests/rest/client/sliding_sync/test_extension_account_data.py @@ -17,7 +17,7 @@ import logging from parameterized import parameterized, parameterized_class from typing_extensions import assert_never -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import AccountDataTypes diff --git a/tests/rest/client/sliding_sync/test_extension_e2ee.py b/tests/rest/client/sliding_sync/test_extension_e2ee.py index 7ce6592d8f..baf6a5882e 100644 --- a/tests/rest/client/sliding_sync/test_extension_e2ee.py +++ b/tests/rest/client/sliding_sync/test_extension_e2ee.py @@ -15,7 +15,7 @@ import logging from parameterized import parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.rest.client import devices, login, room, sync diff --git a/tests/rest/client/sliding_sync/test_extension_receipts.py b/tests/rest/client/sliding_sync/test_extension_receipts.py index 6e7700b533..1bba3038db 100644 --- a/tests/rest/client/sliding_sync/test_extension_receipts.py +++ b/tests/rest/client/sliding_sync/test_extension_receipts.py @@ -15,7 +15,7 @@ import logging from parameterized import parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EduTypes, ReceiptTypes diff --git a/tests/rest/client/sliding_sync/test_extension_to_device.py b/tests/rest/client/sliding_sync/test_extension_to_device.py index 790abb739d..151a5be665 100644 --- a/tests/rest/client/sliding_sync/test_extension_to_device.py +++ b/tests/rest/client/sliding_sync/test_extension_to_device.py @@ -16,7 +16,7 @@ from typing import List from parameterized import parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.rest.client import login, sendtodevice, sync diff --git a/tests/rest/client/sliding_sync/test_extension_typing.py b/tests/rest/client/sliding_sync/test_extension_typing.py index f87c3c8b17..37c90d6ec2 100644 --- a/tests/rest/client/sliding_sync/test_extension_typing.py +++ b/tests/rest/client/sliding_sync/test_extension_typing.py @@ -15,7 +15,7 @@ import logging from parameterized import parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EduTypes diff --git a/tests/rest/client/sliding_sync/test_extensions.py b/tests/rest/client/sliding_sync/test_extensions.py index 30230e5c4b..0643596e59 100644 --- a/tests/rest/client/sliding_sync/test_extensions.py +++ b/tests/rest/client/sliding_sync/test_extensions.py @@ -17,7 +17,7 @@ from typing import Literal from parameterized import parameterized, parameterized_class from typing_extensions import assert_never -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import ReceiptTypes diff --git a/tests/rest/client/sliding_sync/test_lists_filters.py b/tests/rest/client/sliding_sync/test_lists_filters.py index c59f6aedc4..57d00a2a7a 100644 --- a/tests/rest/client/sliding_sync/test_lists_filters.py +++ b/tests/rest/client/sliding_sync/test_lists_filters.py @@ -15,7 +15,7 @@ import logging from parameterized import parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( diff --git a/tests/rest/client/sliding_sync/test_room_subscriptions.py b/tests/rest/client/sliding_sync/test_room_subscriptions.py index 285fdaaf78..b78e4f2045 100644 --- a/tests/rest/client/sliding_sync/test_room_subscriptions.py +++ b/tests/rest/client/sliding_sync/test_room_subscriptions.py @@ -16,7 +16,7 @@ from http import HTTPStatus from parameterized import parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventTypes, HistoryVisibility diff --git a/tests/rest/client/sliding_sync/test_rooms_invites.py b/tests/rest/client/sliding_sync/test_rooms_invites.py index 882762ca29..a0f4ccd2cc 100644 --- a/tests/rest/client/sliding_sync/test_rooms_invites.py +++ b/tests/rest/client/sliding_sync/test_rooms_invites.py @@ -15,7 +15,7 @@ import logging from parameterized import parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventTypes, HistoryVisibility diff --git a/tests/rest/client/sliding_sync/test_rooms_meta.py b/tests/rest/client/sliding_sync/test_rooms_meta.py index 0a8b2c02c2..4559bc7646 100644 --- a/tests/rest/client/sliding_sync/test_rooms_meta.py +++ b/tests/rest/client/sliding_sync/test_rooms_meta.py @@ -15,7 +15,7 @@ import logging from parameterized import parameterized, parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership diff --git a/tests/rest/client/sliding_sync/test_rooms_required_state.py b/tests/rest/client/sliding_sync/test_rooms_required_state.py index ba46c5a93c..cfff167c6e 100644 --- a/tests/rest/client/sliding_sync/test_rooms_required_state.py +++ b/tests/rest/client/sliding_sync/test_rooms_required_state.py @@ -16,7 +16,7 @@ import logging from parameterized import parameterized, parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, JoinRules, Membership diff --git a/tests/rest/client/sliding_sync/test_rooms_timeline.py b/tests/rest/client/sliding_sync/test_rooms_timeline.py index 535420209b..3d950eb20b 100644 --- a/tests/rest/client/sliding_sync/test_rooms_timeline.py +++ b/tests/rest/client/sliding_sync/test_rooms_timeline.py @@ -16,7 +16,7 @@ from typing import List, Optional from parameterized import parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventTypes diff --git a/tests/rest/client/sliding_sync/test_sliding_sync.py b/tests/rest/client/sliding_sync/test_sliding_sync.py index 412dbcf77b..ea4ee16359 100644 --- a/tests/rest/client/sliding_sync/test_sliding_sync.py +++ b/tests/rest/client/sliding_sync/test_sliding_sync.py @@ -18,7 +18,7 @@ from unittest.mock import AsyncMock from parameterized import parameterized, parameterized_class from typing_extensions import assert_never -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 5343b10e92..02d02ae78e 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -28,7 +28,7 @@ from unittest.mock import Mock import pkg_resources from twisted.internet.interfaces import IReactorTCP -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import LoginType, Membership diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 0b5daf4bb4..4fe506845c 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -23,7 +23,7 @@ from http import HTTPStatus from typing import Any, Dict, List, Optional, Tuple, Union from twisted.internet.defer import succeed -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index cdf31155fd..8ae1cc935a 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -19,7 +19,7 @@ # from http import HTTPStatus -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.room_versions import KNOWN_ROOM_VERSIONS diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 5f4168c56c..1a64b3984f 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -21,7 +21,7 @@ import os from http import HTTPStatus -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.urls import ConsentURIBuilder diff --git a/tests/rest/client/test_delayed_events.py b/tests/rest/client/test_delayed_events.py index 9f9d241f12..4b338d333f 100644 --- a/tests/rest/client/test_delayed_events.py +++ b/tests/rest/client/test_delayed_events.py @@ -19,7 +19,7 @@ from typing import List from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import Codes from synapse.rest import admin diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py index b7230488e4..2c498e97e1 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py @@ -21,7 +21,7 @@ from http import HTTPStatus from twisted.internet.defer import ensureDeferred -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import NotFoundError from synapse.appservice import ApplicationService diff --git a/tests/rest/client/test_directory.py b/tests/rest/client/test_directory.py index 6e499093cf..6548ac6fa8 100644 --- a/tests/rest/client/test_directory.py +++ b/tests/rest/client/test_directory.py @@ -19,7 +19,7 @@ # from http import HTTPStatus -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.appservice import ApplicationService from synapse.rest import admin diff --git a/tests/rest/client/test_ephemeral_message.py b/tests/rest/client/test_ephemeral_message.py index 2d98fda67f..5b5c220825 100644 --- a/tests/rest/client/test_ephemeral_message.py +++ b/tests/rest/client/test_ephemeral_message.py @@ -19,7 +19,7 @@ # from http import HTTPStatus -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventContentFields, EventTypes from synapse.rest import admin diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index 039144fdbe..142509bbf7 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -23,7 +23,7 @@ from unittest.mock import Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EduTypes diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 9cfc6b224f..4153fb322d 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -19,7 +19,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import Codes from synapse.rest.client import filter diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index 63c2c5923e..87af18f473 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -20,7 +20,7 @@ from http import HTTPStatus -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.rest.client import login, room diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index b8bcc235e9..d3a7905ef2 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -37,7 +37,7 @@ from urllib.parse import urlencode import pymacaroons -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py index fbacf9d869..202d2cf351 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py @@ -19,7 +19,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, login_token_request, versions diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 7aa1f2406c..e6ed47f83a 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -38,8 +38,8 @@ from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.defer import Deferred from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import IAddress, IResolutionReceiver +from twisted.internet.testing import AccumulatingProtocol, MemoryReactor from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor from twisted.web.http_headers import Headers from twisted.web.iweb import UNKNOWN_LENGTH, IResponse from twisted.web.resource import Resource diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py index 637722ca0a..2e37284680 100644 --- a/tests/rest/client/test_mutual_rooms.py +++ b/tests/rest/client/test_mutual_rooms.py @@ -20,7 +20,7 @@ # from urllib.parse import quote -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.rest.client import login, mutual_rooms, room diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py index e4b0455ce8..ec66567817 100644 --- a/tests/rest/client/test_notifications.py +++ b/tests/rest/client/test_notifications.py @@ -21,7 +21,7 @@ from typing import List, Optional, Tuple from unittest.mock import AsyncMock, Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.rest.client import login, notifications, receipts, room diff --git a/tests/rest/client/test_owned_state.py b/tests/rest/client/test_owned_state.py index 5fb5767676..386b95d616 100644 --- a/tests/rest/client/test_owned_state.py +++ b/tests/rest/client/test_owned_state.py @@ -2,7 +2,7 @@ from http import HTTPStatus from parameterized import parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import Codes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions diff --git a/tests/rest/client/test_password_policy.py b/tests/rest/client/test_password_policy.py index f0ef733f7b..33bab684e3 100644 --- a/tests/rest/client/test_password_policy.py +++ b/tests/rest/client/test_password_policy.py @@ -21,7 +21,7 @@ from http import HTTPStatus -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import LoginType from synapse.api.errors import Codes diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py index 1584c2e96c..39ea9acef6 100644 --- a/tests/rest/client/test_power_levels.py +++ b/tests/rest/client/test_power_levels.py @@ -20,7 +20,7 @@ # from http import HTTPStatus -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import Codes from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 6b9c70974a..7138cc92c2 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -20,7 +20,7 @@ from http import HTTPStatus from unittest.mock import AsyncMock, Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.handlers.presence import PresenceHandler from synapse.rest.client import presence diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 49776d8e8c..936e573bcd 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -21,13 +21,14 @@ """Tests REST events for /profile paths.""" +import logging import urllib.parse from http import HTTPStatus from typing import Any, Dict, Optional from canonicaljson import encode_canonical_json -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import Codes from synapse.rest import admin @@ -648,87 +649,99 @@ class ProfileTestCase(unittest.HomeserverTestCase): """ Attempts to set a custom field that would push the overall profile too large. """ - # Get right to the boundary: - # len("displayname") + len("owner") + 5 = 21 for the displayname - # 1 + 65498 + 5 for key "a" = 65504 - # 2 braces, 1 comma - # 3 + 21 + 65498 = 65522 < 65536. - key = "a" - channel = self.make_request( - "PUT", - f"/_matrix/client/v3/profile/{self.owner}/{key}", - content={key: "a" * 65498}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 200, channel.result) + # FIXME: Because we emit huge SQL log lines and trial can't handle these, + # sometimes (flakily) failing the test run, + # disable SQL logging for this test. + # ref: https://github.com/twisted/twisted/issues/12482 + # To remove this, we would need to fix the above issue and + # update, including in olddeps (so several years' wait). + sql_logger = logging.getLogger("synapse.storage.SQL") + sql_logger_was_disabled = sql_logger.disabled + sql_logger.disabled = True + try: + # Get right to the boundary: + # len("displayname") + len("owner") + 5 = 21 for the displayname + # 1 + 65498 + 5 for key "a" = 65504 + # 2 braces, 1 comma + # 3 + 21 + 65498 = 65522 < 65536. + key = "a" + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.owner}/{key}", + content={key: "a" * 65498}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) - # Get the entire profile. - channel = self.make_request( - "GET", - f"/_matrix/client/v3/profile/{self.owner}", - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 200, channel.result) - canonical_json = encode_canonical_json(channel.json_body) - # 6 is the minimum bytes to store a value: 4 quotes, 1 colon, 1 comma, an empty key. - # Be one below that so we can prove we're at the boundary. - self.assertEqual(len(canonical_json), MAX_PROFILE_SIZE - 8) + # Get the entire profile. + channel = self.make_request( + "GET", + f"/_matrix/client/v3/profile/{self.owner}", + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + canonical_json = encode_canonical_json(channel.json_body) + # 6 is the minimum bytes to store a value: 4 quotes, 1 colon, 1 comma, an empty key. + # Be one below that so we can prove we're at the boundary. + self.assertEqual(len(canonical_json), MAX_PROFILE_SIZE - 8) - # Postgres stores JSONB with whitespace, while SQLite doesn't. - if USE_POSTGRES_FOR_TESTS: - ADDITIONAL_CHARS = 0 - else: - ADDITIONAL_CHARS = 1 + # Postgres stores JSONB with whitespace, while SQLite doesn't. + if USE_POSTGRES_FOR_TESTS: + ADDITIONAL_CHARS = 0 + else: + ADDITIONAL_CHARS = 1 - # The next one should fail, note the value has a (JSON) length of 2. - key = "b" - channel = self.make_request( - "PUT", - f"/_matrix/client/v3/profile/{self.owner}/{key}", - content={key: "1" + "a" * ADDITIONAL_CHARS}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE) + # The next one should fail, note the value has a (JSON) length of 2. + key = "b" + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.owner}/{key}", + content={key: "1" + "a" * ADDITIONAL_CHARS}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE) - # Setting an avatar or (longer) display name should not work. - channel = self.make_request( - "PUT", - f"/profile/{self.owner}/displayname", - content={"displayname": "owner12345678" + "a" * ADDITIONAL_CHARS}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE) + # Setting an avatar or (longer) display name should not work. + channel = self.make_request( + "PUT", + f"/profile/{self.owner}/displayname", + content={"displayname": "owner12345678" + "a" * ADDITIONAL_CHARS}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE) - channel = self.make_request( - "PUT", - f"/profile/{self.owner}/avatar_url", - content={"avatar_url": "mxc://foo/bar"}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE) + channel = self.make_request( + "PUT", + f"/profile/{self.owner}/avatar_url", + content={"avatar_url": "mxc://foo/bar"}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], Codes.PROFILE_TOO_LARGE) - # Removing a single byte should work. - key = "b" - channel = self.make_request( - "PUT", - f"/_matrix/client/v3/profile/{self.owner}/{key}", - content={key: "" + "a" * ADDITIONAL_CHARS}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 200, channel.result) + # Removing a single byte should work. + key = "b" + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.owner}/{key}", + content={key: "" + "a" * ADDITIONAL_CHARS}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) - # Finally, setting a field that already exists to a value that is <= in length should work. - key = "a" - channel = self.make_request( - "PUT", - f"/_matrix/client/v3/profile/{self.owner}/{key}", - content={key: ""}, - access_token=self.owner_tok, - ) - self.assertEqual(channel.code, 200, channel.result) + # Finally, setting a field that already exists to a value that is <= in length should work. + key = "a" + channel = self.make_request( + "PUT", + f"/_matrix/client/v3/profile/{self.owner}/{key}", + content={key: ""}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + finally: + sql_logger.disabled = sql_logger_was_disabled def test_set_custom_field_displayname(self) -> None: channel = self.make_request( diff --git a/tests/rest/client/test_read_marker.py b/tests/rest/client/test_read_marker.py index 0b4ad685b3..a27eb9453b 100644 --- a/tests/rest/client/test_read_marker.py +++ b/tests/rest/client/test_read_marker.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventTypes diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py index f0648289f1..ae4818c412 100644 --- a/tests/rest/client/test_receipts.py +++ b/tests/rest/client/test_receipts.py @@ -21,7 +21,7 @@ from http import HTTPStatus from typing import Optional -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EduTypes, EventTypes, HistoryVisibility, ReceiptTypes diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index b25e184786..d435a9e393 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -22,7 +22,7 @@ from typing import List, Optional from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, RelationTypes from synapse.api.room_versions import RoomVersion, RoomVersions diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 638fbf0062..f0745cf298 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -26,7 +26,7 @@ from unittest.mock import AsyncMock import pkg_resources -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 8f2f44739c..fd1e87296c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -23,7 +23,7 @@ import urllib.parse from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import AsyncMock, patch -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import AccountDataTypes, EventTypes, RelationTypes from synapse.rest import admin diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py index 83a5cbdc15..01401f73da 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py @@ -22,7 +22,7 @@ from typing import Dict from urllib.parse import urlparse -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.web.resource import Resource from synapse.rest.client import rendezvous diff --git a/tests/rest/client/test_reporting.py b/tests/rest/client/test_reporting.py index 37517e5c21..5e5af34b42 100644 --- a/tests/rest/client/test_reporting.py +++ b/tests/rest/client/test_reporting.py @@ -20,7 +20,7 @@ # from typing import Optional -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.rest.client import login, reporting, room diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py index 1e5a1b0a4d..24b007f779 100644 --- a/tests/rest/client/test_retention.py +++ b/tests/rest/client/test_retention.py @@ -20,7 +20,7 @@ from typing import Any, Dict from unittest.mock import Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes from synapse.rest import admin diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 48d33b8e17..24a28fbdd2 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -31,7 +31,7 @@ from urllib import parse as urlparse from parameterized import param, parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( @@ -43,8 +43,9 @@ from synapse.api.constants import ( RoomTypes, ) from synapse.api.errors import Codes, HttpResponseException +from synapse.api.room_versions import RoomVersions from synapse.appservice import ApplicationService -from synapse.events import EventBase +from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.rest import admin from synapse.rest.client import ( @@ -4499,3 +4500,985 @@ class RoomParticipantTestCase(unittest.HomeserverTestCase): self.store.get_room_participation(self.user2, self.room1) ) self.assertFalse(participant) + + +class MSC4293RedactOnBanKickTestCase(unittest.FederatingHomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + admin.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + self.creator = self.register_user("creator", "test") + self.creator_tok = self.login("creator", "test") + + self.bad_user_id = self.register_user("bad", "test") + self.bad_tok = self.login("bad", "test") + + self.room_id = self.helper.create_room_as(self.creator, tok=self.creator_tok) + + self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() + + self.federation_event_handler = self.hs.get_federation_event_handler() + + self.hs.config.experimental.msc4293_enabled = True + + def _check_redactions( + self, + original_events: List[EventBase], + pulled_events: List[JsonDict], + expect_redaction: bool, + reason: Optional[str] = None, + ) -> None: + """ + Checks a set of original events against a second set of the same events, pulled + from the /messages api. If expect_redaction is true, we expect that the second + set of events will be redacted, and the test will fail if that is not the case. + Otherwise, verifies that the events have not been redacted and fails if not. + + Args: + original_events: A list of the original events sent + pulled_events: A list of the same events as the orignal events, fetched + over the /messages api + expect_redaction: Whether or not the pulled_events should be redacted + reason: If the events are expected to be redacted, the expected reason + for the redaction + + """ + if expect_redaction: + redacted_count = 0 + for pulled_event in pulled_events: + for old_event in original_events: + if pulled_event["event_id"] != old_event.event_id: + continue + # we have a matching event, check that it is redacted + event_content = pulled_event["content"] + if event_content: + self.fail(f"Expected event {pulled_event} to be redacted") + redacting_event = pulled_event.get("redacted_because") + if not redacting_event: + self.fail( + f"Expected event {pulled_event} to have a redacting event." + ) + # check that the redacting event records the expected reason, and the + # redact_events flag + content = redacting_event["content"] + self.assertEqual(content["reason"], reason) + self.assertEqual(content["org.matrix.msc4293.redact_events"], True) + redacted_count += 1 + # all provided events should be redacted + self.assertEqual(len(original_events), redacted_count) + else: + unredacted_events = 0 + for pulled_event in pulled_events: + for old_event in original_events: + if pulled_event["event_id"] != old_event.event_id: + continue + # we have a matching event, make sure it is not redacted + redacted_because = pulled_event.get("redacted_because") + if redacted_because: + self.fail("Event should not have been redacted") + self.assertEqual(old_event.content, pulled_event["content"]) + unredacted_events += 1 + # all provided events should not have been redacted + self.assertEqual(unredacted_events, len(original_events)) + + def test_banning_local_member_with_flag_redacts_their_events(self) -> None: + self.helper.join(self.room_id, self.bad_user_id, tok=self.bad_tok) + + # bad user sends some messages + originals = [] + for i in range(5): + event = {"body": f"bothersome noise {i}", "msgtype": "m.text"} + res = self.helper.send_event( + self.room_id, "m.room.message", event, tok=self.bad_tok, expect_code=200 + ) + originals.append(res["event_id"]) + + # grab original events for comparison + original_events = [self.get_success(self.store.get_event(x)) for x in originals] + + # creator bans user with redaction flag set + content = { + "reason": "flooding", + "org.matrix.msc4293.redact_events": True, + } + self.helper.change_membership( + self.room_id, + self.creator, + self.bad_user_id, + "ban", + content, + self.creator_tok, + ) + + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions( + original_events, + channel.json_body["chunk"], + expect_redaction=True, + reason="flooding", + ) + + def test_banning_remote_member_with_flag_redacts_their_events(self) -> None: + bad_user = "@remote_bad_user:" + self.OTHER_SERVER_NAME + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/make_join/{self.room_id}/{bad_user}?ver=10", + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + join_result = channel.json_body + + join_event_dict = join_result["event"] + self.add_hashes_and_signatures_from_other_server( + join_event_dict, + RoomVersions.V10, + ) + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self.room_id}/x", + content=join_event_dict, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # the room should show that the bad user is a member + r = self.get_success( + self._storage_controllers.state.get_current_state(self.room_id) + ) + self.assertEqual(r[("m.room.member", bad_user)].membership, "join") + + auth_ids = [ + r[("m.room.create", "")].event_id, + r[("m.room.power_levels", "")].event_id, + r[("m.room.member", "@remote_bad_user:other.example.com")].event_id, + ] + original_messages = [] + for i in range(5): + remote_message = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": bad_user, + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "content": {"body": f"remote bummer{i}"}, + "auth_events": auth_ids, + "prev_events": auth_ids, + } + ), + room_version=RoomVersions.V10, + ) + + self.get_success( + self.federation_event_handler.on_receive_pdu( + self.OTHER_SERVER_NAME, remote_message + ) + ) + original_messages.append(remote_message) + + # creator bans bad user with redaction flag set + content = { + "reason": "bummer messages", + "org.matrix.msc4293.redact_events": True, + } + res = self.helper.change_membership( + self.room_id, self.creator, bad_user, "ban", content, self.creator_tok + ) + ban_event_id = res["event_id"] + + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions( + original_messages, + channel.json_body["chunk"], + expect_redaction=True, + reason="bummer messages", + ) + + # any future messages that are soft-failed are also redacted - send messages referencing + # dag before ban, they should be soft-failed but also redacted + new_original_messages = [] + for i in range(5): + remote_message = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": bad_user, + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "content": {"body": f"soft-fail remote bummer{i}"}, + "auth_events": auth_ids, + "prev_events": auth_ids, + } + ), + room_version=RoomVersions.V10, + ) + + self.get_success( + self.federation_event_handler.on_receive_pdu( + self.OTHER_SERVER_NAME, remote_message + ) + ) + new_original_messages.append(remote_message) + + # pull them from the db to check because they should be soft-failed and thus not available over + # cs-api + for message in new_original_messages: + original = self.get_success(self.store.get_event(message.event_id)) + if not original: + self.fail("Expected to find remote message in DB") + redacted_because = original.unsigned.get("redacted_because") + if not redacted_because: + self.fail("Did not find redacted_because field") + self.assertEqual(redacted_because.event_id, ban_event_id) + + def test_unbanning_remote_user_stops_redaction_action(self) -> None: + bad_user = "@remote_bad_user:" + self.OTHER_SERVER_NAME + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/make_join/{self.room_id}/{bad_user}?ver=10", + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + join_result = channel.json_body + + join_event_dict = join_result["event"] + self.add_hashes_and_signatures_from_other_server( + join_event_dict, + RoomVersions.V10, + ) + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self.room_id}/x", + content=join_event_dict, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # the room should show that the bad user is a member + r = self.get_success( + self._storage_controllers.state.get_current_state(self.room_id) + ) + self.assertEqual(r[("m.room.member", bad_user)].membership, "join") + + auth_ids = [ + r[("m.room.create", "")].event_id, + r[("m.room.power_levels", "")].event_id, + r[("m.room.member", "@remote_bad_user:other.example.com")].event_id, + ] + original_messages = [] + for i in range(5): + remote_message = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": bad_user, + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "content": {"body": f"annoying messages {i}"}, + "auth_events": auth_ids, + "prev_events": auth_ids, + } + ), + room_version=RoomVersions.V10, + ) + + self.get_success( + self.federation_event_handler.on_receive_pdu( + self.OTHER_SERVER_NAME, remote_message + ) + ) + original_messages.append(remote_message) + + # creator bans bad user with redaction flag set + content = { + "reason": "this dude sucks", + "org.matrix.msc4293.redact_events": True, + } + self.helper.change_membership( + self.room_id, self.creator, bad_user, "ban", content, self.creator_tok + ) + + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions( + original_messages, + channel.json_body["chunk"], + True, + reason="this dude sucks", + ) + + # unban user + self.helper.change_membership( + self.room_id, self.creator, bad_user, "unban", {}, self.creator_tok + ) + + # user should be able to join again + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/make_join/{self.room_id}/{bad_user}?ver=10", + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + join_result = channel.json_body + + join_event_dict = join_result["event"] + self.add_hashes_and_signatures_from_other_server( + join_event_dict, + RoomVersions.V10, + ) + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self.room_id}/x", + content=join_event_dict, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # the room should show that the bad user is a member again + new_state = self.get_success( + self._storage_controllers.state.get_current_state(self.room_id) + ) + self.assertEqual(new_state[("m.room.member", bad_user)].membership, "join") + + new_state = self.get_success( + self._storage_controllers.state.get_current_state(self.room_id) + ) + auth_ids = [ + new_state[("m.room.create", "")].event_id, + new_state[("m.room.power_levels", "")].event_id, + new_state[("m.room.member", "@remote_bad_user:other.example.com")].event_id, + ] + + # messages after unban and join proceed unredacted + new_original_messages = [] + for i in range(5): + remote_message = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": bad_user, + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "content": {"body": f"no longer a bummer {i}"}, + "auth_events": auth_ids, + "prev_events": auth_ids, + } + ), + room_version=RoomVersions.V10, + ) + + self.get_success( + self.federation_event_handler.on_receive_pdu( + self.OTHER_SERVER_NAME, remote_message + ) + ) + new_original_messages.append(remote_message) + + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions(new_original_messages, channel.json_body["chunk"], False) + + def test_redaction_flag_ignored_for_user_if_banner_lacks_redaction_power( + self, + ) -> None: + # change power levels so creator can ban but not redact + self.helper.send_state( + self.room_id, + "m.room.power_levels", + {"events_default": 0, "redact": 100, "users": {self.creator: 75}}, + tok=self.creator_tok, + ) + self.helper.join(self.room_id, self.bad_user_id, tok=self.bad_tok) + + # bad user sends some messages + original_ids = [] + for i in range(15): + event = {"body": f"being a menace {i}", "msgtype": "m.text"} + res = self.helper.send_event( + self.room_id, "m.room.message", event, tok=self.bad_tok, expect_code=200 + ) + original_ids.append(res["event_id"]) + + # grab original events before ban + originals = [self.get_success(self.store.get_event(x)) for x in original_ids] + + # creator bans bad user with redaction flag + content = { + "reason": "flooding", + "org.matrix.msc4293.redact_events": True, + } + self.helper.change_membership( + self.room_id, + self.creator, + self.bad_user_id, + "ban", + content, + self.creator_tok, + ) + + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + # messages are not redacted + self._check_redactions(originals, channel.json_body["chunk"], False) + + def test_kicking_local_member_with_flag_redacts_their_events(self) -> None: + self.helper.join(self.room_id, self.bad_user_id, tok=self.bad_tok) + + # bad user sends some messages + originals = [] + for i in range(5): + event = {"body": f"bothersome noise {i}", "msgtype": "m.text"} + res = self.helper.send_event( + self.room_id, "m.room.message", event, tok=self.bad_tok, expect_code=200 + ) + originals.append(res["event_id"]) + + # grab original events for comparison + original_events = [self.get_success(self.store.get_event(x)) for x in originals] + + # creator kicks user with redaction flag set + content = { + "reason": "flooding", + "org.matrix.msc4293.redact_events": True, + } + self.helper.change_membership( + self.room_id, + self.creator, + self.bad_user_id, + "kick", + content, + self.creator_tok, + ) + + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions( + original_events, + channel.json_body["chunk"], + expect_redaction=True, + reason="flooding", + ) + + def test_kicking_remote_member_with_flag_redacts_their_events(self) -> None: + bad_user = "@remote_bad_user:" + self.OTHER_SERVER_NAME + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/make_join/{self.room_id}/{bad_user}?ver=10", + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + join_result = channel.json_body + + join_event_dict = join_result["event"] + self.add_hashes_and_signatures_from_other_server( + join_event_dict, + RoomVersions.V10, + ) + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self.room_id}/x", + content=join_event_dict, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # the room should show that the bad user is a member + r = self.get_success( + self._storage_controllers.state.get_current_state(self.room_id) + ) + self.assertEqual(r[("m.room.member", bad_user)].membership, "join") + + auth_ids = [ + r[("m.room.create", "")].event_id, + r[("m.room.power_levels", "")].event_id, + r[("m.room.member", "@remote_bad_user:other.example.com")].event_id, + ] + original_messages = [] + for i in range(5): + remote_message = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": bad_user, + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "content": {"body": f"remote bummer{i}"}, + "auth_events": auth_ids, + "prev_events": auth_ids, + } + ), + room_version=RoomVersions.V10, + ) + + self.get_success( + self.federation_event_handler.on_receive_pdu( + self.OTHER_SERVER_NAME, remote_message + ) + ) + original_messages.append(remote_message) + + # creator kicks bad user with redaction flag set + content = { + "reason": "bummer messages", + "org.matrix.msc4293.redact_events": True, + } + res = self.helper.change_membership( + self.room_id, self.creator, bad_user, "kick", content, self.creator_tok + ) + ban_event_id = res["event_id"] + + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions( + original_messages, + channel.json_body["chunk"], + expect_redaction=True, + reason="bummer messages", + ) + + # any future messages that are soft-failed are also redacted - send messages referencing + # dag before ban, they should be soft-failed but also redacted + new_original_messages = [] + for i in range(5): + remote_message = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": bad_user, + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "content": {"body": f"soft-fail remote bummer{i}"}, + "auth_events": auth_ids, + "prev_events": auth_ids, + } + ), + room_version=RoomVersions.V10, + ) + + self.get_success( + self.federation_event_handler.on_receive_pdu( + self.OTHER_SERVER_NAME, remote_message + ) + ) + new_original_messages.append(remote_message) + + # pull them from the db to check because they should be soft-failed and thus not available over + # cs-api + for message in new_original_messages: + original = self.get_success(self.store.get_event(message.event_id)) + if not original: + self.fail("Expected to find remote message in DB") + self.assertEqual(original.unsigned["redacted_by"], ban_event_id) + + def test_rejoining_kicked_remote_user_stops_redaction_action(self) -> None: + bad_user = "@remote_bad_user:" + self.OTHER_SERVER_NAME + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/make_join/{self.room_id}/{bad_user}?ver=10", + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + join_result = channel.json_body + + join_event_dict = join_result["event"] + self.add_hashes_and_signatures_from_other_server( + join_event_dict, + RoomVersions.V10, + ) + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self.room_id}/x", + content=join_event_dict, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # the room should show that the bad user is a member + r = self.get_success( + self._storage_controllers.state.get_current_state(self.room_id) + ) + self.assertEqual(r[("m.room.member", bad_user)].membership, "join") + + auth_ids = [ + r[("m.room.create", "")].event_id, + r[("m.room.power_levels", "")].event_id, + r[("m.room.member", "@remote_bad_user:other.example.com")].event_id, + ] + original_messages = [] + for i in range(5): + remote_message = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": bad_user, + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "content": {"body": f"annoying messages {i}"}, + "auth_events": auth_ids, + "prev_events": auth_ids, + } + ), + room_version=RoomVersions.V10, + ) + + self.get_success( + self.federation_event_handler.on_receive_pdu( + self.OTHER_SERVER_NAME, remote_message + ) + ) + original_messages.append(remote_message) + + # creator kicks bad user with redaction flag set + content = { + "reason": "this dude sucks", + "org.matrix.msc4293.redact_events": True, + } + self.helper.change_membership( + self.room_id, self.creator, bad_user, "kick", content, self.creator_tok + ) + + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions( + original_messages, + channel.json_body["chunk"], + True, + reason="this dude sucks", + ) + + # user re-joins after kick + channel = self.make_signed_federation_request( + "GET", + f"/_matrix/federation/v1/make_join/{self.room_id}/{bad_user}?ver=10", + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + join_result = channel.json_body + + join_event_dict = join_result["event"] + self.add_hashes_and_signatures_from_other_server( + join_event_dict, + RoomVersions.V10, + ) + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/send_join/{self.room_id}/x", + content=join_event_dict, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + # the room should show that the bad user is a member again + new_state = self.get_success( + self._storage_controllers.state.get_current_state(self.room_id) + ) + self.assertEqual(new_state[("m.room.member", bad_user)].membership, "join") + + new_state = self.get_success( + self._storage_controllers.state.get_current_state(self.room_id) + ) + auth_ids = [ + new_state[("m.room.create", "")].event_id, + new_state[("m.room.power_levels", "")].event_id, + new_state[("m.room.member", "@remote_bad_user:other.example.com")].event_id, + ] + + # messages after kick and re-join proceed unredacted + new_original_messages = [] + for i in range(5): + remote_message = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": bad_user, + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "content": {"body": f"no longer a bummer {i}"}, + "auth_events": auth_ids, + "prev_events": auth_ids, + } + ), + room_version=RoomVersions.V10, + ) + + self.get_success( + self.federation_event_handler.on_receive_pdu( + self.OTHER_SERVER_NAME, remote_message + ) + ) + new_original_messages.append(remote_message) + + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions(new_original_messages, channel.json_body["chunk"], False) + + def test_redaction_flag_ignored_for_user_if_kicker_lacks_redaction_power( + self, + ) -> None: + # change power levels so creator can kick but not redact + self.helper.send_state( + self.room_id, + "m.room.power_levels", + {"events_default": 0, "redact": 100, "users": {self.creator: 75}}, + tok=self.creator_tok, + ) + self.helper.join(self.room_id, self.bad_user_id, tok=self.bad_tok) + + # bad user sends some messages + original_ids = [] + for i in range(15): + event = {"body": f"being a menace {i}", "msgtype": "m.text"} + res = self.helper.send_event( + self.room_id, "m.room.message", event, tok=self.bad_tok, expect_code=200 + ) + original_ids.append(res["event_id"]) + + # grab original events before ban + originals = [self.get_success(self.store.get_event(x)) for x in original_ids] + + # creator kicks bad user with redaction flag + content = { + "reason": "flooding", + "org.matrix.msc4293.redact_events": True, + } + self.helper.change_membership( + self.room_id, + self.creator, + self.bad_user_id, + "kick", + content, + self.creator_tok, + ) + + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + # messages are not redacted + self._check_redactions(originals, channel.json_body["chunk"], False) + + def test_MSC4293_flag_ignored_in_other_membership_events(self) -> None: + self.helper.join(self.room_id, self.bad_user_id, tok=self.bad_tok) + + # bad user sends some messages + original_ids = [] + for i in range(15): + event = {"body": f"being a menace {i}", "msgtype": "m.text"} + res = self.helper.send_event( + self.room_id, "m.room.message", event, tok=self.bad_tok, expect_code=200 + ) + original_ids.append(res["event_id"]) + + # grab original events before ban + originals = [self.get_success(self.store.get_event(x)) for x in original_ids] + + # bad user leaves on their own with flag + content = { + "org.matrix.msc4293.redact_events": True, + } + self.helper.change_membership( + self.room_id, + self.bad_user_id, + self.bad_user_id, + "leave", + content, + self.bad_tok, + ) + + # their messages are not redacted + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions(originals, channel.json_body["chunk"], False) + + # bad user is invited with flag in invite event + content = { + "org.matrix.msc4293.redact_events": True, + } + self.helper.change_membership( + self.room_id, + self.creator, + self.bad_user_id, + "invite", + content, + self.creator_tok, + ) + + # their messages are still not redacted + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions(originals, channel.json_body["chunk"], False) + + # bad user joins with flag in invite event + content = { + "org.matrix.msc4293.redact_events": True, + } + self.helper.change_membership( + self.room_id, + self.bad_user_id, + self.bad_user_id, + "join", + content, + self.bad_tok, + ) + + # and still their messages are not redacted + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions(originals, channel.json_body["chunk"], False) + + def test_MSC4293_redaction_applied_via_kick_api(self) -> None: + """ + Test that MSC4239 field passed through and applied when using /kick + """ + self.helper.join(self.room_id, self.bad_user_id, tok=self.bad_tok) + + # bad user sends some messages + original_ids = [] + for i in range(15): + event = {"body": f"being a menace {i}", "msgtype": "m.text"} + res = self.helper.send_event( + self.room_id, "m.room.message", event, tok=self.bad_tok, expect_code=200 + ) + original_ids.append(res["event_id"]) + + # grab original events before kick + originals = [self.get_success(self.store.get_event(x)) for x in original_ids] + + channel = self.make_request( + "POST", + f"/_matrix/client/v3/rooms/{self.room_id}/kick", + access_token=self.creator_tok, + content={ + "reason": "being annoying", + "org.matrix.msc4293.redact_events": True, + "user_id": self.bad_user_id, + }, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions( + originals, + channel.json_body["chunk"], + expect_redaction=True, + reason="being annoying", + ) + + def test_MSC4293_redaction_applied_via_ban_api(self) -> None: + """ + Test that MSC4239 field passed through and applied when using /ban + """ + self.helper.join(self.room_id, self.bad_user_id, tok=self.bad_tok) + + # bad user sends some messages + original_ids = [] + for i in range(15): + event = {"body": f"being a menace {i}", "msgtype": "m.text"} + res = self.helper.send_event( + self.room_id, "m.room.message", event, tok=self.bad_tok, expect_code=200 + ) + original_ids.append(res["event_id"]) + + # grab original events before ban + originals = [self.get_success(self.store.get_event(x)) for x in original_ids] + + channel = self.make_request( + "POST", + f"/_matrix/client/v3/rooms/{self.room_id}/ban", + access_token=self.creator_tok, + content={ + "reason": "being disruptive", + "org.matrix.msc4293.redact_events": True, + "user_id": self.bad_user_id, + }, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + + filter = json.dumps({"types": [EventTypes.Message]}) + channel = self.make_request( + "GET", + f"rooms/{self.room_id}/messages?filter={filter}&limit=50", + access_token=self.creator_tok, + ) + self.assertEqual(channel.code, 200) + self._check_redactions( + originals, + channel.json_body["chunk"], + expect_redaction=True, + reason="being disruptive", + ) diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 2287f233b4..b990a8600b 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -21,7 +21,7 @@ from unittest.mock import Mock, patch -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EduTypes, EventTypes diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index c52a5b2e79..e612df3be9 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -24,7 +24,7 @@ from typing import List from parameterized import parameterized, parameterized_class -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index d10df1a90f..f14ca8237a 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -22,7 +22,7 @@ import threading from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union from unittest.mock import AsyncMock, Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, LoginType, Membership from synapse.api.errors import SynapseError diff --git a/tests/rest/client/test_thread_subscriptions.py b/tests/rest/client/test_thread_subscriptions.py index a5c38753cb..3fbf3c5bfa 100644 --- a/tests/rest/client/test_thread_subscriptions.py +++ b/tests/rest/client/test_thread_subscriptions.py @@ -13,8 +13,9 @@ from http import HTTPStatus -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor +from synapse.api.errors import Codes from synapse.rest import admin from synapse.rest.client import login, profile, room, thread_subscriptions from synapse.server import HomeServer @@ -49,15 +50,16 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): # Create a room and send a message to use as a thread root self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) self.helper.join(self.room_id, self.other_user_id, tok=self.other_token) - response = self.helper.send(self.room_id, body="Root message", tok=self.token) - self.root_event_id = response["event_id"] + (self.root_event_id,) = self.helper.send_messages( + self.room_id, 1, tok=self.token + ) # Send a message in the thread - self.helper.send_event( - room_id=self.room_id, - type="m.room.message", - content={ - "body": "Thread message", + self.threaded_events = self.helper.send_messages( + self.room_id, + 2, + content_fn=lambda idx: { + "body": f"Thread message {idx}", "msgtype": "m.text", "m.relates_to": { "rel_type": "m.thread", @@ -106,9 +108,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - { - "automatic": False, - }, + {}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.OK) @@ -127,7 +127,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - {"automatic": True}, + {"automatic": self.threaded_events[0]}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.OK) @@ -148,11 +148,11 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", { - "automatic": True, + "automatic": self.threaded_events[0], }, access_token=self.token, ) - self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body) # Assert the subscription was saved channel = self.make_request( @@ -167,7 +167,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - {"automatic": False}, + {}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.OK) @@ -187,7 +187,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", { - "automatic": True, + "automatic": self.threaded_events[0], }, access_token=self.token, ) @@ -202,7 +202,6 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(channel.json_body, {"automatic": True}) - # Now also register a manual subscription channel = self.make_request( "DELETE", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", @@ -210,7 +209,6 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.OK) - # Assert the manual subscription was not overridden channel = self.make_request( "GET", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", @@ -224,7 +222,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription", - {"automatic": True}, + {}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) @@ -238,7 +236,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - {"automatic": True}, + {}, access_token=no_access_token, ) self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) @@ -249,8 +247,105 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): channel = self.make_request( "PUT", f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", - # non-boolean `automatic` - {"automatic": "true"}, + # non-Event ID `automatic` + {"automatic": True}, access_token=self.token, ) self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) + + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + # non-Event ID `automatic` + {"automatic": "$malformedEventId"}, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) + + def test_auto_subscribe_cause_event_not_in_thread(self) -> None: + """ + Test making an automatic subscription, where the cause event is not + actually in the thread. + This is an error. + """ + (unrelated_event_id,) = self.helper.send_messages( + self.room_id, 1, tok=self.token + ) + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + {"automatic": unrelated_event_id}, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.text_body) + self.assertEqual(channel.json_body["errcode"], Codes.MSC4306_NOT_IN_THREAD) + + def test_auto_resubscription_conflict(self) -> None: + """ + Test that an automatic subscription that conflicts with an unsubscription + is skipped. + """ + # Reuse the test that subscribes and unsubscribes + self.test_unsubscribe() + + # Now no matter which event we present as the cause of an automatic subscription, + # the automatic subscription is skipped. + # This is because the unsubscription happened after all of the events. + for event in self.threaded_events: + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + { + "automatic": event, + }, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.CONFLICT, channel.text_body) + self.assertEqual( + channel.json_body["errcode"], + Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION, + channel.text_body, + ) + + # Check the subscription was not made + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + + # But if a new event is sent after the unsubscription took place, + # that one can be used for an automatic subscription + (later_event_id,) = self.helper.send_messages( + self.room_id, + 1, + content_fn=lambda _: { + "body": "Thread message after unsubscription", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": "m.thread", + "event_id": self.root_event_id, + }, + }, + tok=self.token, + ) + + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + { + "automatic": later_event_id, + }, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body) + + # Check the subscription was made + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.json_body, {"automatic": True}) diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index af1eecbb34..5f42acb391 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -90,7 +90,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): ) -> Generator["defer.Deferred[Any]", object, None]: @defer.inlineCallbacks def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]: - yield Clock(reactor).sleep(0) + yield defer.ensureDeferred(Clock(reactor).sleep(0)) return 1, {} @defer.inlineCallbacks diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py index 805c49b540..ce2504156c 100644 --- a/tests/rest/client/test_typing.py +++ b/tests/rest/client/test_typing.py @@ -21,7 +21,7 @@ """Tests REST events for /rooms paths.""" -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EduTypes from synapse.rest.client import room diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index dbc493e970..66fddc5475 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -21,7 +21,7 @@ from typing import Optional from unittest.mock import patch -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventContentFields, EventTypes, Membership, RoomTypes from synapse.config.server import DEFAULT_ROOM_VERSION diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 280486da08..bb214759d9 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -29,12 +29,14 @@ from http import HTTPStatus from typing import ( Any, AnyStr, + Callable, Dict, Iterable, Literal, Mapping, MutableMapping, Optional, + Sequence, Tuple, overload, ) @@ -42,10 +44,10 @@ from urllib.parse import urlencode import attr -from twisted.test.proto_helpers import MemoryReactorClock +from twisted.internet.testing import MemoryReactorClock from twisted.web.server import Site -from synapse.api.constants import Membership, ReceiptTypes +from synapse.api.constants import EventTypes, Membership, ReceiptTypes from synapse.api.errors import Codes from synapse.server import HomeServer from synapse.types import JsonDict @@ -185,7 +187,7 @@ class RestHelper: def join( self, room: str, - user: Optional[str] = None, + user: str, expect_code: int = HTTPStatus.OK, tok: Optional[str] = None, appservice_user_id: Optional[str] = None, @@ -394,6 +396,32 @@ class RestHelper: custom_headers=custom_headers, ) + def send_messages( + self, + room_id: str, + num_events: int, + content_fn: Callable[[int], JsonDict] = lambda idx: { + "msgtype": "m.text", + "body": f"Test event {idx}", + }, + tok: Optional[str] = None, + ) -> Sequence[str]: + """ + Helper to send a handful of sequential events and return their event IDs as a sequence. + """ + event_ids = [] + + for event_index in range(num_events): + response = self.send_event( + room_id, + EventTypes.Message, + content_fn(event_index), + tok=tok, + ) + event_ids.append(response["event_id"]) + + return event_ids + def send_event( self, room_id: str, diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 21e12b2a2f..3717d70b6b 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -27,7 +27,7 @@ from canonicaljson import encode_canonical_json from signedjson.sign import sign_json from signedjson.types import SigningKey -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.web.resource import NoResource, Resource from synapse.crypto.keyring import PerspectivesKeyFetcher @@ -99,7 +99,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): """ channel = FakeChannel(self.site, self.reactor) # channel is a `FakeChannel` but `HTTPChannel` is expected - req = SynapseRequest(channel, self.site) # type: ignore[arg-type] + req = SynapseRequest(channel, self.site, self.hs.hostname) # type: ignore[arg-type] req.content = BytesIO(b"") req.requestReceived( b"GET", @@ -201,7 +201,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): channel = FakeChannel(self.site, self.reactor) # channel is a `FakeChannel` but `HTTPChannel` is expected - req = SynapseRequest(channel, self.site) # type: ignore[arg-type] + req = SynapseRequest(channel, self.site, self.hs.hostname) # type: ignore[arg-type] req.content = BytesIO(encode_canonical_json(data)) req.requestReceived( diff --git a/tests/rest/media/test_domain_blocking.py b/tests/rest/media/test_domain_blocking.py index 26453f70dd..3feade4a4b 100644 --- a/tests/rest/media/test_domain_blocking.py +++ b/tests/rest/media/test_domain_blocking.py @@ -20,7 +20,7 @@ # from typing import Dict -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.web.resource import Resource from synapse.media._base import FileInfo diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py index 2a7bee19f9..e096780ce2 100644 --- a/tests/rest/media/test_url_preview.py +++ b/tests/rest/media/test_url_preview.py @@ -29,7 +29,7 @@ from twisted.internet._resolver import HostResolution from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import IAddress, IResolutionReceiver -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor +from twisted.internet.testing import AccumulatingProtocol, MemoryReactor from twisted.web.resource import Resource from synapse.config.oembed import OEmbedEndpointConfig diff --git a/tests/rest/synapse/mas/test_devices.py b/tests/rest/synapse/mas/test_devices.py index a7cd58d8ff..458878c13c 100644 --- a/tests/rest/synapse/mas/test_devices.py +++ b/tests/rest/synapse/mas/test_devices.py @@ -11,7 +11,7 @@ # See the GNU Affero General Public License for more details: # . -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.types import UserID diff --git a/tests/rest/synapse/mas/test_users.py b/tests/rest/synapse/mas/test_users.py index 378f29fd4c..b236aceaf2 100644 --- a/tests/rest/synapse/mas/test_users.py +++ b/tests/rest/synapse/mas/test_users.py @@ -13,7 +13,7 @@ from urllib.parse import urlencode -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.appservice import ApplicationService from synapse.server import HomeServer diff --git a/tests/server.py b/tests/server.py index 0c519bc4c9..3a81a4c6d9 100644 --- a/tests/server.py +++ b/tests/server.py @@ -78,9 +78,9 @@ from twisted.internet.interfaces import ( ITransport, ) from twisted.internet.protocol import ClientFactory, DatagramProtocol, Factory +from twisted.internet.testing import AccumulatingProtocol, MemoryReactorClock from twisted.python import threadpool from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site @@ -97,6 +97,7 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( load_legacy_third_party_event_rules, ) from synapse.server import HomeServer +from synapse.server_notices.consent_server_notices import ConfigError from synapse.storage import DataStore from synapse.storage.database import LoggingDatabaseConnection, make_pool from synapse.storage.engines import BaseDatabaseEngine, create_engine @@ -432,7 +433,7 @@ def make_request( channel = FakeChannel(site, reactor, ip=client_ip) - req = request(channel, site) + req = request(channel, site, our_server_name="test_server") channel.request = req req.content = BytesIO(content) @@ -702,6 +703,7 @@ def make_fake_db_pool( reactor: ISynapseReactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine, + server_name: str, ) -> adbapi.ConnectionPool: """Wrapper for `make_pool` which builds a pool which runs db queries synchronously. @@ -710,7 +712,9 @@ def make_fake_db_pool( is a drop-in replacement for the normal `make_pool` which builds such a connection pool. """ - pool = make_pool(reactor, db_config, engine) + pool = make_pool( + reactor=reactor, db_config=db_config, engine=engine, server_name=server_name + ) def runWithConnection( func: Callable[..., R], *args: Any, **kwargs: Any @@ -1084,12 +1088,19 @@ def setup_test_homeserver( "args": {"database": test_db_location, "cp_min": 1, "cp_max": 1}, } + server_name = config.server.server_name + if not isinstance(server_name, str): + raise ConfigError("Must be a string", ("server_name",)) + # Check if we have set up a DB that we can use as a template. global PREPPED_SQLITE_DB_CONN if PREPPED_SQLITE_DB_CONN is None: temp_engine = create_engine(database_config) PREPPED_SQLITE_DB_CONN = LoggingDatabaseConnection( - sqlite3.connect(":memory:"), temp_engine, "PREPPED_CONN" + conn=sqlite3.connect(":memory:"), + engine=temp_engine, + default_txn_name="PREPPED_CONN", + server_name=server_name, ) database = DatabaseConnectionConfig("master", database_config) diff --git a/tests/server_notices/__init__.py b/tests/server_notices/__init__.py index e69de29bb2..1d23a126de 100644 --- a/tests/server_notices/__init__.py +++ b/tests/server_notices/__init__.py @@ -0,0 +1,240 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config +from tests.utils import default_config + +DEFAULT_SERVER_NOTICES_CONFIG = { + "system_mxid_localpart": "notices", + "system_mxid_display_name": "test display name", + "system_mxid_avatar_url": None, + "room_name": "Server Notices", + "auto_join": False, +} + + +class ServerNoticesTests(unittest.HomeserverTestCase): + servlets = [ + sync.register_servlets, + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = default_config("test") + + config.update({"server_notices": DEFAULT_SERVER_NOTICES_CONFIG}) + + # apply any additional config which was specified via the override_config + # decorator. + if self._extra_config is not None: + config.update(self._extra_config) + + return config + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self._admin_user_id = self.register_user( + "server_notices_admin", "abc123", admin=True + ) + self._admin_user_access_token = self.login("server_notices_admin", "abc123") + + self._test_user_id = self.register_user("server_notices_test_user", "abc123") + self._test_user_access_token = self.login("server_notices_test_user", "abc123") + + self._server_notice_content = { + "msgtype": "m.text", + "formatted_body": "

Do the hussle.

", + "body": "Do the hussle.", + "format": "org.matrix.custom.html", + } + + def _send_server_notice( + self, + admin_access_token: str, + target_user_id: str, + notice_content: JsonDict, + ) -> None: + # Send a server notice. + channel = self.make_request( + "POST", + "/_synapse/admin/v1/send_server_notice", + content={ + "user_id": target_user_id, + "content": notice_content, + }, + access_token=admin_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + def _check_user_received_server_notice( + self, + target_user_id: str, + target_access_token: str, + expected_content: JsonDict, + user_accepts_invite: bool, + ) -> None: + # Have the target user sync. + channel = self.make_request( + "GET", "/_matrix/client/v3/sync", access_token=target_access_token + ) + self.assertEqual(channel.code, 200, channel.json_body) + sync_body = channel.json_body + + if user_accepts_invite: + # Get the Room ID to join + room_id = list(sync_body["rooms"]["invite"].keys())[0] + + # Join the room + self.helper.join(room_id, target_user_id, tok=target_access_token) + + for _ in range(5): + # Sync until we're joined to the room. + channel = self.make_request( + "GET", "/_matrix/client/v3/sync", access_token=target_access_token + ) + self.assertEqual(channel.code, 200, channel.json_body) + sync_body = channel.json_body + + if "join" in sync_body["rooms"] and len(sync_body["rooms"]["join"]) > 0: + # Retrieve the server notices message. + room_id = list(sync_body["rooms"]["join"].keys())[0] + room = sync_body["rooms"]["join"][room_id] + messages = [ + x + for x in room["timeline"]["events"] + if x["type"] == "m.room.message" + ] + break + + # Sleep and try again. + self.get_success(self.clock.sleep(0.1)) + else: + self.fail( + f"Failed to join the server notices room. No 'join' field in sync_body['rooms']: {sync_body['rooms']}" + ) + + # Should be the expected server notices content. + self.assertDictEqual(messages[-1]["content"], expected_content) + + def test_send_server_notice(self) -> None: + """ + Test the happy path of sending a server notice to a user. + """ + # Send a server notice. The server notice room does not yet exist. + self._send_server_notice( + self._admin_user_access_token, + self._test_user_id, + self._server_notice_content, + ) + + self._check_user_received_server_notice( + self._test_user_id, + self._test_user_access_token, + self._server_notice_content, + # User must accept the invite manually. + True, + ) + + # Send another server notice. In this case, the room already exists. + self._send_server_notice( + self._admin_user_access_token, + self._test_user_id, + self._server_notice_content, + ) + + self._check_user_received_server_notice( + self._test_user_id, + self._test_user_access_token, + self._server_notice_content, + # User is already in the room, no need to join it. + False, + ) + + @override_config( + { + "server_notices": { + **DEFAULT_SERVER_NOTICES_CONFIG, + "auto_join": True, + } + } + ) + def test_send_server_notice_auto_join(self) -> None: + """ + Test the happy path of sending a server notice to a user, with auto_join enabled. + """ + # Send a server notice. The server notice room does not yet exist. + self._send_server_notice( + self._admin_user_access_token, + self._test_user_id, + self._server_notice_content, + ) + + self._check_user_received_server_notice( + self._test_user_id, + self._test_user_access_token, + self._server_notice_content, + # User does not need to join the room manually. They should be auto-joined. + False, + ) + + @override_config( + { + "server_notices": { + **DEFAULT_SERVER_NOTICES_CONFIG, + "auto_join": True, + } + } + ) + def test_send_server_notice_suspended_user_auto_join(self) -> None: + """Test sending a server notice to a user that's suspended, with auto-join enabled. + + This is a regression test for https://github.com/element-hq/synapse/pull/18750, where + previously the suspended user would not be allowed to join the server notices room. + """ + # Suspend the target user. + channel = self.make_request( + "PUT", + f"/_synapse/admin/v1/suspend/{self._test_user_id}", + content={"suspend": True}, + access_token=self._admin_user_access_token, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Send a server notice. The server notices room will be created and the user auto-joined. + self._send_server_notice( + self._admin_user_access_token, + self._test_user_id, + self._server_notice_content, + ) + + self._check_user_received_server_notice( + self._test_user_id, + self._test_user_access_token, + self._server_notice_content, + # User does not need to join the room manually. They should be auto-joined. + False, + ) diff --git a/tests/server_notices/test_consent.py b/tests/server_notices/test_consent.py index 53e14e049e..db4a6370e8 100644 --- a/tests/server_notices/test_consent.py +++ b/tests/server_notices/test_consent.py @@ -20,7 +20,7 @@ import os -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.rest.client import login, room, sync diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 997ee7b91b..0da12f14cd 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -20,7 +20,7 @@ from typing import Tuple from unittest.mock import AsyncMock, Mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index eade461e79..b4f2b98cc4 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -66,8 +66,8 @@ ORIGIN_SERVER_TS = 0 class FakeClock: - def sleep(self, msec: float) -> "defer.Deferred[None]": - return defer.succeed(None) + async def sleep(self, msec: float) -> None: + return None class FakeEvent: diff --git a/tests/state/test_v21.py b/tests/state/test_v21.py index 5eecde1c51..5e46b69fef 100644 --- a/tests/state/test_v21.py +++ b/tests/state/test_v21.py @@ -66,8 +66,8 @@ def monotonic_timestamp() -> int: class FakeClock: - def sleep(self, msec: float) -> "defer.Deferred[None]": - return defer.succeed(None) + async def sleep(self, duration_ms: float) -> None: + defer.succeed(None) class StateResV21TestCase(unittest.HomeserverTestCase): diff --git a/tests/storage/databases/main/test_deviceinbox.py b/tests/storage/databases/main/test_deviceinbox.py index 7556111f16..d3ddeaa57e 100644 --- a/tests/storage/databases/main/test_deviceinbox.py +++ b/tests/storage/databases/main/test_deviceinbox.py @@ -22,7 +22,7 @@ from unittest.mock import patch from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.rest import admin from synapse.rest.client import devices diff --git a/tests/storage/databases/main/test_end_to_end_keys.py b/tests/storage/databases/main/test_end_to_end_keys.py index 1ed1d01cea..3992fc3264 100644 --- a/tests/storage/databases/main/test_end_to_end_keys.py +++ b/tests/storage/databases/main/test_end_to_end_keys.py @@ -20,7 +20,7 @@ # from typing import List, Optional, Tuple -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage._base import db_to_json diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 18039a07e2..f23609aee3 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -25,7 +25,7 @@ from unittest import mock from twisted.enterprise.adbapi import ConnectionPool from twisted.internet.defer import CancelledError, Deferred, ensureDeferred -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.room_versions import EventFormatVersions, RoomVersions from synapse.events import make_event_from_dict diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index 1df71e723e..e18e0f2792 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -23,7 +23,7 @@ from twisted.internet import defer, reactor from twisted.internet.base import ReactorBase from twisted.internet.defer import Deferred -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS, _RENEWAL_INTERVAL_MS diff --git a/tests/storage/test_event_metrics.py b/tests/storage/databases/main/test_metrics.py similarity index 67% rename from tests/storage/test_event_metrics.py rename to tests/storage/databases/main/test_metrics.py index fc6e02545f..be59e1b67e 100644 --- a/tests/storage/test_event_metrics.py +++ b/tests/storage/databases/main/test_metrics.py @@ -65,24 +65,24 @@ class ExtremStatisticsTestCase(HomeserverTestCase): ) expected = [ - b'synapse_forward_extremities_bucket{le="1.0"} 0.0', - b'synapse_forward_extremities_bucket{le="2.0"} 2.0', - b'synapse_forward_extremities_bucket{le="3.0"} 2.0', - b'synapse_forward_extremities_bucket{le="5.0"} 2.0', - b'synapse_forward_extremities_bucket{le="7.0"} 3.0', - b'synapse_forward_extremities_bucket{le="10.0"} 3.0', - b'synapse_forward_extremities_bucket{le="15.0"} 3.0', - b'synapse_forward_extremities_bucket{le="20.0"} 3.0', - b'synapse_forward_extremities_bucket{le="50.0"} 3.0', - b'synapse_forward_extremities_bucket{le="100.0"} 3.0', - b'synapse_forward_extremities_bucket{le="200.0"} 3.0', - b'synapse_forward_extremities_bucket{le="500.0"} 3.0', + b'synapse_forward_extremities_bucket{le="1.0",server_name="test"} 0.0', + b'synapse_forward_extremities_bucket{le="2.0",server_name="test"} 2.0', + b'synapse_forward_extremities_bucket{le="3.0",server_name="test"} 2.0', + b'synapse_forward_extremities_bucket{le="5.0",server_name="test"} 2.0', + b'synapse_forward_extremities_bucket{le="7.0",server_name="test"} 3.0', + b'synapse_forward_extremities_bucket{le="10.0",server_name="test"} 3.0', + b'synapse_forward_extremities_bucket{le="15.0",server_name="test"} 3.0', + b'synapse_forward_extremities_bucket{le="20.0",server_name="test"} 3.0', + b'synapse_forward_extremities_bucket{le="50.0",server_name="test"} 3.0', + b'synapse_forward_extremities_bucket{le="100.0",server_name="test"} 3.0', + b'synapse_forward_extremities_bucket{le="200.0",server_name="test"} 3.0', + b'synapse_forward_extremities_bucket{le="500.0",server_name="test"} 3.0', # per https://docs.google.com/document/d/1KwV0mAXwwbvvifBvDKH_LU1YjyXE_wxCkHNoCGq1GX0/edit#heading=h.wghdjzzh72j9, # "inf" is valid: "this includes variants such as inf" - b'synapse_forward_extremities_bucket{le="inf"} 3.0', + b'synapse_forward_extremities_bucket{le="inf",server_name="test"} 3.0', b"# TYPE synapse_forward_extremities_gcount gauge", - b"synapse_forward_extremities_gcount 3.0", + b'synapse_forward_extremities_gcount{server_name="test"} 3.0', b"# TYPE synapse_forward_extremities_gsum gauge", - b"synapse_forward_extremities_gsum 10.0", + b'synapse_forward_extremities_gsum{server_name="test"} 10.0', ] self.assertEqual(items, expected) diff --git a/tests/storage/databases/main/test_receipts.py b/tests/storage/databases/main/test_receipts.py index da2ec26421..4141f868d6 100644 --- a/tests/storage/databases/main/test_receipts.py +++ b/tests/storage/databases/main/test_receipts.py @@ -21,7 +21,7 @@ from typing import Any, Dict, Optional, Sequence, Tuple -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, room diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py index 88a5aa8cb1..dda4294e63 100644 --- a/tests/storage/databases/main/test_room.py +++ b/tests/storage/databases/main/test_room.py @@ -21,7 +21,7 @@ import json -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import RoomTypes from synapse.rest import admin diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 49dc973a36..7d260b7915 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -22,7 +22,7 @@ import secrets from typing import Generator, List, Tuple, cast -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.util import Clock diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index 0e52dd26ce..794cefd04d 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -21,7 +21,7 @@ from typing import Iterable, Optional, Set -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import AccountDataTypes from synapse.api.errors import Codes, SynapseError diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 10533f45d7..759fad6af1 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -27,7 +27,7 @@ from unittest.mock import AsyncMock, Mock import yaml from twisted.internet import defer -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.appservice import ApplicationService, ApplicationServiceState from synapse.config._base import ConfigError @@ -63,9 +63,15 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob") # must be done after inserts database = self.hs.get_datastores().databases[0] + self.server_name = self.hs.hostname self.store = ApplicationServiceStore( database, - make_conn(database._database_config, database.engine, "test"), + make_conn( + db_config=database._database_config, + engine=database.engine, + default_txn_name="test", + server_name=self.server_name, + ), self.hs, ) @@ -138,9 +144,17 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): self.db_pool = database._db_pool self.engine = database.engine + server_name = self.hs.hostname db_config = self.hs.config.database.get_single_database() self.store = TestTransactionStore( - database, make_conn(db_config, self.engine, "test"), self.hs + database, + make_conn( + db_config=db_config, + engine=self.engine, + default_txn_name="test", + server_name=server_name, + ), + self.hs, ) def _add_service(self, url: str, as_token: str, id: str) -> None: @@ -488,10 +502,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): self.hs.config.appservice.app_service_config_files = [f1, f2] self.hs.config.caches.event_cache_size = 1 + server_name = self.hs.hostname database = self.hs.get_datastores().databases[0] ApplicationServiceStore( database, - make_conn(database._database_config, database.engine, "test"), + make_conn( + db_config=database._database_config, + engine=database.engine, + default_txn_name="test", + server_name=server_name, + ), self.hs, ) @@ -503,10 +523,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): self.hs.config.caches.event_cache_size = 1 with self.assertRaises(ConfigError) as cm: + server_name = self.hs.hostname database = self.hs.get_datastores().databases[0] ApplicationServiceStore( database, - make_conn(database._database_config, database.engine, "test"), + make_conn( + db_config=database._database_config, + engine=database.engine, + default_txn_name="test", + server_name=server_name, + ), self.hs, ) @@ -523,10 +549,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.HomeserverTestCase): self.hs.config.caches.event_cache_size = 1 with self.assertRaises(ConfigError) as cm: + server_name = self.hs.hostname database = self.hs.get_datastores().databases[0] ApplicationServiceStore( database, - make_conn(database._database_config, database.engine, "test"), + make_conn( + db_config=database._database_config, + engine=database.engine, + default_txn_name="test", + server_name=server_name, + ), self.hs, ) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index b28db6a4ad..89a3b54a25 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -25,7 +25,7 @@ from unittest.mock import AsyncMock, Mock import yaml from twisted.internet.defer import Deferred, ensureDeferred -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage.background_updates import ( diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index d5b9996284..94fb8e01a1 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -22,7 +22,7 @@ import os.path from unittest.mock import Mock, patch -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventTypes diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index 13f78ee2d2..de95272b52 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -24,7 +24,7 @@ from unittest.mock import AsyncMock from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.http.site import XForwardedForRequest diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 8af0d6265b..5e5937ff17 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -24,7 +24,7 @@ from unittest.mock import Mock, call from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage.database import ( diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index 74edca7523..e8ea813668 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -21,7 +21,7 @@ from typing import Collection, List, Tuple -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.api.errors from synapse.api.constants import EduTypes diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index f1602fdc86..26bf6cf391 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -19,7 +19,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.types import RoomAlias, RoomID diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 931d37e85a..f390d11e41 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -19,7 +19,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage.databases.main.e2e_room_keys import RoomKey diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index bd594d3c1f..e46999022a 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -19,7 +19,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.util import Clock diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index c4e216c308..b2480a139d 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -23,7 +23,7 @@ from typing import Dict, List, Set, Tuple, cast from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from twisted.trial import unittest from synapse.api.constants import EventTypes diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 11426e2b97..2f79068f6b 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -37,7 +37,7 @@ from typing import ( import attr from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes from synapse.api.room_versions import ( diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 233066bf82..640490a6e5 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -21,7 +21,7 @@ from typing import Optional, Tuple -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import MAIN_TIMELINE, RelationTypes from synapse.rest import admin diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index 2a43f762a8..6d2e4e4bbe 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -22,7 +22,7 @@ import logging from typing import List, Optional -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions diff --git a/tests/storage/test_events_bg_updates.py b/tests/storage/test_events_bg_updates.py index ecdf413e3b..7bbb5849a0 100644 --- a/tests/storage/test_events_bg_updates.py +++ b/tests/storage/test_events_bg_updates.py @@ -15,7 +15,7 @@ from typing import Dict -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import MAX_DEPTH from synapse.api.room_versions import RoomVersion, RoomVersions diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 12b89cecb6..9e949af482 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -20,7 +20,7 @@ # from typing import Dict, List, Optional -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage.database import ( @@ -80,10 +80,11 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase): ) -> MultiWriterIdGenerator: def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator: return MultiWriterIdGenerator( - conn, - self.db_pool, + db_conn=conn, + db=self.db_pool, notifier=self.hs.get_replication_notifier(), stream_name="test_stream", + server_name=self.hs.hostname, instance_name=instance_name, tables=[(table, "instance_name", "stream_id") for table in self.tables], sequence_name="foobar_seq", diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 15ae582051..78ef2e67a2 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -20,7 +20,7 @@ from typing import Any, Dict, List from unittest.mock import AsyncMock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import UserTypes from synapse.server import HomeServer diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 9df8ea4ee6..0f14e00e51 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -19,7 +19,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 0aa14fd1f4..38d0cd6eb2 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -18,7 +18,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError from synapse.rest.client import room diff --git a/tests/storage/test_receipts.py b/tests/storage/test_receipts.py index 0b984c7ebc..8f7f736175 100644 --- a/tests/storage/test_receipts.py +++ b/tests/storage/test_receipts.py @@ -21,7 +21,7 @@ from typing import Collection, Optional -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import ReceiptTypes from synapse.server import HomeServer diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index e2e48a5295..a9c0d7d9a9 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -22,7 +22,7 @@ from typing import List, Optional, cast from canonicaljson import json -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersion, RoomVersions diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 14e3871dc1..992ccc779b 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -18,7 +18,7 @@ # [This file includes modifications made by New Vector Limited] # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import UserTypes from synapse.api.errors import ThreepidValidationError diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py index a7f7c840f3..0f3e3fe7eb 100644 --- a/tests/storage/test_relations.py +++ b/tests/storage/test_relations.py @@ -19,7 +19,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import MAIN_TIMELINE from synapse.server import HomeServer diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py index 909aee043e..af69b93cf8 100644 --- a/tests/storage/test_rollback_worker.py +++ b/tests/storage/test_rollback_worker.py @@ -21,7 +21,7 @@ from typing import List from unittest import mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.app.generic_worker import GenericWorkerServer from synapse.server import HomeServer @@ -69,9 +69,10 @@ class WorkerSchemaTests(HomeserverTestCase): db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( - db_pool._db_pool.connect(), - db_pool.engine, - "tests", + conn=db_pool._db_pool.connect(), + engine=db_pool.engine, + default_txn_name="tests", + server_name="test_server", ) cur = db_conn.cursor() @@ -85,9 +86,10 @@ class WorkerSchemaTests(HomeserverTestCase): """Test that workers don't start if the DB has an older schema version""" db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( - db_pool._db_pool.connect(), - db_pool.engine, - "tests", + conn=db_pool._db_pool.connect(), + engine=db_pool.engine, + default_txn_name="tests", + server_name="test_server", ) cur = db_conn.cursor() @@ -105,9 +107,10 @@ class WorkerSchemaTests(HomeserverTestCase): """ db_pool = self.hs.get_datastores().main.db_pool db_conn = LoggingDatabaseConnection( - db_pool._db_pool.connect(), - db_pool.engine, - "tests", + conn=db_pool._db_pool.connect(), + engine=db_pool.engine, + default_txn_name="tests", + server_name="test_server", ) # Set the schema version of the database to the current version diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 34d6fdb71e..a8a75d2973 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -19,7 +19,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.room_versions import RoomVersions from synapse.server import HomeServer diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index 340642b7e7..f7eaa83ec6 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -22,7 +22,7 @@ from typing import List, Tuple from unittest.case import SkipTest -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventTypes diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 330fea0e62..fd489022a8 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -22,7 +22,7 @@ import logging from typing import List, Optional, Tuple, cast -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventContentFields, EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions diff --git a/tests/storage/test_sliding_sync_tables.py b/tests/storage/test_sliding_sync_tables.py index 53212f7c45..1a7a0b4c5c 100644 --- a/tests/storage/test_sliding_sync_tables.py +++ b/tests/storage/test_sliding_sync_tables.py @@ -23,7 +23,7 @@ from typing import Dict, List, Optional, Tuple, cast import attr from parameterized import parameterized -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventContentFields, EventTypes, Membership, RoomTypes from synapse.api.room_versions import RoomVersions diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 48f8d1c340..cbf68b3032 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -24,7 +24,7 @@ from typing import List, Tuple, cast from immutabledict import immutabledict -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions diff --git a/tests/storage/test_state_deletion.py b/tests/storage/test_state_deletion.py index a4d318ae20..58cd118567 100644 --- a/tests/storage/test_state_deletion.py +++ b/tests/storage/test_state_deletion.py @@ -15,7 +15,7 @@ import logging -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.rest import admin from synapse.rest.client import login, room diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 0f58dc8a0a..ba2af1e044 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -25,7 +25,7 @@ from unittest.mock import AsyncMock, patch from immutabledict import immutabledict -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import ( Direction, diff --git a/tests/storage/test_thread_subscriptions.py b/tests/storage/test_thread_subscriptions.py index dd0b804f1f..2a5c440cf4 100644 --- a/tests/storage/test_thread_subscriptions.py +++ b/tests/storage/test_thread_subscriptions.py @@ -12,13 +12,18 @@ # . # -from typing import Optional +from typing import Optional, Union -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main.thread_subscriptions import ( + AutomaticSubscriptionConflicted, + ThreadSubscriptionsWorkerStore, +) from synapse.storage.engines.sqlite import Sqlite3Engine +from synapse.types import EventOrderings from synapse.util import Clock from tests import unittest @@ -97,10 +102,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): self, thread_root_id: str, *, - automatic: bool, + automatic_event_orderings: Optional[EventOrderings], room_id: Optional[str] = None, user_id: Optional[str] = None, - ) -> Optional[int]: + ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: if user_id is None: user_id = self.user_id @@ -112,7 +117,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): user_id, room_id, thread_root_id, - automatic=automatic, + automatic_event_orderings=automatic_event_orderings, ) ) @@ -149,7 +154,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): # Subscribe self._subscribe( self.thread_root_id, - automatic=True, + automatic_event_orderings=EventOrderings(1, 1), ) # Assert subscription went through @@ -164,7 +169,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): # Now make it a manual subscription self._subscribe( self.thread_root_id, - automatic=False, + automatic_event_orderings=None, ) # Assert the manual subscription overrode the automatic one @@ -178,8 +183,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): def test_purge_thread_subscriptions_for_user(self) -> None: """Test purging all thread subscription settings for a user.""" # Set subscription settings for multiple threads - self._subscribe(self.thread_root_id, automatic=True) - self._subscribe(self.other_thread_root_id, automatic=False) + self._subscribe( + self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1) + ) + self._subscribe(self.other_thread_root_id, automatic_event_orderings=None) subscriptions = self.get_success( self.store.get_updated_thread_subscriptions_for_user( @@ -217,20 +224,32 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): def test_get_updated_thread_subscriptions(self) -> None: """Test getting updated thread subscriptions since a stream ID.""" - stream_id1 = self._subscribe(self.thread_root_id, automatic=False) - stream_id2 = self._subscribe(self.other_thread_root_id, automatic=True) - assert stream_id1 is not None - assert stream_id2 is not None + stream_id1 = self._subscribe( + self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1) + ) + stream_id2 = self._subscribe( + self.other_thread_root_id, automatic_event_orderings=EventOrderings(2, 2) + ) + assert stream_id1 is not None and not isinstance( + stream_id1, AutomaticSubscriptionConflicted + ) + assert stream_id2 is not None and not isinstance( + stream_id2, AutomaticSubscriptionConflicted + ) # Get updates since initial ID (should include both changes) updates = self.get_success( - self.store.get_updated_thread_subscriptions(0, stream_id2, 10) + self.store.get_updated_thread_subscriptions( + from_id=0, to_id=stream_id2, limit=10 + ) ) self.assertEqual(len(updates), 2) # Get updates since first change (should include only the second change) updates = self.get_success( - self.store.get_updated_thread_subscriptions(stream_id1, stream_id2, 10) + self.store.get_updated_thread_subscriptions( + from_id=stream_id1, to_id=stream_id2, limit=10 + ) ) self.assertEqual( updates, @@ -242,21 +261,27 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): other_user_id = "@other_user:test" # Set thread subscription for main user - stream_id1 = self._subscribe(self.thread_root_id, automatic=True) - assert stream_id1 is not None + stream_id1 = self._subscribe( + self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1) + ) + assert stream_id1 is not None and not isinstance( + stream_id1, AutomaticSubscriptionConflicted + ) # Set thread subscription for other user stream_id2 = self._subscribe( self.other_thread_root_id, - automatic=True, + automatic_event_orderings=EventOrderings(1, 1), user_id=other_user_id, ) - assert stream_id2 is not None + assert stream_id2 is not None and not isinstance( + stream_id2, AutomaticSubscriptionConflicted + ) # Get updates for main user updates = self.get_success( self.store.get_updated_thread_subscriptions_for_user( - self.user_id, 0, stream_id2, 10 + self.user_id, from_id=0, to_id=stream_id2, limit=10 ) ) self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)]) @@ -264,9 +289,80 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): # Get updates for other user updates = self.get_success( self.store.get_updated_thread_subscriptions_for_user( - other_user_id, 0, max(stream_id1, stream_id2), 10 + other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10 ) ) self.assertEqual( updates, [(stream_id2, self.room_id, self.other_thread_root_id)] ) + + def test_should_skip_autosubscription_after_unsubscription(self) -> None: + """ + Tests the comparison logic for whether an autoscription should be skipped + due to a chronologically earlier but logically later unsubscription. + """ + + func = ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription + + # Order of arguments: + # automatic cause event: stream order, then topological order + # unsubscribe maximums: stream order, then tological order + + # both orderings agree that the unsub is after the cause event + self.assertTrue( + func(autosub=EventOrderings(1, 1), unsubscribed_at=EventOrderings(2, 2)) + ) + + # topological ordering is inconsistent with stream ordering, + # in that case favour stream ordering because it's what /sync uses + self.assertTrue( + func(autosub=EventOrderings(1, 2), unsubscribed_at=EventOrderings(2, 1)) + ) + + # the automatic subscription is caused by a backfilled event here + # unfortunately we must fall back to topological ordering here + self.assertTrue( + func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 3)) + ) + self.assertFalse( + func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 1)) + ) + + def test_get_subscribers_to_thread(self) -> None: + """ + Test getting all subscribers to a thread at once. + + To check cache invalidations are correct, we do multiple + step-by-step rounds of subscription changes and assertions. + """ + other_user_id = "@other_user:test" + + subscribers = self.get_success( + self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id) + ) + self.assertEqual(subscribers, frozenset()) + + self._subscribe( + self.thread_root_id, automatic_event_orderings=None, user_id=self.user_id + ) + + subscribers = self.get_success( + self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id) + ) + self.assertEqual(subscribers, frozenset((self.user_id,))) + + self._subscribe( + self.thread_root_id, automatic_event_orderings=None, user_id=other_user_id + ) + + subscribers = self.get_success( + self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id) + ) + self.assertEqual(subscribers, frozenset((self.user_id, other_user_id))) + + self._unsubscribe(self.thread_root_id, user_id=self.user_id) + + subscribers = self.get_success( + self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id) + ) + self.assertEqual(subscribers, frozenset((other_user_id,))) diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index 4d2402f144..7b2ac9fce1 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -18,7 +18,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage.databases.main.transactions import DestinationRetryTimings diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py index 5d58521810..4722da5005 100644 --- a/tests/storage/test_txn_limit.py +++ b/tests/storage/test_txn_limit.py @@ -19,7 +19,7 @@ # # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage.types import Cursor diff --git a/tests/storage/test_unsafe_locale.py b/tests/storage/test_unsafe_locale.py index 4f652fc179..3c012642aa 100644 --- a/tests/storage/test_unsafe_locale.py +++ b/tests/storage/test_unsafe_locale.py @@ -36,8 +36,14 @@ class UnsafeLocaleTest(HomeserverTestCase): def test_unsafe_locale(self, mock_db_locale: MagicMock) -> None: mock_db_locale.return_value = ("B", "B") database = self.hs.get_datastores().databases[0] + server_name = self.hs.hostname - db_conn = make_conn(database._database_config, database.engine, "test_unsafe") + db_conn = make_conn( + db_config=database._database_config, + engine=database.engine, + default_txn_name="test_unsafe", + server_name=server_name, + ) with self.assertRaises(IncorrectDatabaseSetup): database.engine.check_database(db_conn) with self.assertRaises(IncorrectDatabaseSetup): @@ -47,8 +53,14 @@ class UnsafeLocaleTest(HomeserverTestCase): def test_safe_locale(self) -> None: database = self.hs.get_datastores().databases[0] assert isinstance(database.engine, PostgresEngine) + server_name = self.hs.hostname - db_conn = make_conn(database._database_config, database.engine, "test_unsafe") + db_conn = make_conn( + db_config=database._database_config, + engine=database.engine, + default_txn_name="test_unsafe", + server_name=server_name, + ) with db_conn.cursor() as txn: res = database.engine.get_db_locale(txn) self.assertEqual(res, ("C", "C")) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 80f491aff9..255de298f3 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -23,7 +23,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, cast from unittest import mock from unittest.mock import Mock, patch -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import EventTypes, Membership, UserTypes from synapse.appservice import ApplicationService diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py index 177da340e5..8d928aa55c 100644 --- a/tests/storage/test_user_filters.py +++ b/tests/storage/test_user_filters.py @@ -20,7 +20,7 @@ # -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.storage.database import LoggingTransaction diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 18792fdee3..19dafe64ed 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -28,7 +28,7 @@ from . import unittest class DistributorTestCase(unittest.TestCase): def setUp(self) -> None: - self.dist = Distributor() + self.dist = Distributor(server_name="test_server") def test_signal_dispatch(self) -> None: self.dist.declare("alert") diff --git a/tests/test_mau.py b/tests/test_mau.py index 472965e022..1000870aa9 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -22,7 +22,7 @@ from typing import List, Optional -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError diff --git a/tests/test_phone_home.py b/tests/test_phone_home.py index 16206d5a97..0b230ed0f5 100644 --- a/tests/test_phone_home.py +++ b/tests/test_phone_home.py @@ -22,7 +22,7 @@ import resource from unittest import mock -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.app.phone_stats_home import phone_stats_home from synapse.rest import admin diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index d656136972..8065ae4b8a 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -21,7 +21,7 @@ from unittest.mock import Mock from twisted.internet.interfaces import IReactorTime -from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock +from twisted.internet.testing import MemoryReactor, MemoryReactorClock from synapse.rest.client.register import register_servlets from synapse.server import HomeServer diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 35b3245708..c1eaf9a575 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -105,6 +105,13 @@ async def create_event( builder, prev_event_ids=prev_event_ids ) + # Copy over writable internal_metadata, if set + # Dev note: This isn't everything that's writable. `for k,v` doesn't work here :( + if kwargs.get("internal_metadata", {}).get("soft_failed", False): + event.internal_metadata.soft_failed = True + if kwargs.get("internal_metadata", {}).get("policy_server_spammy", False): + event.internal_metadata.policy_server_spammy = True + context = await unpersisted_context.persist(event) return event, context diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 89cbe4e54b..285e28e0f9 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -21,7 +21,9 @@ import logging from typing import Optional from unittest.mock import patch -from synapse.api.constants import EventUnsignedContentFields +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import AccountDataTypes, EventUnsignedContentFields from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext @@ -29,6 +31,7 @@ from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import create_requester +from synapse.util import Clock from synapse.visibility import filter_events_for_client, filter_events_for_server from tests import unittest @@ -272,6 +275,210 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): return event +class FilterEventsForServerAdminsTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.register_user("admin", "password", admin=True) + self.tok = self.login("admin", "password") + self.room_id = self.helper.create_room_as("admin", tok=self.tok) + self.get_success( + inject_visibility_event(self.hs, self.room_id, "@admin:test", "joined") + ) + self.regular_event = self.get_success( + inject_message_event(self.hs, self.room_id, "@admin:test", body="regular") + ) + self.soft_failed_event = self.get_success( + inject_message_event( + self.hs, + self.room_id, + "@admin:test", + body="soft failed", + soft_failed=True, + ) + ) + self.spammy_event = self.get_success( + inject_message_event( + self.hs, + self.room_id, + "@admin:test", + body="spammy", + soft_failed=True, + policy_server_spammy=True, + ) + ) + + def test_normal_operation_as_admin(self) -> None: + # `filter_events_for_client` shouldn't include soft failed events by default + # for admins. + + # Reload events from DB + events_to_filter = [ + self.get_success( + self.hs.get_storage_controllers().main.get_event( + e.event_id, + get_prev_content=True, + ) + ) + for e in [self.regular_event, self.soft_failed_event] + ] + + # Do filter & assert + filtered_events = self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@admin:test", + events_to_filter, + ) + ) + self.assertEqual( + [e.event_id for e in [self.regular_event]], + [e.event_id for e in filtered_events], + ) + + def test_see_soft_failed_events(self) -> None: + # `filter_events_for_client` should include soft failed events when configured + + # Reload events from DB + events_to_filter = [ + self.get_success( + self.hs.get_storage_controllers().main.get_event( + e.event_id, + get_prev_content=True, + ) + ) + for e in [self.regular_event, self.soft_failed_event] + ] + + # Inject client config + self.get_success( + self.hs.get_account_data_handler().add_account_data_for_user( + "@admin:test", + AccountDataTypes.SYNAPSE_ADMIN_CLIENT_CONFIG, + {"return_soft_failed_events": True}, + ) + ) + + # Sanity check + self.assertEqual(True, events_to_filter[1].internal_metadata.soft_failed) + + # Do filter & assert + filtered_events = self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@admin:test", + events_to_filter, + ) + ) + self.assertEqual( + [e.event_id for e in [self.regular_event, self.soft_failed_event]], + [e.event_id for e in filtered_events], + ) + + def test_see_policy_server_spammy_events(self) -> None: + # `filter_events_for_client` should include policy server-flagged events, but + # not other soft-failed events, when asked. + + # Reload events from DB + events_to_filter = [ + self.get_success( + self.hs.get_storage_controllers().main.get_event( + e.event_id, + get_prev_content=True, + ) + ) + for e in [self.regular_event, self.soft_failed_event, self.spammy_event] + ] + + # Inject client config + self.get_success( + self.hs.get_account_data_handler().add_account_data_for_user( + "@admin:test", + AccountDataTypes.SYNAPSE_ADMIN_CLIENT_CONFIG, + { + "return_soft_failed_events": False, + "return_policy_server_spammy_events": True, + }, + ) + ) + + # Sanity checks + self.assertEqual(True, events_to_filter[1].internal_metadata.soft_failed) + self.assertEqual(True, events_to_filter[2].internal_metadata.soft_failed) + self.assertEqual( + True, events_to_filter[2].internal_metadata.policy_server_spammy + ) + + # Do filter & assert + filtered_events = self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@admin:test", + events_to_filter, + ) + ) + self.assertEqual( + [e.event_id for e in [self.regular_event, self.spammy_event]], + [e.event_id for e in filtered_events], + ) + + def test_see_soft_failed_and_policy_server_spammy_events(self) -> None: + # `filter_events_for_client` should include both types of soft failed events + # when configured. + + # Reload events from DB + events_to_filter = [ + self.get_success( + self.hs.get_storage_controllers().main.get_event( + e.event_id, + get_prev_content=True, + ) + ) + for e in [self.regular_event, self.soft_failed_event, self.spammy_event] + ] + + # Inject client config + self.get_success( + self.hs.get_account_data_handler().add_account_data_for_user( + "@admin:test", + AccountDataTypes.SYNAPSE_ADMIN_CLIENT_CONFIG, + { + "return_soft_failed_events": True, + "return_policy_server_spammy_events": True, + }, + ) + ) + + # Sanity checks + self.assertEqual(True, events_to_filter[1].internal_metadata.soft_failed) + self.assertEqual(True, events_to_filter[2].internal_metadata.soft_failed) + self.assertEqual( + True, events_to_filter[2].internal_metadata.policy_server_spammy + ) + + # Do filter & assert + filtered_events = self.get_success( + filter_events_for_client( + self.hs.get_storage_controllers(), + "@admin:test", + events_to_filter, + ) + ) + self.assertEqual( + [ + e.event_id + for e in [self.regular_event, self.soft_failed_event, self.spammy_event] + ], + [e.event_id for e in filtered_events], + ) + + class FilterEventsForClientTestCase(HomeserverTestCase): servlets = [ admin.register_servlets, @@ -487,6 +694,8 @@ async def inject_message_event( room_id: str, sender: str, body: Optional[str] = "testytest", + soft_failed: Optional[bool] = False, + policy_server_spammy: Optional[bool] = False, ) -> EventBase: return await inject_event( hs, @@ -494,4 +703,8 @@ async def inject_message_event( sender=sender, room_id=room_id, content={"body": body, "msgtype": "m.text"}, + internal_metadata={ + "soft_failed": soft_failed, + "policy_server_spammy": policy_server_spammy, + }, ) diff --git a/tests/unittest.py b/tests/unittest.py index 24077d79d6..5e6957dc6d 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -54,9 +54,9 @@ import unpaddedbase64 from typing_extensions import Concatenate, ParamSpec from twisted.internet.defer import Deferred, ensureDeferred +from twisted.internet.testing import MemoryReactor, MemoryReactorClock from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool -from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock from twisted.trial import unittest from twisted.web.resource import Resource from twisted.web.server import Request diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py index 2dcf3a3412..532582cf87 100644 --- a/tests/util/test_batching_queue.py +++ b/tests/util/test_batching_queue.py @@ -42,15 +42,18 @@ class BatchingQueueTestCase(TestCase): # We ensure that we remove any existing metrics for "test_queue". try: - number_queued.remove("test_queue") - number_of_keys.remove("test_queue") - number_in_flight.remove("test_queue") + number_queued.remove("test_queue", "test_server") + number_of_keys.remove("test_queue", "test_server") + number_in_flight.remove("test_queue", "test_server") except KeyError: pass self._pending_calls: List[Tuple[List[str], defer.Deferred]] = [] self.queue: BatchingQueue[str, str] = BatchingQueue( - "test_queue", hs_clock, self._process_queue + name="test_queue", + server_name="test_server", + clock=hs_clock, + process_batch_callback=self._process_queue, ) async def _process_queue(self, values: List[str]) -> str: diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index f7c5f5faca..af36e685d7 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -51,20 +51,18 @@ class LoggingContextTestCase(unittest.TestCase): with LoggingContext("test"): self._check_test_key("test") - @defer.inlineCallbacks - def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]: + async def test_sleep(self) -> None: clock = Clock(reactor) - @defer.inlineCallbacks - def competing_callback() -> Generator["defer.Deferred[object]", object, None]: + async def competing_callback() -> None: with LoggingContext("competing"): - yield clock.sleep(0) + await clock.sleep(0) self._check_test_key("competing") - reactor.callLater(0, competing_callback) + reactor.callLater(0, lambda: defer.ensureDeferred(competing_callback())) with LoggingContext("one"): - yield clock.sleep(0) + await clock.sleep(0) self._check_test_key("one") def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred: @@ -108,9 +106,8 @@ class LoggingContextTestCase(unittest.TestCase): return d2 def test_run_in_background_with_blocking_fn(self) -> defer.Deferred: - @defer.inlineCallbacks - def blocking_function() -> Generator["defer.Deferred[object]", object, None]: - yield Clock(reactor).sleep(0) + async def blocking_function() -> None: + await Clock(reactor).sleep(0) return self._test_run_in_background(blocking_function) @@ -133,7 +130,7 @@ class LoggingContextTestCase(unittest.TestCase): def test_run_in_background_with_coroutine(self) -> defer.Deferred: async def testfunc() -> None: self._check_test_key("one") - d = Clock(reactor).sleep(0) + d = defer.ensureDeferred(Clock(reactor).sleep(0)) self.assertIs(current_context(), SENTINEL_CONTEXT) await d self._check_test_key("one") diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index 7bb45f9bf2..20281d04fe 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -37,7 +37,7 @@ class FederationRateLimiterTestCase(TestCase): """A simple test with the default values""" reactor, clock = get_clock() rc_config = build_rc_config() - ratelimiter = FederationRateLimiter(clock, rc_config) + ratelimiter = FederationRateLimiter("test_server", clock, rc_config) with ratelimiter.ratelimit("testhost") as d1: # shouldn't block @@ -47,7 +47,7 @@ class FederationRateLimiterTestCase(TestCase): """Test what happens when we hit the concurrent limit""" reactor, clock = get_clock() rc_config = build_rc_config({"rc_federation": {"concurrent": 2}}) - ratelimiter = FederationRateLimiter(clock, rc_config) + ratelimiter = FederationRateLimiter("test_server", clock, rc_config) with ratelimiter.ratelimit("testhost") as d1: # shouldn't block @@ -74,7 +74,7 @@ class FederationRateLimiterTestCase(TestCase): rc_config = build_rc_config( {"rc_federation": {"sleep_limit": 2, "sleep_delay": 500}} ) - ratelimiter = FederationRateLimiter(clock, rc_config) + ratelimiter = FederationRateLimiter("test_server", clock, rc_config) with ratelimiter.ratelimit("testhost") as d1: # shouldn't block @@ -105,7 +105,7 @@ class FederationRateLimiterTestCase(TestCase): } } ) - ratelimiter = FederationRateLimiter(clock, rc_config) + ratelimiter = FederationRateLimiter("test_server", clock, rc_config) with ratelimiter.ratelimit("testhost") as d: # shouldn't block diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 2c286c19a2..82baff5883 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -31,7 +31,14 @@ class RetryLimiterTestCase(HomeserverTestCase): def test_new_destination(self) -> None: """A happy-path case with a new destination and a successful operation""" store = self.hs.get_datastores().main - limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) + limiter = self.get_success( + get_retry_limiter( + destination="test_dest", + our_server_name=self.hs.hostname, + clock=self.clock, + store=store, + ) + ) # advance the clock a bit before making the request self.pump(1) @@ -46,7 +53,14 @@ class RetryLimiterTestCase(HomeserverTestCase): """General test case which walks through the process of a failing request""" store = self.hs.get_datastores().main - limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) + limiter = self.get_success( + get_retry_limiter( + destination="test_dest", + our_server_name=self.hs.hostname, + clock=self.clock, + store=store, + ) + ) min_retry_interval_ms = ( self.hs.config.federation.destination_min_retry_interval_ms @@ -72,7 +86,13 @@ class RetryLimiterTestCase(HomeserverTestCase): # now if we try again we should get a failure self.get_failure( - get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination + get_retry_limiter( + destination="test_dest", + our_server_name=self.hs.hostname, + clock=self.clock, + store=store, + ), + NotRetryingDestination, ) # @@ -80,7 +100,14 @@ class RetryLimiterTestCase(HomeserverTestCase): # self.pump(min_retry_interval_ms) - limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) + limiter = self.get_success( + get_retry_limiter( + destination="test_dest", + our_server_name=self.hs.hostname, + clock=self.clock, + store=store, + ) + ) self.pump(1) try: @@ -108,7 +135,14 @@ class RetryLimiterTestCase(HomeserverTestCase): # one more go, with success # self.reactor.advance(min_retry_interval_ms * retry_multiplier * 2.0) - limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) + limiter = self.get_success( + get_retry_limiter( + destination="test_dest", + our_server_name=self.hs.hostname, + clock=self.clock, + store=store, + ) + ) self.pump(1) with limiter: @@ -129,9 +163,10 @@ class RetryLimiterTestCase(HomeserverTestCase): limiter = self.get_success( get_retry_limiter( - "test_dest", - self.clock, - store, + destination="test_dest", + our_server_name=self.hs.hostname, + clock=self.clock, + store=store, notifier=notifier, replication_client=replication_client, ) @@ -199,7 +234,14 @@ class RetryLimiterTestCase(HomeserverTestCase): self.hs.config.federation.destination_max_retry_interval_ms ) - self.get_success(get_retry_limiter("test_dest", self.clock, store)) + self.get_success( + get_retry_limiter( + destination="test_dest", + our_server_name=self.hs.hostname, + clock=self.clock, + store=store, + ) + ) self.pump(1) failure_ts = self.clock.time_msec() @@ -216,12 +258,25 @@ class RetryLimiterTestCase(HomeserverTestCase): # Check it fails self.get_failure( - get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination + get_retry_limiter( + destination="test_dest", + our_server_name=self.hs.hostname, + clock=self.clock, + store=store, + ), + NotRetryingDestination, ) # Get past retry_interval and we can try again, and still throw an error to continue the backoff self.reactor.advance(destination_max_retry_interval_ms / 1000 + 1) - limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) + limiter = self.get_success( + get_retry_limiter( + destination="test_dest", + our_server_name=self.hs.hostname, + clock=self.clock, + store=store, + ) + ) self.pump(1) try: with limiter: @@ -239,5 +294,11 @@ class RetryLimiterTestCase(HomeserverTestCase): # Check it fails self.get_failure( - get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination + get_retry_limiter( + destination="test_dest", + our_server_name=self.hs.hostname, + clock=self.clock, + store=store, + ), + NotRetryingDestination, ) diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py index 7f6e63bd49..2171f91b4d 100644 --- a/tests/util/test_task_scheduler.py +++ b/tests/util/test_task_scheduler.py @@ -21,7 +21,7 @@ from typing import List, Optional, Tuple from twisted.internet.task import deferLater -from twisted.test.proto_helpers import MemoryReactor +from twisted.internet.testing import MemoryReactor from synapse.server import HomeServer from synapse.types import JsonMapping, ScheduledTask, TaskStatus diff --git a/tests/utils.py b/tests/utils.py index 0006bd7a8d..d1b66d4159 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -113,7 +113,12 @@ def setupdb() -> None: port=POSTGRES_PORT, password=POSTGRES_PASSWORD, ) - logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") + logging_conn = LoggingDatabaseConnection( + conn=db_conn, + engine=db_engine, + default_txn_name="tests", + server_name="test_server", + ) prepare_database(logging_conn, db_engine, None) logging_conn.close()