From 90a6bd01c23fcaba1941ccd427793bd2a0130a64 Mon Sep 17 00:00:00 2001 From: Max Kratz Date: Tue, 21 Jan 2025 19:54:31 +0100 Subject: [PATCH 01/16] Contrib: Docker: updates PostgreSQL version in `docker-compose.yml` (#18089) Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/18089.bugfix | 2 ++ contrib/docker/docker-compose.yml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) create mode 100644 changelog.d/18089.bugfix diff --git a/changelog.d/18089.bugfix b/changelog.d/18089.bugfix new file mode 100644 index 0000000000..607fab7112 --- /dev/null +++ b/changelog.d/18089.bugfix @@ -0,0 +1,2 @@ +Updates contributed `docker-compose.yml` file to PostgreSQL v15, as v12 is no longer supported by Synapse. +Contributed by @maxkratz. \ No newline at end of file diff --git a/contrib/docker/docker-compose.yml b/contrib/docker/docker-compose.yml index 36d5fd5309..9dffc852fd 100644 --- a/contrib/docker/docker-compose.yml +++ b/contrib/docker/docker-compose.yml @@ -51,7 +51,7 @@ services: - traefik.http.routers.https-synapse.tls.certResolver=le-ssl db: - image: docker.io/postgres:12-alpine + image: docker.io/postgres:15-alpine # Change that password, of course! environment: - POSTGRES_USER=synapse From 9c5d08fff8d66a7cc0e2ecfeeb783f933a778c2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sven=20M=C3=A4der?= Date: Fri, 24 Jan 2025 20:58:01 +0100 Subject: [PATCH 02/16] Ratelimit presence updates (#18000) --- changelog.d/18000.bugfix | 1 + .../conf/workers-shared-extra.yaml.j2 | 5 + .../configuration/config_documentation.md | 21 ++++ synapse/api/ratelimiting.py | 8 +- synapse/config/ratelimiting.py | 6 + synapse/rest/client/presence.py | 22 +++- synapse/rest/client/sync.py | 19 +++- tests/handlers/test_presence.py | 105 +++++++++++++++++- tests/rest/client/test_presence.py | 52 +++++++++ tests/utils.py | 1 + 10 files changed, 232 insertions(+), 8 deletions(-) create mode 100644 changelog.d/18000.bugfix diff --git a/changelog.d/18000.bugfix b/changelog.d/18000.bugfix new file mode 100644 index 0000000000..a8f1545bf5 --- /dev/null +++ b/changelog.d/18000.bugfix @@ -0,0 +1 @@ +Add rate limit `rc_presence.per_user`. This prevents load from excessive presence updates sent by clients via sync api. Also rate limit `/_matrix/client/v3/presence` as per the spec. Contributed by @rda0. diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 4c8fa65e62..797d58e9b3 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -89,6 +89,11 @@ rc_invites: per_second: 1000 burst_count: 1000 +rc_presence: + per_user: + per_second: 9999 + burst_count: 9999 + federation_rr_transactions_per_room_per_second: 9999 allow_device_name_lookup_over_federation: true diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 64392366ca..a1e671ab8e 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -1868,6 +1868,27 @@ rc_federation: concurrent: 5 ``` --- +### `rc_presence` + +This option sets ratelimiting for presence. + +The `rc_presence.per_user` option sets rate limits on how often a specific +users' presence updates are evaluated. Ratelimited presence updates sent via sync are +ignored, and no error is returned to the client. +This option also sets the rate limit for the +[`PUT /_matrix/client/v3/presence/{userId}/status`](https://spec.matrix.org/latest/client-server-api/#put_matrixclientv3presenceuseridstatus) +endpoint. + +`per_user` defaults to `per_second: 0.1`, `burst_count: 1`. + +Example configuration: +```yaml +rc_presence: + per_user: + per_second: 0.05 + burst_count: 0.5 +``` +--- ### `federation_rr_transactions_per_room_per_second` Sets outgoing federation transaction frequency for sending read-receipts, diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index b80630c5d3..229329a5ae 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -275,6 +275,7 @@ class Ratelimiter: update: bool = True, n_actions: int = 1, _time_now_s: Optional[float] = None, + pause: Optional[float] = 0.5, ) -> None: """Checks if an action can be performed. If not, raises a LimitExceededError @@ -298,6 +299,8 @@ class Ratelimiter: at all. _time_now_s: The current time. Optional, defaults to the current time according to self.clock. Only used by tests. + pause: Time in seconds to pause when an action is being limited. Defaults to 0.5 + to stop clients from "tight-looping" on retrying their request. Raises: LimitExceededError: If an action could not be performed, along with the time in @@ -316,9 +319,8 @@ class Ratelimiter: ) if not allowed: - # We pause for a bit here to stop clients from "tight-looping" on - # retrying their request. - await self.clock.sleep(0.5) + if pause: + await self.clock.sleep(pause) raise LimitExceededError( limiter_name=self._limiter_name, diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 3fa33f5373..06af4da3c5 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -228,3 +228,9 @@ class RatelimitConfig(Config): config.get("remote_media_download_burst_count", "500M") ), ) + + self.rc_presence_per_user = RatelimitSettings.parse( + config, + "rc_presence.per_user", + defaults={"per_second": 0.1, "burst_count": 1}, + ) diff --git a/synapse/rest/client/presence.py b/synapse/rest/client/presence.py index ecc52956e4..104d54cd89 100644 --- a/synapse/rest/client/presence.py +++ b/synapse/rest/client/presence.py @@ -24,7 +24,8 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError +from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.presence import format_user_presence_state from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -48,6 +49,14 @@ class PresenceStatusRestServlet(RestServlet): self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() self.auth = hs.get_auth() + self.store = hs.get_datastores().main + + # Ratelimiter for presence updates, keyed by requester. + self._presence_per_user_limiter = Ratelimiter( + store=self.store, + clock=self.clock, + cfg=hs.config.ratelimiting.rc_presence_per_user, + ) async def on_GET( self, request: SynapseRequest, user_id: str @@ -82,6 +91,17 @@ class PresenceStatusRestServlet(RestServlet): if requester.user != user: raise AuthError(403, "Can only set your own presence state") + # ignore the presence update if the ratelimit is exceeded + try: + await self._presence_per_user_limiter.ratelimit(requester) + except LimitExceededError as e: + logger.debug("User presence ratelimit exceeded; ignoring it.") + return 429, { + "errcode": Codes.LIMIT_EXCEEDED, + "error": "Too many requests", + "retry_after_ms": e.retry_after_ms, + } + state = {} content = parse_json_object_from_request(request) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index f4ef84a038..4fb9c0c8e7 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -24,9 +24,10 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState -from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.api.errors import Codes, LimitExceededError, StoreError, SynapseError from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState +from synapse.api.ratelimiting import Ratelimiter from synapse.events.utils import ( SerializeEventConfig, format_event_for_client_v2_without_room_id, @@ -126,6 +127,13 @@ class SyncRestServlet(RestServlet): cache_name="sync_valid_filter", ) + # Ratelimiter for presence updates, keyed by requester. + self._presence_per_user_limiter = Ratelimiter( + store=self.store, + clock=self.clock, + cfg=hs.config.ratelimiting.rc_presence_per_user, + ) + async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: # This will always be set by the time Twisted calls us. assert request.args is not None @@ -239,7 +247,14 @@ class SyncRestServlet(RestServlet): # send any outstanding server notices to the user. await self._server_notices_sender.on_user_syncing(user.to_string()) - affect_presence = set_presence != PresenceState.OFFLINE + # ignore the presence update if the ratelimit is exceeded but do not pause the request + try: + await self._presence_per_user_limiter.ratelimit(requester, pause=0.0) + except LimitExceededError: + affect_presence = False + logger.debug("User set_presence ratelimit exceeded; ignoring it.") + else: + affect_presence = set_presence != PresenceState.OFFLINE context = await self.presence_handler.user_syncing( user.to_string(), diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 598d6c13cd..4cf048f0df 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -45,7 +45,7 @@ from synapse.handlers.presence import ( handle_update, ) from synapse.rest import admin -from synapse.rest.client import room +from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.storage.database import LoggingDatabaseConnection from synapse.types import JsonDict, UserID, get_domain_from_id @@ -53,10 +53,15 @@ from synapse.util import Clock from tests import unittest from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.unittest import override_config class PresenceUpdateTestCase(unittest.HomeserverTestCase): - servlets = [admin.register_servlets] + servlets = [ + admin.register_servlets, + login.register_servlets, + sync.register_servlets, + ] def prepare( self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer @@ -425,6 +430,102 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): wheel_timer.insert.assert_not_called() + # `rc_presence` is set very high during unit tests to avoid ratelimiting + # subtly impacting unrelated tests. We set the ratelimiting back to a + # reasonable value for the tests specific to presence ratelimiting. + @override_config( + {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}} + ) + def test_over_ratelimit_offline_to_online_to_unavailable(self) -> None: + """ + Send a presence update, check that it went through, immediately send another one and + check that it was ignored. + """ + self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=True) + + @override_config( + {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}} + ) + def test_within_ratelimit_offline_to_online_to_unavailable(self) -> None: + """ + Send a presence update, check that it went through, advancing time a sufficient amount, + send another presence update and check that it also worked. + """ + self._test_ratelimit_offline_to_online_to_unavailable(ratelimited=False) + + @override_config( + {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}} + ) + def _test_ratelimit_offline_to_online_to_unavailable( + self, ratelimited: bool + ) -> None: + """Test rate limit for presence updates sent with sync requests. + + Args: + ratelimited: Test rate limited case. + """ + wheel_timer = Mock() + user_id = "@user:pass" + now = 5000000 + sync_url = "/sync?access_token=%s&set_presence=%s" + + # Register the user who syncs presence + user_id = self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # Get the handler (which kicks off a bunch of timers). + presence_handler = self.hs.get_presence_handler() + + # Ensure the user is initially offline. + prev_state = UserPresenceState.default(user_id) + new_state = prev_state.copy_and_replace( + state=PresenceState.OFFLINE, last_active_ts=now + ) + + state, persist_and_notify, federation_ping = handle_update( + prev_state, + new_state, + is_mine=True, + wheel_timer=wheel_timer, + now=now, + persist=False, + ) + + # Check that the user is offline. + state = self.get_success( + presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + + # Send sync request with set_presence=online. + channel = self.make_request("GET", sync_url % (access_token, "online")) + self.assertEqual(200, channel.code) + + # Assert the user is now online. + state = self.get_success( + presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + + if not ratelimited: + # Advance time a sufficient amount to avoid rate limiting. + self.reactor.advance(30) + + # Send another sync request with set_presence=unavailable. + channel = self.make_request("GET", sync_url % (access_token, "unavailable")) + self.assertEqual(200, channel.code) + + state = self.get_success( + presence_handler.get_state(UserID.from_string(user_id)) + ) + + if ratelimited: + # Assert the user is still online and presence update was ignored. + self.assertEqual(state.state, PresenceState.ONLINE) + else: + # Assert the user is now unavailable. + self.assertEqual(state.state, PresenceState.UNAVAILABLE) + class PresenceTimeoutTestCase(unittest.TestCase): """Tests different timers and that the timer does not change `status_msg` of user.""" diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 5ced8319e1..6b9c70974a 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -29,6 +29,7 @@ from synapse.types import UserID from synapse.util import Clock from tests import unittest +from tests.unittest import override_config class PresenceTestCase(unittest.HomeserverTestCase): @@ -95,3 +96,54 @@ class PresenceTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK) self.assertEqual(self.presence_handler.set_state.call_count, 0) + + @override_config( + {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}} + ) + def test_put_presence_over_ratelimit(self) -> None: + """ + Multiple PUTs to the status endpoint without sufficient delay will be rate limited. + """ + self.hs.config.server.presence_enabled = True + + body = {"presence": "here", "status_msg": "beep boop"} + channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + + self.assertEqual(channel.code, HTTPStatus.OK) + + body = {"presence": "here", "status_msg": "beep boop"} + channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + + self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS) + self.assertEqual(self.presence_handler.set_state.call_count, 1) + + @override_config( + {"rc_presence": {"per_user": {"per_second": 0.1, "burst_count": 1}}} + ) + def test_put_presence_within_ratelimit(self) -> None: + """ + Multiple PUTs to the status endpoint with sufficient delay should all call set_state. + """ + self.hs.config.server.presence_enabled = True + + body = {"presence": "here", "status_msg": "beep boop"} + channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + + self.assertEqual(channel.code, HTTPStatus.OK) + + # Advance time a sufficient amount to avoid rate limiting. + self.reactor.advance(30) + + body = {"presence": "here", "status_msg": "beep boop"} + channel = self.make_request( + "PUT", "/presence/%s/status" % (self.user_id,), body + ) + + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(self.presence_handler.set_state.call_count, 2) diff --git a/tests/utils.py b/tests/utils.py index 9fd26ef348..3f4a7bb560 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -200,6 +200,7 @@ def default_config( "per_user": {"per_second": 10000, "burst_count": 10000}, }, "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000}, + "rc_presence": {"per_user": {"per_second": 10000, "burst_count": 10000}}, "saml2_enabled": False, "public_baseurl": None, "default_identity_server": None, From 56ed412839e23a8ee79142db5a7d679192851c96 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 27 Jan 2025 15:20:41 +0000 Subject: [PATCH 03/16] Bump dawidd6/action-download-artifact from 7 to 8 (#18108) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [dawidd6/action-download-artifact](https://github.com/dawidd6/action-download-artifact) from 7 to 8.
Release notes

Sourced from dawidd6/action-download-artifact's releases.

v8

New features

  • use_unzip boolean input (defaulting to false) - if set to true, the action will use system provided unzip utility for unpacking downloaded artifact(s) (note that the action will first download the .zip artifact file, then unpack it and remove the .zip file)

What's Changed

New Contributors

Full Changelog: https://github.com/dawidd6/action-download-artifact/compare/v7...v8

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=dawidd6/action-download-artifact&package-manager=github_actions&previous-version=7&new-version=8)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/docs-pr-netlify.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/docs-pr-netlify.yaml b/.github/workflows/docs-pr-netlify.yaml index 3962f75055..0fbf6e02b7 100644 --- a/.github/workflows/docs-pr-netlify.yaml +++ b/.github/workflows/docs-pr-netlify.yaml @@ -14,7 +14,7 @@ jobs: # There's a 'download artifact' action, but it hasn't been updated for the workflow_run action # (https://github.com/actions/download-artifact/issues/60) so instead we get this mess: - name: 📥 Download artifact - uses: dawidd6/action-download-artifact@80620a5d27ce0ae443b965134db88467fc607b43 # v7 + uses: dawidd6/action-download-artifact@20319c5641d495c8a52e688b7dc5fada6c3a9fbc # v8 with: workflow: docs-pr.yaml run_id: ${{ github.event.workflow_run.id }} From 148e93576e6a8dc588ce95f1d01888c391601b05 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 27 Jan 2025 15:23:28 +0000 Subject: [PATCH 04/16] Bump log from 0.4.22 to 0.4.25 (#18098) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [log](https://github.com/rust-lang/log) from 0.4.22 to 0.4.25.
Release notes

Sourced from log's releases.

0.4.25

What's Changed

Full Changelog: https://github.com/rust-lang/log/compare/0.4.24...0.4.25

0.4.24 (yanked)

What's Changed

Full Changelog: https://github.com/rust-lang/log/compare/0.4.23...0.4.24

0.4.23 (yanked)

What's Changed

New Contributors

Full Changelog: https://github.com/rust-lang/log/compare/0.4.22...0.4.23

Changelog

Sourced from log's changelog.

[0.4.25] - 2025-01-14

What's Changed

Full Changelog: https://github.com/rust-lang/log/compare/0.4.24...0.4.25

[0.4.24] - 2025-01-11

What's Changed

Full Changelog: https://github.com/rust-lang/log/compare/0.4.23...0.4.24

[0.4.23] - 2025-01-10 (yanked)

What's Changed

New Contributors

Full Changelog: https://github.com/rust-lang/log/compare/0.4.22...0.4.23

Commits
  • 22be810 Merge pull request #663 from rust-lang/cargo/0.4.25
  • 0279730 prepare for 0.4.25 release
  • 4099bcb Merge pull request #662 from rust-lang/fix/cargo-features
  • 36e7e3f revert loosening of kv cargo features
  • 2282191 Merge pull request #660 from rust-lang/cargo/0.4.24
  • 2994f0a prepare for 0.4.24 release
  • 5fcb50e Merge pull request #659 from rust-lang/fix/feature-builds
  • 29fe9e6 fix up feature activation
  • b1824f2 use cargo hack in CI to test all feature combinations
  • e6b643d Merge pull request #656 from rust-lang/cargo/0.4.23
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=log&package-manager=cargo&previous-version=0.4.22&new-version=0.4.25)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3cec1ea4a6..d7312c0125 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -216,9 +216,9 @@ checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346" [[package]] name = "log" -version = "0.4.22" +version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" [[package]] name = "memchr" From 6ec5e13ec94c484a573bad11e3d80fc6b4b2b943 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 27 Jan 2025 11:21:10 -0600 Subject: [PATCH 05/16] Fix join being denied after being invited over federation (#18075) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This also happens for rejecting an invite. Basically, any out-of-band membership transition where we first get the membership as an `outlier` and then rely on federation filling us in to de-outlier it. This PR mainly addresses automated test flakiness, bots/scripts, and options within Synapse like [`auto_accept_invites`](https://element-hq.github.io/synapse/v1.122/usage/configuration/config_documentation.html#auto_accept_invites) that are able to react quickly (before federation is able to push us events), but also helps in generic scenarios where federation is lagging. I initially thought this might be a Synapse consistency issue (see issues labeled with [`Z-Read-After-Write`](https://github.com/matrix-org/synapse/labels/Z-Read-After-Write)) but it seems to be an event auth logic problem. Workers probably do increase the number of possible race condition scenarios that make this visible though (replication and cache invalidation lag). Fix https://github.com/element-hq/synapse/issues/15012 (probably fixes https://github.com/matrix-org/synapse/issues/15012 (https://github.com/element-hq/synapse/issues/15012)) Related to https://github.com/matrix-org/matrix-spec/issues/2062 Problems: 1. We don't consider [out-of-band membership](https://github.com/element-hq/synapse/blob/develop/docs/development/room-dag-concepts.md#out-of-band-membership-events) (outliers) in our `event_auth` logic even though we expose them in `/sync`. 1. (This PR doesn't address this point) Perhaps we should consider authing events in the persistence queue as events already in the queue could allow subsequent events to be allowed (events come through many channels: federation transaction, remote invite, remote join, local send). But this doesn't save us in the case where the event is more delayed over federation. ### What happened before? I wrote some Complement test that stresses this exact scenario and reproduces the problem: https://github.com/matrix-org/complement/pull/757 ``` COMPLEMENT_ALWAYS_PRINT_SERVER_LOGS=1 COMPLEMENT_DIR=../complement ./scripts-dev/complement.sh -run TestSynapseConsistency ``` We have `hs1` and `hs2` running in monolith mode (no workers): 1. `@charlie1:hs2` is invited and joins the room: 1. `hs1` invites `@charlie1:hs2` to a room which we receive on `hs2` as `PUT /_matrix/federation/v1/invite/{roomId}/{eventId}` (`on_invite_request(...)`) and the invite membership is persisted as an outlier. The `room_memberships` and `local_current_membership` database tables are also updated which means they are visible down `/sync` at this point. 1. `@charlie1:hs2` decides to join because it saw the invite down `/sync`. Because `hs2` is not yet in the room, this happens as a remote join `make_join`/`send_join` which comes back with all of the auth events needed to auth successfully and now `@charlie1:hs2` is successfully joined to the room. 1. `@charlie2:hs2` is invited and and tries to join the room: 1. `hs1` invites `@charlie2:hs2` to the room which we receive on `hs2` as `PUT /_matrix/federation/v1/invite/{roomId}/{eventId}` (`on_invite_request(...)`) and the invite membership is persisted as an outlier. The `room_memberships` and `local_current_membership` database tables are also updated which means they are visible down `/sync` at this point. 1. Because `hs2` is already participating in the room, we also see the invite come over federation in a transaction and we start processing it (not done yet, see below) 1. `@charlie2:hs2` decides to join because it saw the invite down `/sync`. Because `hs2`, is already in the room, this happens as a local join but we deny the event because our `event_auth` logic thinks that we have no membership in the room :x: (expected to be able to join because we saw the invite down `/sync`) 1. We finally finish processing the `@charlie2:hs2` invite event from and de-outlier it. - If this finished before we tried to join we would have been fine but this is the race condition that makes this situation visible. Logs for `hs2`: ``` 🗳️ on_invite_request: handling event 🔦 _store_room_members_txn update room_memberships: 🔦 _store_room_members_txn update local_current_membership: 📨 Notifying about new event ✅ on_invite_request: handled event 🧲 do_invite_join for @user-2-charlie1:hs2 in !sfZVBdLUezpPWetrol:hs1 🔦 _store_room_members_txn update room_memberships: 🔦 _store_room_members_txn update room_memberships: 📨 Notifying about new event ... 🗳️ on_invite_request: handling event 🔦 _store_room_members_txn update room_memberships: 🔦 _store_room_members_txn update local_current_membership: 📨 Notifying about new event ✅ on_invite_request: handled event 📬 handling received PDU in room !sfZVBdLUezpPWetrol:hs1: 📮 handle_new_client_event: handling ❌ Denying new event because 403: You are not invited to this room. synapse.http.server - 130 - INFO - POST-16 - SynapseError: 403 - You are not invited to this room. 📨 Notifying about new event ✅ handled received PDU in room !sfZVBdLUezpPWetrol:hs1: ``` --- changelog.d/18075.bugfix | 1 + synapse/event_auth.py | 4 +- synapse/events/__init__.py | 7 +- synapse/events/builder.py | 55 +- synapse/handlers/federation_event.py | 5 +- synapse/server.py | 4 +- tests/federation/test_federation_devices.py | 161 +++++ .../test_federation_out_of_band_membership.py | 671 ++++++++++++++++++ tests/federation/test_federation_server.py | 247 ++++++- tests/handlers/test_federation_event.py | 2 +- tests/handlers/test_presence.py | 104 ++- tests/handlers/test_sync.py | 11 +- .../test_federation_sender_shard.py | 110 ++- tests/rest/client/test_rooms.py | 4 +- tests/test_federation.py | 378 ---------- tests/utils.py | 19 +- 16 files changed, 1341 insertions(+), 442 deletions(-) create mode 100644 changelog.d/18075.bugfix create mode 100644 tests/federation/test_federation_devices.py create mode 100644 tests/federation/test_federation_out_of_band_membership.py delete mode 100644 tests/test_federation.py diff --git a/changelog.d/18075.bugfix b/changelog.d/18075.bugfix new file mode 100644 index 0000000000..95b486bed1 --- /dev/null +++ b/changelog.d/18075.bugfix @@ -0,0 +1 @@ +Fix join being denied after being invited over federation. Also fixes other out-of-band membership transitions. diff --git a/synapse/event_auth.py b/synapse/event_auth.py index c208b900c5..3fe344ac93 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -566,6 +566,7 @@ def _is_membership_change_allowed( logger.debug( "_is_membership_change_allowed: %s", { + "caller_membership": caller.membership if caller else None, "caller_in_room": caller_in_room, "caller_invited": caller_invited, "caller_knocked": caller_knocked, @@ -677,7 +678,8 @@ def _is_membership_change_allowed( and join_rule == JoinRules.KNOCK_RESTRICTED ) ): - if not caller_in_room and not caller_invited: + # You can only join the room if you are invited or are already in the room. + if not (caller_in_room or caller_invited): raise AuthError(403, "You are not invited to this room.") else: # TODO (erikj): may_join list diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 2e56b671f0..8e9d27138c 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -42,7 +42,7 @@ import attr from typing_extensions import Literal from unpaddedbase64 import encode_base64 -from synapse.api.constants import RelationTypes +from synapse.api.constants import EventTypes, RelationTypes from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.synapse_rust.events import EventInternalMetadata from synapse.types import JsonDict, StrCollection @@ -325,12 +325,17 @@ class EventBase(metaclass=abc.ABCMeta): def __repr__(self) -> str: rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else "" + conditional_membership_string = "" + if self.get("type") == EventTypes.Member: + conditional_membership_string = f"membership={self.membership}, " + return ( f"<{self.__class__.__name__} " f"{rejection}" f"event_id={self.event_id}, " f"type={self.get('type')}, " f"state_key={self.get('state_key')}, " + f"{conditional_membership_string}" f"outlier={self.internal_metadata.is_outlier()}" ">" ) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 10ef01131b..76df083d69 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import attr from signedjson.types import SigningKey -from synapse.api.constants import MAX_DEPTH +from synapse.api.constants import MAX_DEPTH, EventTypes from synapse.api.room_versions import ( KNOWN_EVENT_FORMAT_VERSIONS, EventFormatVersions, @@ -109,6 +109,19 @@ class EventBuilder: def is_state(self) -> bool: return self._state_key is not None + def is_mine_id(self, user_id: str) -> bool: + """Determines whether a user ID or room alias originates from this homeserver. + + Returns: + `True` if the hostname part of the user ID or room alias matches this + homeserver. + `False` otherwise, or if the user ID or room alias is malformed. + """ + localpart_hostname = user_id.split(":", 1) + if len(localpart_hostname) < 2: + return False + return localpart_hostname[1] == self._hostname + async def build( self, prev_event_ids: List[str], @@ -142,6 +155,46 @@ class EventBuilder: self, state_ids ) + # Check for out-of-band membership that may have been exposed on `/sync` but + # the events have not been de-outliered yet so they won't be part of the + # room state yet. + # + # This helps in situations where a remote homeserver invites a local user to + # a room that we're already participating in; and we've persisted the invite + # as an out-of-band membership (outlier), but it hasn't been pushed to us as + # part of a `/send` transaction yet and de-outliered. This also helps for + # any of the other out-of-band membership transitions. + # + # As an optimization, we could check if the room state already includes a + # non-`leave` membership event, then we can assume the membership event has + # been de-outliered and we don't need to check for an out-of-band + # membership. But we don't have the necessary information from a + # `StateMap[str]` and we'll just have to take the hit of this extra lookup + # for any membership event for now. + if self.type == EventTypes.Member and self.is_mine_id(self.state_key): + ( + _membership, + member_event_id, + ) = await self._store.get_local_current_membership_for_user_in_room( + user_id=self.state_key, + room_id=self.room_id, + ) + # There is no need to check if the membership is actually an + # out-of-band membership (`outlier`) as we would end up with the + # same result either way (adding the member event to the + # `auth_event_ids`). + if ( + member_event_id is not None + # We only need to be careful about duplicating the event in the + # `auth_event_ids` list (duplicate `type`/`state_key` is part of the + # authorization rules) + and member_event_id not in auth_event_ids + ): + auth_event_ids.append(member_event_id) + # Also make sure to point to the previous membership event that will + # allow this one to happen so the computed state works out. + prev_event_ids.append(member_event_id) + format_version = self.room_version.event_format # The types of auth/prev events changes between event versions. prev_events: Union[StrCollection, List[Tuple[str, Dict[str, str]]]] diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index c85deaed56..1b535ea2cb 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -2272,8 +2272,9 @@ class FederationEventHandler: event_and_contexts, backfilled=backfilled ) - # After persistence we always need to notify replication there may - # be new data. + # After persistence, we never notify clients (wake up `/sync` streams) about + # backfilled events but it's important to let all the workers know about any + # new event (backfilled or not) because TODO self._notifier.notify_replication() if self._ephemeral_messages_enabled: diff --git a/synapse/server.py b/synapse/server.py index 462e15cc2f..bd2faa61b9 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -391,7 +391,7 @@ class HomeServer(metaclass=abc.ABCMeta): def is_mine(self, domain_specific_string: DomainSpecificString) -> bool: return domain_specific_string.domain == self.hostname - def is_mine_id(self, string: str) -> bool: + def is_mine_id(self, user_id: str) -> bool: """Determines whether a user ID or room alias originates from this homeserver. Returns: @@ -399,7 +399,7 @@ class HomeServer(metaclass=abc.ABCMeta): homeserver. `False` otherwise, or if the user ID or room alias is malformed. """ - localpart_hostname = string.split(":", 1) + localpart_hostname = user_id.split(":", 1) if len(localpart_hostname) < 2: return False return localpart_hostname[1] == self.hostname diff --git a/tests/federation/test_federation_devices.py b/tests/federation/test_federation_devices.py new file mode 100644 index 0000000000..ba27e69479 --- /dev/null +++ b/tests/federation/test_federation_devices.py @@ -0,0 +1,161 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2024 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# Originally licensed under the Apache License, Version 2.0: +# . +# +# [This file includes modifications made by New Vector Limited] +# +# + +import logging +from unittest.mock import AsyncMock, Mock + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.handlers.device import DeviceListUpdater +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock +from synapse.util.retryutils import NotRetryingDestination + +from tests import unittest + +logger = logging.getLogger(__name__) + + +class DeviceListResyncTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = self.hs.get_datastores().main + + def test_retry_device_list_resync(self) -> None: + """Tests that device lists are marked as stale if they couldn't be synced, and + that stale device lists are retried periodically. + """ + remote_user_id = "@john:test_remote" + remote_origin = "test_remote" + + # Track the number of attempts to resync the user's device list. + self.resync_attempts = 0 + + # When this function is called, increment the number of resync attempts (only if + # we're querying devices for the right user ID), then raise a + # NotRetryingDestination error to fail the resync gracefully. + def query_user_devices( + destination: str, user_id: str, timeout: int = 30000 + ) -> JsonDict: + if user_id == remote_user_id: + self.resync_attempts += 1 + + raise NotRetryingDestination(0, 0, destination) + + # Register the mock on the federation client. + federation_client = self.hs.get_federation_client() + federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[method-assign] + + # Register a mock on the store so that the incoming update doesn't fail because + # we don't share a room with the user. + self.store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"]) + + # Manually inject a fake device list update. We need this update to include at + # least one prev_id so that the user's device list will need to be retried. + device_list_updater = self.hs.get_device_handler().device_list_updater + assert isinstance(device_list_updater, DeviceListUpdater) + self.get_success( + device_list_updater.incoming_device_list_update( + origin=remote_origin, + edu_content={ + "deleted": False, + "device_display_name": "Mobile", + "device_id": "QBUAZIFURK", + "prev_id": [5], + "stream_id": 6, + "user_id": remote_user_id, + }, + ) + ) + + # Check that there was one resync attempt. + self.assertEqual(self.resync_attempts, 1) + + # Check that the resync attempt failed and caused the user's device list to be + # marked as stale. + need_resync = self.get_success( + self.store.get_user_ids_requiring_device_list_resync() + ) + self.assertIn(remote_user_id, need_resync) + + # Check that waiting for 30 seconds caused Synapse to retry resyncing the device + # list. + self.reactor.advance(30) + self.assertEqual(self.resync_attempts, 2) + + def test_cross_signing_keys_retry(self) -> None: + """Tests that resyncing a device list correctly processes cross-signing keys from + the remote server. + """ + remote_user_id = "@john:test_remote" + remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" + remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" + + # Register mock device list retrieval on the federation client. + federation_client = self.hs.get_federation_client() + federation_client.query_user_devices = AsyncMock( # type: ignore[method-assign] + return_value={ + "user_id": remote_user_id, + "stream_id": 1, + "devices": [], + "master_key": { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + remote_self_signing_key: remote_self_signing_key + }, + }, + } + ) + + # Resync the device list. + device_handler = self.hs.get_device_handler() + self.get_success( + device_handler.device_list_updater.multi_user_device_resync( + [remote_user_id] + ), + ) + + # Retrieve the cross-signing keys for this user. + keys = self.get_success( + self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]), + ) + self.assertIn(remote_user_id, keys) + key = keys[remote_user_id] + assert key is not None + + # Check that the master key is the one returned by the mock. + master_key = key["master"] + self.assertEqual(len(master_key["keys"]), 1) + self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys()) + self.assertTrue(remote_master_key in master_key["keys"].values()) + + # Check that the self-signing key is the one returned by the mock. + self_signing_key = key["self_signing"] + self.assertEqual(len(self_signing_key["keys"]), 1) + self.assertTrue( + "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), + ) + self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values()) diff --git a/tests/federation/test_federation_out_of_band_membership.py b/tests/federation/test_federation_out_of_band_membership.py new file mode 100644 index 0000000000..a4a266cf06 --- /dev/null +++ b/tests/federation/test_federation_out_of_band_membership.py @@ -0,0 +1,671 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright 2020 The Matrix.org Foundation C.I.C. +# Copyright (C) 2023 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# Originally licensed under the Apache License, Version 2.0: +# . +# +# [This file includes modifications made by New Vector Limited] +# +# + +import logging +import time +import urllib.parse +from http import HTTPStatus +from typing import Any, Callable, Optional, Set, Tuple, TypeVar, Union +from unittest.mock import Mock + +import attr +from parameterized import parameterized + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.room_versions import RoomVersion, RoomVersions +from synapse.events import EventBase, make_event_from_dict +from synapse.events.utils import strip_event +from synapse.federation.federation_base import ( + event_from_pdu_json, +) +from synapse.federation.transport.client import SendJoinResponse +from synapse.http.matrixfederationclient import ( + ByteParser, +) +from synapse.http.types import QueryParams +from synapse.rest import admin +from synapse.rest.client import login, room, sync +from synapse.server import HomeServer +from synapse.types import JsonDict, MutableStateMap, StateMap +from synapse.types.handlers.sliding_sync import ( + StateValues, +) +from synapse.util import Clock + +from tests import unittest +from tests.utils import test_timeout + +logger = logging.getLogger(__name__) + + +def required_state_json_to_state_map(required_state: Any) -> StateMap[EventBase]: + state_map: MutableStateMap[EventBase] = {} + + # Scrutinize JSON values to ensure it's in the expected format + if isinstance(required_state, list): + for state_event_dict in required_state: + # Yell because we're in a test and this is unexpected + assert isinstance( + state_event_dict, dict + ), "`required_state` should be a list of event dicts" + + event_type = state_event_dict["type"] + event_state_key = state_event_dict["state_key"] + + # Yell because we're in a test and this is unexpected + assert isinstance( + event_type, str + ), "Each event in `required_state` should have a string `type`" + assert isinstance( + event_state_key, str + ), "Each event in `required_state` should have a string `state_key`" + + state_map[(event_type, event_state_key)] = make_event_from_dict( + state_event_dict + ) + else: + # Yell because we're in a test and this is unexpected + raise AssertionError("`required_state` should be a list of event dicts") + + return state_map + + +@attr.s(slots=True, auto_attribs=True) +class RemoteRoomJoinResult: + remote_room_id: str + room_version: RoomVersion + remote_room_creator_user_id: str + local_user1_id: str + local_user1_tok: str + state_map: StateMap[EventBase] + + +class OutOfBandMembershipTests(unittest.FederatingHomeserverTestCase): + """ + Tests to make sure that interactions with out-of-band membership (outliers) works as + expected. + + - invites received over federation, before we join the room + - *rejections* for said invites + + See the "Out-of-band membership events" section in + `docs/development/room-dag-concepts.md` for more information. + """ + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + sync.register_servlets, + ] + + sync_endpoint = "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync" + + def default_config(self) -> JsonDict: + conf = super().default_config() + # Federation sending is disabled by default in the test environment + # so we need to enable it like this. + conf["federation_sender_instances"] = ["master"] + + return conf + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.federation_http_client = Mock( + # The problem with using `spec=MatrixFederationHttpClient` here is that it + # requires everything to be mocked which is a lot of work that I don't want + # to do when the code only uses a few methods (`get_json` and `put_json`). + ) + return self.setup_test_homeserver( + federation_http_client=self.federation_http_client + ) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + + self.store = self.hs.get_datastores().main + self.storage_controllers = hs.get_storage_controllers() + + def do_sync( + self, sync_body: JsonDict, *, since: Optional[str] = None, tok: str + ) -> Tuple[JsonDict, str]: + """Do a sliding sync request with given body. + + Asserts the request was successful. + + Attributes: + sync_body: The full request body to use + since: Optional since token + tok: Access token to use + + Returns: + A tuple of the response body and the `pos` field. + """ + + sync_path = self.sync_endpoint + if since: + sync_path += f"?pos={since}" + + channel = self.make_request( + method="POST", + path=sync_path, + content=sync_body, + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + return channel.json_body, channel.json_body["pos"] + + def _invite_local_user_to_remote_room_and_join(self) -> RemoteRoomJoinResult: + """ + Helper to reproduce this scenario: + + 1. The remote user invites our local user to a room on their remote server (which + creates an out-of-band invite membership for user1 on our local server). + 2. The local user notices the invite from `/sync`. + 3. The local user joins the room. + 4. The local user can see that they are now joined to the room from `/sync`. + """ + + # Create a local user + local_user1_id = self.register_user("user1", "pass") + local_user1_tok = self.login(local_user1_id, "pass") + + # Create a remote room + room_creator_user_id = f"@remote-user:{self.OTHER_SERVER_NAME}" + remote_room_id = f"!remote-room:{self.OTHER_SERVER_NAME}" + room_version = RoomVersions.V10 + + room_create_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": remote_room_id, + "sender": room_creator_user_id, + "depth": 1, + "origin_server_ts": 1, + "type": EventTypes.Create, + "state_key": "", + "content": { + # The `ROOM_CREATOR` field could be removed if we used a room + # version > 10 (in favor of relying on `sender`) + EventContentFields.ROOM_CREATOR: room_creator_user_id, + EventContentFields.ROOM_VERSION: room_version.identifier, + }, + "auth_events": [], + "prev_events": [], + } + ), + room_version=room_version, + ) + + creator_membership_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": remote_room_id, + "sender": room_creator_user_id, + "depth": 2, + "origin_server_ts": 2, + "type": EventTypes.Member, + "state_key": room_creator_user_id, + "content": {"membership": Membership.JOIN}, + "auth_events": [room_create_event.event_id], + "prev_events": [room_create_event.event_id], + } + ), + room_version=room_version, + ) + + # From the remote homeserver, invite user1 on the local homserver + user1_invite_membership_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": remote_room_id, + "sender": room_creator_user_id, + "depth": 3, + "origin_server_ts": 3, + "type": EventTypes.Member, + "state_key": local_user1_id, + "content": {"membership": Membership.INVITE}, + "auth_events": [ + room_create_event.event_id, + creator_membership_event.event_id, + ], + "prev_events": [creator_membership_event.event_id], + } + ), + room_version=room_version, + ) + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/invite/{remote_room_id}/{user1_invite_membership_event.event_id}", + content={ + "event": user1_invite_membership_event.get_dict(), + "invite_room_state": [ + strip_event(room_create_event), + ], + "room_version": room_version.identifier, + }, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [(EventTypes.Member, StateValues.WILDCARD)], + "timeline_limit": 0, + } + } + } + + # Sync until the local user1 can see the invite + with test_timeout( + 3, + "Unable to find user1's invite event in the room", + ): + while True: + response_body, _ = self.do_sync(sync_body, tok=local_user1_tok) + if ( + remote_room_id in response_body["rooms"].keys() + # If they have `invite_state` for the room, they are invited + and len( + response_body["rooms"][remote_room_id].get("invite_state", []) + ) + > 0 + ): + break + + # Prevent tight-looping to allow the `test_timeout` to work + time.sleep(0.1) + + user1_join_membership_event_template = make_event_from_dict( + { + "room_id": remote_room_id, + "sender": local_user1_id, + "depth": 4, + "origin_server_ts": 4, + "type": EventTypes.Member, + "state_key": local_user1_id, + "content": {"membership": Membership.JOIN}, + "auth_events": [ + room_create_event.event_id, + user1_invite_membership_event.event_id, + ], + "prev_events": [user1_invite_membership_event.event_id], + }, + room_version=room_version, + ) + + T = TypeVar("T") + + # Mock the remote homeserver responding to our HTTP requests + # + # We're going to mock the following endpoints so that user1 can join the remote room: + # - GET /_matrix/federation/v1/make_join/{room_id}/{user_id} + # - PUT /_matrix/federation/v2/send_join/{room_id}/{user_id} + # + async def get_json( + destination: str, + path: str, + args: Optional[QueryParams] = None, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser[T]] = None, + ) -> Union[JsonDict, T]: + if ( + path + == f"/_matrix/federation/v1/make_join/{urllib.parse.quote_plus(remote_room_id)}/{urllib.parse.quote_plus(local_user1_id)}" + ): + return { + "event": user1_join_membership_event_template.get_pdu_json(), + "room_version": room_version.identifier, + } + + raise NotImplementedError( + "We have not mocked a response for `get_json(...)` for the following endpoint yet: " + + f"{destination}{path}" + ) + + self.federation_http_client.get_json.side_effect = get_json + + # PDU's that hs1 sent to hs2 + collected_pdus_from_hs1_federation_send: Set[str] = set() + + async def put_json( + destination: str, + path: str, + args: Optional[QueryParams] = None, + data: Optional[JsonDict] = None, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser[T]] = None, + backoff_on_all_error_codes: bool = False, + ) -> Union[JsonDict, T, SendJoinResponse]: + if ( + path.startswith( + f"/_matrix/federation/v2/send_join/{urllib.parse.quote_plus(remote_room_id)}/" + ) + and data is not None + and data.get("type") == EventTypes.Member + and data.get("state_key") == local_user1_id + # We're assuming this is a `ByteParser[SendJoinResponse]` + and parser is not None + ): + # As the remote server, we need to sign the event before sending it back + user1_join_membership_event_signed = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server(data), + room_version=room_version, + ) + + # Since they passed in a `parser`, we need to return the type that + # they're expecting instead of just a `JsonDict` + return SendJoinResponse( + auth_events=[ + room_create_event, + user1_invite_membership_event, + ], + state=[ + room_create_event, + creator_membership_event, + user1_invite_membership_event, + ], + event_dict=user1_join_membership_event_signed.get_pdu_json(), + event=user1_join_membership_event_signed, + members_omitted=False, + servers_in_room=[ + self.OTHER_SERVER_NAME, + ], + ) + + if path.startswith("/_matrix/federation/v1/send/") and data is not None: + for pdu in data.get("pdus", []): + event = event_from_pdu_json(pdu, room_version) + collected_pdus_from_hs1_federation_send.add(event.event_id) + + # Just acknowledge everything hs1 is trying to send hs2 + return { + event_from_pdu_json(pdu, room_version).event_id: {} + for pdu in data.get("pdus", []) + } + + raise NotImplementedError( + "We have not mocked a response for `put_json(...)` for the following endpoint yet: " + + f"{destination}{path} with the following body data: {data}" + ) + + self.federation_http_client.put_json.side_effect = put_json + + # User1 joins the room + self.helper.join(remote_room_id, local_user1_id, tok=local_user1_tok) + + # Reset the mocks now that user1 has joined the room + self.federation_http_client.get_json.side_effect = None + self.federation_http_client.put_json.side_effect = None + + # Sync until the local user1 can see that they are now joined to the room + with test_timeout( + 3, + "Unable to find user1's join event in the room", + ): + while True: + response_body, _ = self.do_sync(sync_body, tok=local_user1_tok) + if remote_room_id in response_body["rooms"].keys(): + required_state_map = required_state_json_to_state_map( + response_body["rooms"][remote_room_id]["required_state"] + ) + if ( + required_state_map.get((EventTypes.Member, local_user1_id)) + is not None + ): + break + + # Prevent tight-looping to allow the `test_timeout` to work + time.sleep(0.1) + + # Nothing needs to be sent from hs1 to hs2 since we already let the other + # homeserver know by doing the `/make_join` and `/send_join` dance. + self.assertIncludes( + collected_pdus_from_hs1_federation_send, + set(), + exact=True, + message="Didn't expect any events to be sent from hs1 over federation to hs2", + ) + + return RemoteRoomJoinResult( + remote_room_id=remote_room_id, + room_version=room_version, + remote_room_creator_user_id=room_creator_user_id, + local_user1_id=local_user1_id, + local_user1_tok=local_user1_tok, + state_map=self.get_success( + self.storage_controllers.state.get_current_state(remote_room_id) + ), + ) + + def test_can_join_from_out_of_band_invite(self) -> None: + """ + Test to make sure that we can join a room that we were invited to over + federation; even if our server has never participated in the room before. + """ + self._invite_local_user_to_remote_room_and_join() + + @parameterized.expand( + [("accept invite", Membership.JOIN), ("reject invite", Membership.LEAVE)] + ) + def test_can_x_from_out_of_band_invite_after_we_are_already_participating_in_the_room( + self, _test_description: str, membership_action: str + ) -> None: + """ + Test to make sure that we can do either a) join the room (accept the invite) or + b) reject the invite after being invited to over federation; even if we are + already participating in the room. + + This is a regression test to make sure we stress the scenario where even though + we are already participating in the room, local users can still react to invites + regardless of whether the remote server has told us about the invite event (via + a federation `/send` transaction) and we have de-outliered the invite event. + Previously, we would mistakenly throw an error saying the user wasn't in the + room when they tried to join or reject the invite. + """ + remote_room_join_result = self._invite_local_user_to_remote_room_and_join() + remote_room_id = remote_room_join_result.remote_room_id + room_version = remote_room_join_result.room_version + + # Create another local user + local_user2_id = self.register_user("user2", "pass") + local_user2_tok = self.login(local_user2_id, "pass") + + T = TypeVar("T") + + # PDU's that hs1 sent to hs2 + collected_pdus_from_hs1_federation_send: Set[str] = set() + + async def put_json( + destination: str, + path: str, + args: Optional[QueryParams] = None, + data: Optional[JsonDict] = None, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser[T]] = None, + backoff_on_all_error_codes: bool = False, + ) -> Union[JsonDict, T]: + if path.startswith("/_matrix/federation/v1/send/") and data is not None: + for pdu in data.get("pdus", []): + event = event_from_pdu_json(pdu, room_version) + collected_pdus_from_hs1_federation_send.add(event.event_id) + + # Just acknowledge everything hs1 is trying to send hs2 + return { + event_from_pdu_json(pdu, room_version).event_id: {} + for pdu in data.get("pdus", []) + } + + raise NotImplementedError( + "We have not mocked a response for `put_json(...)` for the following endpoint yet: " + + f"{destination}{path} with the following body data: {data}" + ) + + self.federation_http_client.put_json.side_effect = put_json + + # From the remote homeserver, invite user2 on the local homserver + user2_invite_membership_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": remote_room_id, + "sender": remote_room_join_result.remote_room_creator_user_id, + "depth": 5, + "origin_server_ts": 5, + "type": EventTypes.Member, + "state_key": local_user2_id, + "content": {"membership": Membership.INVITE}, + "auth_events": [ + remote_room_join_result.state_map[ + (EventTypes.Create, "") + ].event_id, + remote_room_join_result.state_map[ + ( + EventTypes.Member, + remote_room_join_result.remote_room_creator_user_id, + ) + ].event_id, + ], + "prev_events": [ + remote_room_join_result.state_map[ + (EventTypes.Member, remote_room_join_result.local_user1_id) + ].event_id + ], + } + ), + room_version=room_version, + ) + channel = self.make_signed_federation_request( + "PUT", + f"/_matrix/federation/v2/invite/{remote_room_id}/{user2_invite_membership_event.event_id}", + content={ + "event": user2_invite_membership_event.get_dict(), + "invite_room_state": [ + strip_event( + remote_room_join_result.state_map[(EventTypes.Create, "")] + ), + ], + "room_version": room_version.identifier, + }, + ) + self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) + + sync_body = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [(EventTypes.Member, StateValues.WILDCARD)], + "timeline_limit": 0, + } + } + } + + # Sync until the local user2 can see the invite + with test_timeout( + 3, + "Unable to find user2's invite event in the room", + ): + while True: + response_body, _ = self.do_sync(sync_body, tok=local_user2_tok) + if ( + remote_room_id in response_body["rooms"].keys() + # If they have `invite_state` for the room, they are invited + and len( + response_body["rooms"][remote_room_id].get("invite_state", []) + ) + > 0 + ): + break + + # Prevent tight-looping to allow the `test_timeout` to work + time.sleep(0.1) + + if membership_action == Membership.JOIN: + # User2 joins the room + join_event = self.helper.join( + remote_room_join_result.remote_room_id, + local_user2_id, + tok=local_user2_tok, + ) + expected_pdu_event_id = join_event["event_id"] + elif membership_action == Membership.LEAVE: + # User2 rejects the invite + leave_event = self.helper.leave( + remote_room_join_result.remote_room_id, + local_user2_id, + tok=local_user2_tok, + ) + expected_pdu_event_id = leave_event["event_id"] + else: + raise NotImplementedError( + "This test does not support this membership action yet" + ) + + # Sync until the local user2 can see their new membership in the room + with test_timeout( + 3, + "Unable to find user2's new membership event in the room", + ): + while True: + response_body, _ = self.do_sync(sync_body, tok=local_user2_tok) + if membership_action == Membership.JOIN: + if remote_room_id in response_body["rooms"].keys(): + required_state_map = required_state_json_to_state_map( + response_body["rooms"][remote_room_id]["required_state"] + ) + if ( + required_state_map.get((EventTypes.Member, local_user2_id)) + is not None + ): + break + elif membership_action == Membership.LEAVE: + if remote_room_id not in response_body["rooms"].keys(): + break + else: + raise NotImplementedError( + "This test does not support this membership action yet" + ) + + # Prevent tight-looping to allow the `test_timeout` to work + time.sleep(0.1) + + # Make sure that we let hs2 know about the new membership event + self.assertIncludes( + collected_pdus_from_hs1_federation_send, + {expected_pdu_event_id}, + exact=True, + message="Expected to find the event ID of the user2 membership to be sent from hs1 over federation to hs2", + ) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 88261450b1..42dc844734 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -20,14 +20,21 @@ # import logging from http import HTTPStatus +from typing import Optional, Union +from unittest.mock import Mock from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import FederationError +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.events import EventBase, make_event_from_dict +from synapse.federation.federation_base import event_from_pdu_json +from synapse.http.types import QueryParams +from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -85,6 +92,163 @@ class FederationServerTests(unittest.FederatingHomeserverTestCase): self.assertEqual(500, channel.code, channel.result) +def _create_acl_event(content: JsonDict) -> EventBase: + return make_event_from_dict( + { + "room_id": "!a:b", + "event_id": "$a:b", + "type": "m.room.server_acls", + "sender": "@a:b", + "content": content, + } + ) + + +class MessageAcceptTests(unittest.FederatingHomeserverTestCase): + """ + Tests to make sure that we don't accept flawed events from federation (incoming). + """ + + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.http_client = Mock() + return self.setup_test_homeserver(federation_http_client=self.http_client) + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + + self.store = self.hs.get_datastores().main + self.storage_controllers = hs.get_storage_controllers() + self.federation_event_handler = self.hs.get_federation_event_handler() + + # Create a local room + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + self.room_id = self.helper.create_room_as( + user1_id, tok=user1_tok, is_public=True + ) + + state_map = self.get_success( + self.storage_controllers.state.get_current_state(self.room_id) + ) + + # Figure out what the forward extremities in the room are (the most recent + # events that aren't tied into the DAG) + forward_extremity_event_ids = self.get_success( + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) + ) + + # Join a remote user to the room that will attempt to send bad events + self.remote_bad_user_id = f"@baduser:{self.OTHER_SERVER_NAME}" + self.remote_bad_user_join_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": self.remote_bad_user_id, + "state_key": self.remote_bad_user_id, + "depth": 1000, + "origin_server_ts": 1, + "type": EventTypes.Member, + "content": {"membership": Membership.JOIN}, + "auth_events": [ + state_map[(EventTypes.Create, "")].event_id, + state_map[(EventTypes.JoinRules, "")].event_id, + ], + "prev_events": list(forward_extremity_event_ids), + } + ), + room_version=RoomVersions.V10, + ) + + # Send the join, it should return None (which is not an error) + self.assertEqual( + self.get_success( + self.federation_event_handler.on_receive_pdu( + self.OTHER_SERVER_NAME, self.remote_bad_user_join_event + ) + ), + None, + ) + + # Make sure we actually joined the room + self.assertEqual( + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)), + {self.remote_bad_user_join_event.event_id}, + ) + + def test_cant_hide_direct_ancestors(self) -> None: + """ + If you send a message, you must be able to provide the direct + prev_events that said event references. + """ + + async def post_json( + destination: str, + path: str, + data: Optional[JsonDict] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + args: Optional[QueryParams] = None, + ) -> Union[JsonDict, list]: + # If it asks us for new missing events, give them NOTHING + if path.startswith("/_matrix/federation/v1/get_missing_events/"): + return {"events": []} + return {} + + self.http_client.post_json = post_json + + # Figure out what the forward extremities in the room are (the most recent + # events that aren't tied into the DAG) + forward_extremity_event_ids = self.get_success( + self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) + ) + + # Now lie about an event's prev_events + lying_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "room_id": self.room_id, + "sender": self.remote_bad_user_id, + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "content": {"body": "hewwo?"}, + "auth_events": [], + "prev_events": ["$missing_prev_event"] + + list(forward_extremity_event_ids), + } + ), + room_version=RoomVersions.V10, + ) + + with LoggingContext("test-context"): + failure = self.get_failure( + self.federation_event_handler.on_receive_pdu( + self.OTHER_SERVER_NAME, lying_event + ), + FederationError, + ) + + # on_receive_pdu should throw an error + self.assertEqual( + failure.value.args[0], + ( + "ERROR 403: Your server isn't divulging details about prev_events " + "referenced in this event." + ), + ) + + # Make sure the invalid event isn't there + extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) + self.assertEqual(extrem, {self.remote_bad_user_join_event.event_id}) + + class ServerACLsTestCase(unittest.TestCase): def test_blocked_server(self) -> None: e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]}) @@ -355,13 +519,76 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): # is probably sufficient to reassure that the bucket is updated. -def _create_acl_event(content: JsonDict) -> EventBase: - return make_event_from_dict( - { - "room_id": "!a:b", - "event_id": "$a:b", - "type": "m.room.server_acls", - "sender": "@a:b", - "content": content, +class StripUnsignedFromEventsTestCase(unittest.TestCase): + """ + Test to make sure that we handle the raw JSON events from federation carefully and + strip anything that shouldn't be there. + """ + + def test_strip_unauthorized_unsigned_values(self) -> None: + event1 = { + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$event1:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.member", + "origin": "test.servx", + "content": {"membership": "join"}, + "auth_events": [], + "unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"}, } - ) + filtered_event = event_from_pdu_json(event1, RoomVersions.V1) + # Make sure unauthorized fields are stripped from unsigned + self.assertNotIn("more warez", filtered_event.unsigned) + + def test_strip_event_maintains_allowed_fields(self) -> None: + event2 = { + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$event2:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.member", + "origin": "test.servx", + "auth_events": [], + "content": {"membership": "join"}, + "unsigned": { + "malicious garbage": "hackz", + "more warez": "more hackz", + "age": 14, + "invite_room_state": [], + }, + } + + filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1) + self.assertIn("age", filtered_event2.unsigned) + self.assertEqual(14, filtered_event2.unsigned["age"]) + self.assertNotIn("more warez", filtered_event2.unsigned) + # Invite_room_state is allowed in events of type m.room.member + self.assertIn("invite_room_state", filtered_event2.unsigned) + self.assertEqual([], filtered_event2.unsigned["invite_room_state"]) + + def test_strip_event_removes_fields_based_on_event_type(self) -> None: + event3 = { + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$event3:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.power_levels", + "origin": "test.servx", + "content": {}, + "auth_events": [], + "unsigned": { + "malicious garbage": "hackz", + "more warez": "more hackz", + "age": 14, + "invite_room_state": [], + }, + } + filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1) + self.assertIn("age", filtered_event3.unsigned) + # Invite_room_state field is only permitted in event type m.room.member + self.assertNotIn("invite_room_state", filtered_event3.unsigned) + self.assertNotIn("more warez", filtered_event3.unsigned) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 5db10fa74c..61b0efb87e 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -375,7 +375,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): In this test, we pretend we are processing a "pulled" event via backfill. The pulled event succesfully processes and the backward - extremeties are updated along with clearing out any failed pull attempts + extremities are updated along with clearing out any failed pull attempts for those old extremities. We check that we correctly cleared failed pull attempts of the diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 4cf048f0df..6b7bf112c2 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -23,14 +23,21 @@ from typing import Optional, cast from unittest.mock import Mock, call from parameterized import parameterized -from signedjson.key import generate_signing_key +from signedjson.key import ( + encode_verify_key_base64, + generate_signing_key, + get_verify_key, +) from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, Membership, PresenceState from synapse.api.presence import UserDevicePresenceState, UserPresenceState -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.events.builder import EventBuilder +from synapse.api.room_versions import ( + RoomVersion, +) +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events import EventBase, make_event_from_dict from synapse.federation.sender import FederationSender from synapse.handlers.presence import ( BUSY_ONLINE_TIMEOUT, @@ -48,6 +55,7 @@ from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.server import HomeServer from synapse.storage.database import LoggingDatabaseConnection +from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util import Clock @@ -1926,6 +1934,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # self.event_builder_for_2.hostname = "test2" self.store = hs.get_datastores().main + self.storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() @@ -2041,29 +2050,35 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): hostname = get_domain_from_id(user_id) - room_version = self.get_success(self.store.get_room_version_id(room_id)) + room_version = self.get_success(self.store.get_room_version(room_id)) - builder = EventBuilder( - state=self.state, - event_auth_handler=self._event_auth_handler, - store=self.store, - clock=self.clock, - hostname=hostname, - signing_key=self.random_signing_key, - room_version=KNOWN_ROOM_VERSIONS[room_version], - room_id=room_id, - type=EventTypes.Member, - sender=user_id, - state_key=user_id, - content={"membership": Membership.JOIN}, + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id) ) - prev_event_ids = self.get_success( - self.store.get_latest_event_ids_in_room(room_id) + # Figure out what the forward extremities in the room are (the most recent + # events that aren't tied into the DAG) + forward_extremity_event_ids = self.get_success( + self.hs.get_datastores().main.get_latest_event_ids_in_room(room_id) ) - event = self.get_success( - builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None) + event = self.create_fake_event_from_remote_server( + remote_server_name=hostname, + event_dict={ + "room_id": room_id, + "sender": user_id, + "type": EventTypes.Member, + "state_key": user_id, + "depth": 1000, + "origin_server_ts": 1, + "content": {"membership": Membership.JOIN}, + "auth_events": [ + state_map[(EventTypes.Create, "")].event_id, + state_map[(EventTypes.JoinRules, "")].event_id, + ], + "prev_events": list(forward_extremity_event_ids), + }, + room_version=room_version, ) self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event)) @@ -2071,3 +2086,50 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Check that it was successfully persisted. self.get_success(self.store.get_event(event.event_id)) self.get_success(self.store.get_event(event.event_id)) + + def create_fake_event_from_remote_server( + self, remote_server_name: str, event_dict: JsonDict, room_version: RoomVersion + ) -> EventBase: + """ + This is similar to what `FederatingHomeserverTestCase` is doing but we don't + need all of the extra baggage and we want to be able to create an event from + many remote servers. + """ + + # poke the other server's signing key into the key store, so that we don't + # make requests for it + other_server_signature_key = generate_signing_key("test") + verify_key = get_verify_key(other_server_signature_key) + verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) + + self.get_success( + self.hs.get_datastores().main.store_server_keys_response( + remote_server_name, + from_server=remote_server_name, + ts_added_ms=self.clock.time_msec(), + verify_keys={ + verify_key_id: FetchKeyResult( + verify_key=verify_key, + valid_until_ts=self.clock.time_msec() + 10000, + ), + }, + response_json={ + "verify_keys": { + verify_key_id: {"key": encode_verify_key_base64(verify_key)} + } + }, + ) + ) + + add_hashes_and_signatures( + room_version=room_version, + event_dict=event_dict, + signature_name=remote_server_name, + signing_key=other_server_signature_key, + ) + event = make_event_from_dict( + event_dict, + room_version=room_version, + ) + + return event diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 9dd0e98971..6b202dfbd5 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -17,6 +17,7 @@ # [This file includes modifications made by New Vector Limited] # # +from http import HTTPStatus from typing import Collection, ContextManager, List, Optional from unittest.mock import AsyncMock, Mock, patch @@ -347,7 +348,15 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # the prev_events used when creating the join event, such that the ban does not # precede the join. with self._patch_get_latest_events([last_room_creation_event_id]): - self.helper.join(room_id, eve, tok=eve_token) + self.helper.join( + room_id, + eve, + tok=eve_token, + # Previously, this join would succeed but now we expect it to fail at + # this point. The rest of the test is for the case when this used to + # succeed. + expect_code=HTTPStatus.FORBIDDEN, + ) # Eve makes a second, incremental sync. eve_incremental_sync_after_join: SyncResult = self.get_success( diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 4429d0f4e2..58a7a9dc72 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -22,14 +22,26 @@ import logging from unittest.mock import AsyncMock, Mock from netaddr import IPSet +from signedjson.key import ( + encode_verify_key_base64, + generate_signing_key, + get_verify_key, +) + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, Membership -from synapse.events.builder import EventBuilderFactory +from synapse.api.room_versions import RoomVersion +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events import EventBase, make_event_from_dict from synapse.handlers.typing import TypingWriterHandler from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client import login, room -from synapse.types import UserID, create_requester +from synapse.server import HomeServer +from synapse.storage.keys import FetchKeyResult +from synapse.types import JsonDict, UserID, create_requester +from synapse.util import Clock from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import get_clock @@ -63,6 +75,9 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ip_blocklist=IPSet(), ) + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.storage_controllers = hs.get_storage_controllers() + def test_send_event_single_sender(self) -> None: """Test that using a single federation sender worker correctly sends a new event. @@ -243,35 +258,92 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): self.assertTrue(sent_on_1) self.assertTrue(sent_on_2) + def create_fake_event_from_remote_server( + self, remote_server_name: str, event_dict: JsonDict, room_version: RoomVersion + ) -> EventBase: + """ + This is similar to what `FederatingHomeserverTestCase` is doing but we don't + need all of the extra baggage and we want to be able to create an event from + many remote servers. + """ + + # poke the other server's signing key into the key store, so that we don't + # make requests for it + other_server_signature_key = generate_signing_key("test") + verify_key = get_verify_key(other_server_signature_key) + verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) + + self.get_success( + self.hs.get_datastores().main.store_server_keys_response( + remote_server_name, + from_server=remote_server_name, + ts_added_ms=self.clock.time_msec(), + verify_keys={ + verify_key_id: FetchKeyResult( + verify_key=verify_key, + valid_until_ts=self.clock.time_msec() + 10000, + ), + }, + response_json={ + "verify_keys": { + verify_key_id: {"key": encode_verify_key_base64(verify_key)} + } + }, + ) + ) + + add_hashes_and_signatures( + room_version=room_version, + event_dict=event_dict, + signature_name=remote_server_name, + signing_key=other_server_signature_key, + ) + event = make_event_from_dict( + event_dict, + room_version=room_version, + ) + + return event + def create_room_with_remote_server( self, user: str, token: str, remote_server: str = "other_server" ) -> str: - room = self.helper.create_room_as(user, tok=token) + room_id = self.helper.create_room_as(user, tok=token) store = self.hs.get_datastores().main federation = self.hs.get_federation_event_handler() - prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room)) - room_version = self.get_success(store.get_room_version(room)) + room_version = self.get_success(store.get_room_version(room_id)) - factory = EventBuilderFactory(self.hs) - factory.hostname = remote_server + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id) + ) + + # Figure out what the forward extremities in the room are (the most recent + # events that aren't tied into the DAG) + prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room_id)) user_id = UserID("user", remote_server).to_string() - event_dict = { - "type": EventTypes.Member, - "state_key": user_id, - "content": {"membership": Membership.JOIN}, - "sender": user_id, - "room_id": room, - } - - builder = factory.for_room_version(room_version, event_dict) - join_event = self.get_success( - builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None) + join_event = self.create_fake_event_from_remote_server( + remote_server_name=remote_server, + event_dict={ + "room_id": room_id, + "sender": user_id, + "type": EventTypes.Member, + "state_key": user_id, + "depth": 1000, + "origin_server_ts": 1, + "content": {"membership": Membership.JOIN}, + "auth_events": [ + state_map[(EventTypes.Create, "")].event_id, + state_map[(EventTypes.JoinRules, "")].event_id, + ], + "prev_events": list(prev_event_ids), + }, + room_version=room_version, ) self.get_success(federation.on_send_membership_event(remote_server, join_event)) self.replicate() - return room + return room_id diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 4cf1a3dc51..833bd6fff8 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -742,7 +742,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(33, channel.resource_usage.db_txn_count) + self.assertEqual(34, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -755,7 +755,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(35, channel.resource_usage.db_txn_count) + self.assertEqual(36, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id diff --git a/tests/test_federation.py b/tests/test_federation.py deleted file mode 100644 index 94b0fa9856..0000000000 --- a/tests/test_federation.py +++ /dev/null @@ -1,378 +0,0 @@ -# -# This file is licensed under the Affero General Public License (AGPL) version 3. -# -# Copyright 2020 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as -# published by the Free Software Foundation, either version 3 of the -# License, or (at your option) any later version. -# -# See the GNU Affero General Public License for more details: -# . -# -# Originally licensed under the Apache License, Version 2.0: -# . -# -# [This file includes modifications made by New Vector Limited] -# -# - -from typing import Collection, List, Optional, Union -from unittest.mock import AsyncMock, Mock - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.api.errors import FederationError -from synapse.api.room_versions import RoomVersion, RoomVersions -from synapse.events import EventBase, make_event_from_dict -from synapse.events.snapshot import EventContext -from synapse.federation.federation_base import event_from_pdu_json -from synapse.handlers.device import DeviceListUpdater -from synapse.http.types import QueryParams -from synapse.logging.context import LoggingContext -from synapse.server import HomeServer -from synapse.types import JsonDict, UserID, create_requester -from synapse.util import Clock -from synapse.util.retryutils import NotRetryingDestination - -from tests import unittest - - -class MessageAcceptTests(unittest.HomeserverTestCase): - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.http_client = Mock() - return self.setup_test_homeserver(federation_http_client=self.http_client) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - user_id = UserID("us", "test") - our_user = create_requester(user_id) - room_creator = self.hs.get_room_creation_handler() - self.room_id = self.get_success( - room_creator.create_room( - our_user, room_creator._presets_dict["public_chat"], ratelimit=False - ) - )[0] - - self.store = self.hs.get_datastores().main - - # Figure out what the most recent event is - most_recent = next( - iter( - self.get_success( - self.hs.get_datastores().main.get_latest_event_ids_in_room( - self.room_id - ) - ) - ) - ) - - join_event = make_event_from_dict( - { - "room_id": self.room_id, - "sender": "@baduser:test.serv", - "state_key": "@baduser:test.serv", - "event_id": "$join:test.serv", - "depth": 1000, - "origin_server_ts": 1, - "type": "m.room.member", - "origin": "test.servx", - "content": {"membership": "join"}, - "auth_events": [], - "prev_state": [(most_recent, {})], - "prev_events": [(most_recent, {})], - } - ) - - self.handler = self.hs.get_federation_handler() - federation_event_handler = self.hs.get_federation_event_handler() - - async def _check_event_auth( - origin: Optional[str], event: EventBase, context: EventContext - ) -> None: - pass - - federation_event_handler._check_event_auth = _check_event_auth # type: ignore[method-assign] - self.client = self.hs.get_federation_client() - - async def _check_sigs_and_hash_for_pulled_events_and_fetch( - dest: str, pdus: Collection[EventBase], room_version: RoomVersion - ) -> List[EventBase]: - return list(pdus) - - self.client._check_sigs_and_hash_for_pulled_events_and_fetch = ( # type: ignore[method-assign] - _check_sigs_and_hash_for_pulled_events_and_fetch # type: ignore[assignment] - ) - - # Send the join, it should return None (which is not an error) - self.assertEqual( - self.get_success( - federation_event_handler.on_receive_pdu("test.serv", join_event) - ), - None, - ) - - # Make sure we actually joined the room - self.assertEqual( - self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)), - {"$join:test.serv"}, - ) - - def test_cant_hide_direct_ancestors(self) -> None: - """ - If you send a message, you must be able to provide the direct - prev_events that said event references. - """ - - async def post_json( - destination: str, - path: str, - data: Optional[JsonDict] = None, - long_retries: bool = False, - timeout: Optional[int] = None, - ignore_backoff: bool = False, - args: Optional[QueryParams] = None, - ) -> Union[JsonDict, list]: - # If it asks us for new missing events, give them NOTHING - if path.startswith("/_matrix/federation/v1/get_missing_events/"): - return {"events": []} - return {} - - self.http_client.post_json = post_json - - # Figure out what the most recent event is - most_recent = next( - iter( - self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) - ) - ) - - # Now lie about an event - lying_event = make_event_from_dict( - { - "room_id": self.room_id, - "sender": "@baduser:test.serv", - "event_id": "one:test.serv", - "depth": 1000, - "origin_server_ts": 1, - "type": "m.room.message", - "origin": "test.serv", - "content": {"body": "hewwo?"}, - "auth_events": [], - "prev_events": [("two:test.serv", {}), (most_recent, {})], - } - ) - - federation_event_handler = self.hs.get_federation_event_handler() - with LoggingContext("test-context"): - failure = self.get_failure( - federation_event_handler.on_receive_pdu("test.serv", lying_event), - FederationError, - ) - - # on_receive_pdu should throw an error - self.assertEqual( - failure.value.args[0], - ( - "ERROR 403: Your server isn't divulging details about prev_events " - "referenced in this event." - ), - ) - - # Make sure the invalid event isn't there - extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) - self.assertEqual(extrem, {"$join:test.serv"}) - - def test_retry_device_list_resync(self) -> None: - """Tests that device lists are marked as stale if they couldn't be synced, and - that stale device lists are retried periodically. - """ - remote_user_id = "@john:test_remote" - remote_origin = "test_remote" - - # Track the number of attempts to resync the user's device list. - self.resync_attempts = 0 - - # When this function is called, increment the number of resync attempts (only if - # we're querying devices for the right user ID), then raise a - # NotRetryingDestination error to fail the resync gracefully. - def query_user_devices( - destination: str, user_id: str, timeout: int = 30000 - ) -> JsonDict: - if user_id == remote_user_id: - self.resync_attempts += 1 - - raise NotRetryingDestination(0, 0, destination) - - # Register the mock on the federation client. - federation_client = self.hs.get_federation_client() - federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[method-assign] - - # Register a mock on the store so that the incoming update doesn't fail because - # we don't share a room with the user. - store = self.hs.get_datastores().main - store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"]) - - # Manually inject a fake device list update. We need this update to include at - # least one prev_id so that the user's device list will need to be retried. - device_list_updater = self.hs.get_device_handler().device_list_updater - assert isinstance(device_list_updater, DeviceListUpdater) - self.get_success( - device_list_updater.incoming_device_list_update( - origin=remote_origin, - edu_content={ - "deleted": False, - "device_display_name": "Mobile", - "device_id": "QBUAZIFURK", - "prev_id": [5], - "stream_id": 6, - "user_id": remote_user_id, - }, - ) - ) - - # Check that there was one resync attempt. - self.assertEqual(self.resync_attempts, 1) - - # Check that the resync attempt failed and caused the user's device list to be - # marked as stale. - need_resync = self.get_success( - store.get_user_ids_requiring_device_list_resync() - ) - self.assertIn(remote_user_id, need_resync) - - # Check that waiting for 30 seconds caused Synapse to retry resyncing the device - # list. - self.reactor.advance(30) - self.assertEqual(self.resync_attempts, 2) - - def test_cross_signing_keys_retry(self) -> None: - """Tests that resyncing a device list correctly processes cross-signing keys from - the remote server. - """ - remote_user_id = "@john:test_remote" - remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" - remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - - # Register mock device list retrieval on the federation client. - federation_client = self.hs.get_federation_client() - federation_client.query_user_devices = AsyncMock( # type: ignore[method-assign] - return_value={ - "user_id": remote_user_id, - "stream_id": 1, - "devices": [], - "master_key": { - "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - "self_signing_key": { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" + remote_self_signing_key: remote_self_signing_key - }, - }, - } - ) - - # Resync the device list. - device_handler = self.hs.get_device_handler() - self.get_success( - device_handler.device_list_updater.multi_user_device_resync( - [remote_user_id] - ), - ) - - # Retrieve the cross-signing keys for this user. - keys = self.get_success( - self.store.get_e2e_cross_signing_keys_bulk(user_ids=[remote_user_id]), - ) - self.assertIn(remote_user_id, keys) - key = keys[remote_user_id] - assert key is not None - - # Check that the master key is the one returned by the mock. - master_key = key["master"] - self.assertEqual(len(master_key["keys"]), 1) - self.assertTrue("ed25519:" + remote_master_key in master_key["keys"].keys()) - self.assertTrue(remote_master_key in master_key["keys"].values()) - - # Check that the self-signing key is the one returned by the mock. - self_signing_key = key["self_signing"] - self.assertEqual(len(self_signing_key["keys"]), 1) - self.assertTrue( - "ed25519:" + remote_self_signing_key in self_signing_key["keys"].keys(), - ) - self.assertTrue(remote_self_signing_key in self_signing_key["keys"].values()) - - -class StripUnsignedFromEventsTestCase(unittest.TestCase): - def test_strip_unauthorized_unsigned_values(self) -> None: - event1 = { - "sender": "@baduser:test.serv", - "state_key": "@baduser:test.serv", - "event_id": "$event1:test.serv", - "depth": 1000, - "origin_server_ts": 1, - "type": "m.room.member", - "origin": "test.servx", - "content": {"membership": "join"}, - "auth_events": [], - "unsigned": {"malicious garbage": "hackz", "more warez": "more hackz"}, - } - filtered_event = event_from_pdu_json(event1, RoomVersions.V1) - # Make sure unauthorized fields are stripped from unsigned - self.assertNotIn("more warez", filtered_event.unsigned) - - def test_strip_event_maintains_allowed_fields(self) -> None: - event2 = { - "sender": "@baduser:test.serv", - "state_key": "@baduser:test.serv", - "event_id": "$event2:test.serv", - "depth": 1000, - "origin_server_ts": 1, - "type": "m.room.member", - "origin": "test.servx", - "auth_events": [], - "content": {"membership": "join"}, - "unsigned": { - "malicious garbage": "hackz", - "more warez": "more hackz", - "age": 14, - "invite_room_state": [], - }, - } - - filtered_event2 = event_from_pdu_json(event2, RoomVersions.V1) - self.assertIn("age", filtered_event2.unsigned) - self.assertEqual(14, filtered_event2.unsigned["age"]) - self.assertNotIn("more warez", filtered_event2.unsigned) - # Invite_room_state is allowed in events of type m.room.member - self.assertIn("invite_room_state", filtered_event2.unsigned) - self.assertEqual([], filtered_event2.unsigned["invite_room_state"]) - - def test_strip_event_removes_fields_based_on_event_type(self) -> None: - event3 = { - "sender": "@baduser:test.serv", - "state_key": "@baduser:test.serv", - "event_id": "$event3:test.serv", - "depth": 1000, - "origin_server_ts": 1, - "type": "m.room.power_levels", - "origin": "test.servx", - "content": {}, - "auth_events": [], - "unsigned": { - "malicious garbage": "hackz", - "more warez": "more hackz", - "age": 14, - "invite_room_state": [], - }, - } - filtered_event3 = event_from_pdu_json(event3, RoomVersions.V1) - self.assertIn("age", filtered_event3.unsigned) - # Invite_room_state field is only permitted in event type m.room.member - self.assertNotIn("invite_room_state", filtered_event3.unsigned) - self.assertNotIn("more warez", filtered_event3.unsigned) diff --git a/tests/utils.py b/tests/utils.py index 3f4a7bb560..d4aebc3069 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -400,11 +400,24 @@ class TestTimeout(Exception): class test_timeout: + """ + FIXME: This implementation is not robust against other code tight-looping and + preventing the signals propagating and timing out the test. You may need to add + `time.sleep(0.1)` to your code in order to allow this timeout to work correctly. + + ```py + with test_timeout(3): + while True: + my_checking_func() + time.sleep(0.1) + ``` + """ + def __init__(self, seconds: int, error_message: Optional[str] = None) -> None: - if error_message is None: - error_message = "test timed out after {}s.".format(seconds) + self.error_message = f"Test timed out after {seconds}s" + if error_message is not None: + self.error_message += f": {error_message}" self.seconds = seconds - self.error_message = error_message def handle_timeout(self, signum: int, frame: Optional[FrameType]) -> None: raise TestTimeout(self.error_message) From b41a9ebb38163f0bd9b08ba411d31a2bb515d9ef Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 27 Jan 2025 18:39:51 +0000 Subject: [PATCH 06/16] OIDC: increase length of generated `nonce` parameter from 30->32 chars (#18109) --- changelog.d/18109.misc | 1 + synapse/handlers/oidc.py | 16 +++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 changelog.d/18109.misc diff --git a/changelog.d/18109.misc b/changelog.d/18109.misc new file mode 100644 index 0000000000..c310e76f78 --- /dev/null +++ b/changelog.d/18109.misc @@ -0,0 +1 @@ +Increase the length of the generated `nonce` parameter when perfoming OIDC logins to comply with the TI-Messenger spec. \ No newline at end of file diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 22b59829fa..701e828081 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -1002,7 +1002,21 @@ class OidcProvider: """ state = generate_token() - nonce = generate_token() + + # Generate a nonce 32 characters long. When encoded with base64url later on, + # the nonce will be 43 characters when sent to the identity provider. + # + # While RFC7636 does not specify a minimum length for the `nonce` + # parameter, the TI-Messenger IDP_FD spec v1.7.3 does require it to be + # between 43 and 128 characters. This spec concerns using Matrix for + # communication in German healthcare. + # + # As increasing the length only strengthens security, we use this length + # to allow TI-Messenger deployments using Synapse to satisfy this + # external spec. + # + # See https://github.com/element-hq/synapse/pull/18109 for more context. + nonce = generate_token(length=32) code_verifier = "" if not client_redirect_url: From c53999dab871fa10d34487b0bf654eee6827a711 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:04:41 +0000 Subject: [PATCH 07/16] Bump types-bleach from 6.1.0.20240331 to 6.2.0.20241123 (#18082) Bumps [types-bleach](https://github.com/python/typeshed) from 6.1.0.20240331 to 6.2.0.20241123.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=types-bleach&package-manager=pip&previous-version=6.1.0.20240331&new-version=6.2.0.20241123)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- poetry.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9764904b81..36622aa13a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2706,13 +2706,13 @@ twisted = "*" [[package]] name = "types-bleach" -version = "6.1.0.20240331" +version = "6.2.0.20241123" description = "Typing stubs for bleach" optional = false python-versions = ">=3.8" files = [ - {file = "types-bleach-6.1.0.20240331.tar.gz", hash = "sha256:2ee858a84fb06fc2225ff56ba2f7f6c88b65638659efae0d7bfd6b24a1b5a524"}, - {file = "types_bleach-6.1.0.20240331-py3-none-any.whl", hash = "sha256:399bc59bfd20a36a56595f13f805e56c8a08e5a5c07903e5cf6fafb5a5107dd4"}, + {file = "types_bleach-6.2.0.20241123-py3-none-any.whl", hash = "sha256:c6e58b3646665ca7c6b29890375390f4569e84f0cf5c171e0fe1ddb71a7be86a"}, + {file = "types_bleach-6.2.0.20241123.tar.gz", hash = "sha256:dac5fe9015173514da3ac810c1a935619a3ccbcc5d66c4cbf4707eac00539057"}, ] [package.dependencies] From 579f4ac1cd6b0d15a12fabcca9d34d5ba3242ce2 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:24:07 +0000 Subject: [PATCH 08/16] Bump serde_json from 1.0.135 to 1.0.137 (#18099) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [serde_json](https://github.com/serde-rs/json) from 1.0.135 to 1.0.137.
Release notes

Sourced from serde_json's releases.

v1.0.137

  • Turn on "float_roundtrip" and "unbounded_depth" features for serde_json in play.rust-lang.org (#1231)

v1.0.136

  • Optimize serde_json::value::Serializer::serialize_map by using Map::with_capacity (#1230, thanks @​goffrie)
Commits
  • eb49e28 Release 1.0.137
  • 51c48ab Merge pull request #1231 from dtolnay/playground
  • 7d8f15b Enable "float_roundtrip" and "unbounded_depth" features in playground
  • a46f14c Release 1.0.136
  • eb9f3f6 Merge pull request #1230 from goffrie/patch-1
  • 513e5b2 Use Map::with_capacity in value::Serializer::serialize_map
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=serde_json&package-manager=cargo&previous-version=1.0.135&new-version=1.0.137)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- Cargo.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d7312c0125..82ed544226 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -449,9 +449,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.135" +version = "1.0.137" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" +checksum = "930cfb6e6abf99298aaad7d29abbef7a9999a9a8806a40088f55f0dcec03146b" dependencies = [ "itoa", "memchr", From 8f27b3af0723feb201933266ad0e8688ce3cb37e Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 27 Jan 2025 21:28:00 +0000 Subject: [PATCH 09/16] Bump python-multipart from 0.0.18 to 0.0.20 (#18096) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [python-multipart](https://github.com/Kludex/python-multipart) from 0.0.18 to 0.0.20.
Release notes

Sourced from python-multipart's releases.

Version 0.0.20

What's Changed

New Contributors

Full Changelog: https://github.com/Kludex/python-multipart/compare/0.0.19...0.0.20

Version 0.0.19

What's Changed


Full Changelog: https://github.com/Kludex/python-multipart/compare/0.0.18...0.0.19

Changelog

Sourced from python-multipart's changelog.

0.0.20 (2024-12-16)

  • Handle messages containing only end boundary #142.

0.0.19 (2024-11-30)

  • Don't warn when CRLF is found after last boundary on MultipartParser #193.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=python-multipart&package-manager=pip&previous-version=0.0.18&new-version=0.0.20)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- poetry.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index 36622aa13a..c9c567d6e3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1960,13 +1960,13 @@ six = ">=1.5" [[package]] name = "python-multipart" -version = "0.0.18" +version = "0.0.20" description = "A streaming multipart parser for Python" optional = false python-versions = ">=3.8" files = [ - {file = "python_multipart-0.0.18-py3-none-any.whl", hash = "sha256:efe91480f485f6a361427a541db4796f9e1591afc0fb8e7a4ba06bfbc6708996"}, - {file = "python_multipart-0.0.18.tar.gz", hash = "sha256:7a68db60c8bfb82e460637fa4750727b45af1d5e2ed215593f917f64694d34fe"}, + {file = "python_multipart-0.0.20-py3-none-any.whl", hash = "sha256:8a62d3a8335e06589fe01f2a3e178cdcc632f3fbe0d492ad9ee0ec35aab1f104"}, + {file = "python_multipart-0.0.20.tar.gz", hash = "sha256:8dd0cab45b8e23064ae09147625994d090fa46f5b0d1e13af944c331a7fa9d13"}, ] [[package]] From 628351b98de9e8f7340b139ef77cd530da286d2d Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Tue, 28 Jan 2025 00:37:24 +0000 Subject: [PATCH 10/16] Never autojoin deactivated & suspended users. (#18073) This PR changes the logic so that deactivated users are always ignored. Suspended users were already effectively ignored as Synapse forbids a join while suspended. --------- Co-authored-by: Devon Hudson --- changelog.d/18073.bugfix | 1 + synapse/events/auto_accept_invites.py | 85 ++++++----- tests/events/test_auto_accept_invites.py | 181 ++++++++++++++++++++++- 3 files changed, 232 insertions(+), 35 deletions(-) create mode 100644 changelog.d/18073.bugfix diff --git a/changelog.d/18073.bugfix b/changelog.d/18073.bugfix new file mode 100644 index 0000000000..eeb56a7a61 --- /dev/null +++ b/changelog.d/18073.bugfix @@ -0,0 +1 @@ +Deactivated users will no longer automatically accept an invite when `auto_accept_invites` is enabled. \ No newline at end of file diff --git a/synapse/events/auto_accept_invites.py b/synapse/events/auto_accept_invites.py index d88ec51d9d..4295107c47 100644 --- a/synapse/events/auto_accept_invites.py +++ b/synapse/events/auto_accept_invites.py @@ -66,50 +66,67 @@ class InviteAutoAccepter: event: The incoming event. """ # Check if the event is an invite for a local user. - is_invite_for_local_user = ( - event.type == EventTypes.Member - and event.is_state() - and event.membership == Membership.INVITE - and self._api.is_mine(event.state_key) - ) + if ( + event.type != EventTypes.Member + or event.is_state() is False + or event.membership != Membership.INVITE + or self._api.is_mine(event.state_key) is False + ): + return # Only accept invites for direct messages if the configuration mandates it. is_direct_message = event.content.get("is_direct", False) - is_allowed_by_direct_message_rules = ( - not self._config.accept_invites_only_for_direct_messages - or is_direct_message is True - ) + if ( + self._config.accept_invites_only_for_direct_messages + and is_direct_message is False + ): + return # Only accept invites from remote users if the configuration mandates it. is_from_local_user = self._api.is_mine(event.sender) - is_allowed_by_local_user_rules = ( - not self._config.accept_invites_only_from_local_users - or is_from_local_user is True + if ( + self._config.accept_invites_only_from_local_users + and is_from_local_user is False + ): + return + + # Check the user is activated. + recipient = await self._api.get_userinfo_by_id(event.state_key) + + # Ignore if the user doesn't exist. + if recipient is None: + return + + # Never accept invites for deactivated users. + if recipient.is_deactivated: + return + + # Never accept invites for suspended users. + if recipient.suspended: + return + + # Never accept invites for locked users. + if recipient.locked: + return + + # Make the user join the room. We run this as a background process to circumvent a race condition + # that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12) + run_as_background_process( + "retry_make_join", + self._retry_make_join, + event.state_key, + event.state_key, + event.room_id, + "join", + bg_start_span=False, ) - if ( - is_invite_for_local_user - and is_allowed_by_direct_message_rules - and is_allowed_by_local_user_rules - ): - # Make the user join the room. We run this as a background process to circumvent a race condition - # that occurs when responding to invites over federation (see https://github.com/matrix-org/synapse-auto-accept-invite/issues/12) - run_as_background_process( - "retry_make_join", - self._retry_make_join, - event.state_key, - event.state_key, - event.room_id, - "join", - bg_start_span=False, + if is_direct_message: + # Mark this room as a direct message! + await self._mark_room_as_direct_message( + event.state_key, event.sender, event.room_id ) - if is_direct_message: - # Mark this room as a direct message! - await self._mark_room_as_direct_message( - event.state_key, event.sender, event.room_id - ) - async def _mark_room_as_direct_message( self, user_id: str, dm_user_id: str, room_id: str ) -> None: diff --git a/tests/events/test_auto_accept_invites.py b/tests/events/test_auto_accept_invites.py index 7fb4d4fa90..d4e87b1b7f 100644 --- a/tests/events/test_auto_accept_invites.py +++ b/tests/events/test_auto_accept_invites.py @@ -39,7 +39,7 @@ from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer -from synapse.types import StreamToken, create_requester +from synapse.types import StreamToken, UserID, UserInfo, create_requester from synapse.util import Clock from tests.handlers.test_sync import generate_sync_config @@ -349,6 +349,169 @@ class AutoAcceptInvitesTestCase(FederatingHomeserverTestCase): join_updates, _ = sync_join(self, invited_user_id) self.assertEqual(len(join_updates), 0) + @override_config( + { + "auto_accept_invites": { + "enabled": True, + }, + } + ) + async def test_ignore_invite_for_missing_user(self) -> None: + """Tests that receiving an invite for a missing user is ignored.""" + inviting_user_id = self.register_user("inviter", "pass") + inviting_user_tok = self.login("inviter", "pass") + + # A local user who receives an invite + invited_user_id = "@fake:" + self.hs.config.server.server_name + + # Create a room and send an invite to the other user + room_id = self.helper.create_room_as( + inviting_user_id, + tok=inviting_user_tok, + ) + + self.helper.invite( + room_id, + inviting_user_id, + invited_user_id, + tok=inviting_user_tok, + ) + + join_updates, _ = sync_join(self, inviting_user_id) + # Assert that the last event in the room was not a member event for the target user. + self.assertEqual( + join_updates[0].timeline.events[-1].content["membership"], "invite" + ) + + @override_config( + { + "auto_accept_invites": { + "enabled": True, + }, + } + ) + async def test_ignore_invite_for_deactivated_user(self) -> None: + """Tests that receiving an invite for a deactivated user is ignored.""" + inviting_user_id = self.register_user("inviter", "pass", admin=True) + inviting_user_tok = self.login("inviter", "pass") + + # A local user who receives an invite + invited_user_id = self.register_user("invitee", "pass") + + # Create a room and send an invite to the other user + room_id = self.helper.create_room_as( + inviting_user_id, + tok=inviting_user_tok, + ) + + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % invited_user_id, + {"deactivated": True}, + access_token=inviting_user_tok, + ) + + assert channel.code == 200 + + self.helper.invite( + room_id, + inviting_user_id, + invited_user_id, + tok=inviting_user_tok, + ) + + join_updates, b = sync_join(self, inviting_user_id) + # Assert that the last event in the room was not a member event for the target user. + self.assertEqual( + join_updates[0].timeline.events[-1].content["membership"], "invite" + ) + + @override_config( + { + "auto_accept_invites": { + "enabled": True, + }, + } + ) + async def test_ignore_invite_for_suspended_user(self) -> None: + """Tests that receiving an invite for a suspended user is ignored.""" + inviting_user_id = self.register_user("inviter", "pass", admin=True) + inviting_user_tok = self.login("inviter", "pass") + + # A local user who receives an invite + invited_user_id = self.register_user("invitee", "pass") + + # Create a room and send an invite to the other user + room_id = self.helper.create_room_as( + inviting_user_id, + tok=inviting_user_tok, + ) + + channel = self.make_request( + "PUT", + f"/_synapse/admin/v1/suspend/{invited_user_id}", + {"suspend": True}, + access_token=inviting_user_tok, + ) + + assert channel.code == 200 + + self.helper.invite( + room_id, + inviting_user_id, + invited_user_id, + tok=inviting_user_tok, + ) + + join_updates, b = sync_join(self, inviting_user_id) + # Assert that the last event in the room was not a member event for the target user. + self.assertEqual( + join_updates[0].timeline.events[-1].content["membership"], "invite" + ) + + @override_config( + { + "auto_accept_invites": { + "enabled": True, + }, + } + ) + async def test_ignore_invite_for_locked_user(self) -> None: + """Tests that receiving an invite for a suspended user is ignored.""" + inviting_user_id = self.register_user("inviter", "pass", admin=True) + inviting_user_tok = self.login("inviter", "pass") + + # A local user who receives an invite + invited_user_id = self.register_user("invitee", "pass") + + # Create a room and send an invite to the other user + room_id = self.helper.create_room_as( + inviting_user_id, + tok=inviting_user_tok, + ) + + channel = self.make_request( + "PUT", + f"/_synapse/admin/v2/users/{invited_user_id}", + {"locked": True}, + access_token=inviting_user_tok, + ) + + assert channel.code == 200 + + self.helper.invite( + room_id, + inviting_user_id, + invited_user_id, + tok=inviting_user_tok, + ) + + join_updates, b = sync_join(self, inviting_user_id) + # Assert that the last event in the room was not a member event for the target user. + self.assertEqual( + join_updates[0].timeline.events[-1].content["membership"], "invite" + ) + _request_key = 0 @@ -647,6 +810,22 @@ def create_module( module_api.is_mine.side_effect = lambda a: a.split(":")[1] == "test" module_api.worker_name = worker_name module_api.sleep.return_value = make_multiple_awaitable(None) + module_api.get_userinfo_by_id.return_value = UserInfo( + user_id=UserID.from_string("@user:test"), + is_admin=False, + is_guest=False, + consent_server_notice_sent=None, + consent_ts=None, + consent_version=None, + appservice_id=None, + creation_ts=0, + user_type=None, + is_deactivated=False, + locked=False, + is_shadow_banned=False, + approved=True, + suspended=False, + ) if config_override is None: config_override = {} From 3d8535b1def2fde26a10bdf15d0b9dbd941a6005 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 28 Jan 2025 08:37:58 -0700 Subject: [PATCH 11/16] 1.123.0 --- CHANGES.md | 7 +++++++ debian/changelog | 6 ++++++ pyproject.toml | 2 +- 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index f3ad507de8..cc6426751d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,3 +1,10 @@ +# Synapse 1.123.0 (2025-01-28) + +No significant changes since 1.123.0rc1. + + + + # Synapse 1.123.0rc1 (2025-01-21) ### Features diff --git a/debian/changelog b/debian/changelog index 2a18b776f6..a470dff676 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.123.0) stable; urgency=medium + + * New Synapse release 1.123.0. + + -- Synapse Packaging team Tue, 28 Jan 2025 08:37:34 -0700 + matrix-synapse-py3 (1.123.0~rc1) stable; urgency=medium * New Synapse release 1.123.0rc1. diff --git a/pyproject.toml b/pyproject.toml index e6899207da..1cd874716e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,7 +97,7 @@ module-name = "synapse.synapse_rust" [tool.poetry] name = "matrix-synapse" -version = "1.123.0rc1" +version = "1.123.0" description = "Homeserver for the Matrix decentralised comms protocol" authors = ["Matrix.org Team and Contributors "] license = "AGPL-3.0-or-later" From a0b70473fc08611aa7ecdc50ce4404e882d9e7f9 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 29 Jan 2025 18:14:02 -0600 Subject: [PATCH 12/16] Raise an error if someone is using an incorrect suffix in a config duration string (#18112) Previously, a value like `5q` would be interpreted as 5 milliseconds. We should just raise an error instead of letting someone run with a misconfiguration. --- changelog.d/18112.bugfix | 1 + synapse/config/_base.py | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) create mode 100644 changelog.d/18112.bugfix diff --git a/changelog.d/18112.bugfix b/changelog.d/18112.bugfix new file mode 100644 index 0000000000..61c94280d8 --- /dev/null +++ b/changelog.d/18112.bugfix @@ -0,0 +1 @@ +Raise an error if someone is using an incorrect suffix in a config duration string. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index adce34c03a..912346a423 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -221,9 +221,13 @@ class Config: The number of milliseconds in the duration. Raises: - TypeError, if given something other than an integer or a string + TypeError: if given something other than an integer or a string, or the + duration is using an incorrect suffix. ValueError: if given a string not of the form described above. """ + # For integers, we prefer to use `type(value) is int` instead of + # `isinstance(value, int)` because we want to exclude subclasses of int, such as + # bool. if type(value) is int: # noqa: E721 return value elif isinstance(value, str): @@ -246,9 +250,20 @@ class Config: if suffix in sizes: value = value[:-1] size = sizes[suffix] + elif suffix.isdigit(): + # No suffix is treated as milliseconds. + value = value + size = 1 + else: + raise TypeError( + f"Bad duration suffix {value} (expected no suffix or one of these suffixes: {sizes.keys()})" + ) + return int(value) * size else: - raise TypeError(f"Bad duration {value!r}") + raise TypeError( + f"Bad duration type {value!r} (expected int or string duration)" + ) @staticmethod def abspath(file_path: str) -> str: From ac1bf682ff012ee8af5153318eec5d25ed786e90 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Thu, 30 Jan 2025 20:48:12 +0000 Subject: [PATCH 13/16] Allow (un)block_room storage functions to be called on workers (#18119) This is so workers can call these functions. This was preventing the [Delete Room Admin API](https://element-hq.github.io/synapse/latest/admin_api/rooms.html#version-2-new-version) from succeeding when `block: true` was specified. This was because we had `run_background_tasks_on` configured to run on a separate worker. As workers weren't able to call the `block_room` storage function before this PR, the (delete room) task failed when taken off the queue by the worker. --- changelog.d/18119.bugfix | 1 + synapse/storage/databases/main/room.py | 88 +++++++++++++------------- 2 files changed, 45 insertions(+), 44 deletions(-) create mode 100644 changelog.d/18119.bugfix diff --git a/changelog.d/18119.bugfix b/changelog.d/18119.bugfix new file mode 100644 index 0000000000..c8ac53f9d4 --- /dev/null +++ b/changelog.d/18119.bugfix @@ -0,0 +1 @@ +Fix a bug where the [Delete Room Admin API](https://element-hq.github.io/synapse/latest/admin_api/rooms.html#version-2-new-version) would fail if the `block` parameter was set to `true` and a worker other than the main process was configured to handle background tasks. \ No newline at end of file diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 2522bebd72..d673adba16 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1181,6 +1181,50 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): return total_media_quarantined + async def block_room(self, room_id: str, user_id: str) -> None: + """Marks the room as blocked. + + Can be called multiple times (though we'll only track the last user to + block this room). + + Can be called on a room unknown to this homeserver. + + Args: + room_id: Room to block + user_id: Who blocked it + """ + await self.db_pool.simple_upsert( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + values={}, + insertion_values={"user_id": user_id}, + desc="block_room", + ) + await self.db_pool.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) + + async def unblock_room(self, room_id: str) -> None: + """Remove the room from blocking list. + + Args: + room_id: Room to unblock + """ + await self.db_pool.simple_delete( + table="blocked_rooms", + keyvalues={"room_id": room_id}, + desc="unblock_room", + ) + await self.db_pool.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, + (room_id,), + ) + async def get_rooms_for_retention_period_in_range( self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False ) -> Dict[str, RetentionPolicy]: @@ -2500,50 +2544,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore): ) return next_id - async def block_room(self, room_id: str, user_id: str) -> None: - """Marks the room as blocked. - - Can be called multiple times (though we'll only track the last user to - block this room). - - Can be called on a room unknown to this homeserver. - - Args: - room_id: Room to block - user_id: Who blocked it - """ - await self.db_pool.simple_upsert( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - values={}, - insertion_values={"user_id": user_id}, - desc="block_room", - ) - await self.db_pool.runInteraction( - "block_room_invalidation", - self._invalidate_cache_and_stream, - self.is_room_blocked, - (room_id,), - ) - - async def unblock_room(self, room_id: str) -> None: - """Remove the room from blocking list. - - Args: - room_id: Room to unblock - """ - await self.db_pool.simple_delete( - table="blocked_rooms", - keyvalues={"room_id": room_id}, - desc="unblock_room", - ) - await self.db_pool.runInteraction( - "block_room_invalidation", - self._invalidate_cache_and_stream, - self.is_room_blocked, - (room_id,), - ) - async def clear_partial_state_room(self, room_id: str) -> Optional[int]: """Clears the partial state flag for a room. From aa6e5c2ecbe58dec1c38c33949fc70c88f39c242 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 3 Feb 2025 18:29:15 +0100 Subject: [PATCH 14/16] Add locking to more safely delete state groups: Part 1 (#18107) Currently we don't really have anything that stops us from deleting state groups when an in-flight event references it. This is a fairly rare race currently, but we want to be able to more aggressively delete state groups so it is important to address this to ensure that the database remains valid. This implements the locking, but doesn't actually use it. See the class docstring of the new data store for an explanation for how this works. --------- Co-authored-by: Devon Hudson --- changelog.d/18107.bugfix | 1 + synapse/handlers/federation_event.py | 18 +- synapse/state/__init__.py | 61 ++- synapse/storage/controllers/persist_events.py | 32 +- synapse/storage/databases/__init__.py | 10 +- synapse/storage/databases/state/deletion.py | 446 ++++++++++++++++++ synapse/storage/databases/state/store.py | 39 +- synapse/storage/schema/__init__.py | 5 +- .../delta/89/01_state_groups_deletion.sql | 39 ++ tests/handlers/test_federation_event.py | 9 +- tests/rest/client/test_rooms.py | 4 +- tests/storage/test_state_deletion.py | 411 ++++++++++++++++ tests/test_state.py | 18 +- 13 files changed, 1047 insertions(+), 46 deletions(-) create mode 100644 changelog.d/18107.bugfix create mode 100644 synapse/storage/databases/state/deletion.py create mode 100644 synapse/storage/schema/state/delta/89/01_state_groups_deletion.sql create mode 100644 tests/storage/test_state_deletion.py diff --git a/changelog.d/18107.bugfix b/changelog.d/18107.bugfix new file mode 100644 index 0000000000..4d0c19fab9 --- /dev/null +++ b/changelog.d/18107.bugfix @@ -0,0 +1 @@ +Fix rare edge case where state groups could be deleted while we are persisting new events that reference them. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 1b535ea2cb..1e738f484f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -151,6 +151,8 @@ class FederationEventHandler: def __init__(self, hs: "HomeServer"): self._clock = hs.get_clock() self._store = hs.get_datastores().main + self._state_store = hs.get_datastores().state + self._state_deletion_store = hs.get_datastores().state_deletion self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state @@ -580,7 +582,9 @@ class FederationEventHandler: room_version.identifier, state_maps_to_resolve, event_map=None, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_deletion_store + ), ) ) else: @@ -1179,7 +1183,9 @@ class FederationEventHandler: room_version, state_maps, event_map={event_id: event}, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_deletion_store + ), ) except Exception as e: @@ -1874,7 +1880,9 @@ class FederationEventHandler: room_version, [local_state_id_map, claimed_auth_events_id_map], event_map=None, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_deletion_store + ), ) ) else: @@ -2014,7 +2022,9 @@ class FederationEventHandler: room_version, state_sets, event_map=None, - state_res_store=StateResolutionStore(self._store), + state_res_store=StateResolutionStore( + self._store, self._state_deletion_store + ), ) ) else: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 72b291889b..5b746f2037 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -59,11 +59,13 @@ from synapse.types.state import StateFilter from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func +from synapse.util.stringutils import shortstr if TYPE_CHECKING: from synapse.server import HomeServer from synapse.storage.controllers import StateStorageController from synapse.storage.databases.main import DataStore + from synapse.storage.databases.state.deletion import StateDeletionDataStore logger = logging.getLogger(__name__) metrics_logger = logging.getLogger("synapse.state.metrics") @@ -194,6 +196,8 @@ class StateHandler: self._storage_controllers = hs.get_storage_controllers() self._events_shard_config = hs.config.worker.events_shard_config self._instance_name = hs.get_instance_name() + self._state_store = hs.get_datastores().state + self._state_deletion_store = hs.get_datastores().state_deletion self._update_current_state_client = ( ReplicationUpdateCurrentStateRestServlet.make_client(hs) @@ -475,7 +479,10 @@ class StateHandler: @trace @measure_func() async def resolve_state_groups_for_events( - self, room_id: str, event_ids: StrCollection, await_full_state: bool = True + self, + room_id: str, + event_ids: StrCollection, + await_full_state: bool = True, ) -> _StateCacheEntry: """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -511,6 +518,17 @@ class StateHandler: ) = await self._state_storage_controller.get_state_group_delta( state_group_id ) + + if prev_group: + # Ensure that we still have the prev group, and ensure we don't + # delete it while we're persisting the event. + missing_state_group = await self._state_deletion_store.check_state_groups_and_bump_deletion( + {prev_group} + ) + if missing_state_group: + prev_group = None + delta_ids = None + return _StateCacheEntry( state=None, state_group=state_group_id, @@ -531,7 +549,9 @@ class StateHandler: room_version, state_to_resolve, None, - state_res_store=StateResolutionStore(self.store), + state_res_store=StateResolutionStore( + self.store, self._state_deletion_store + ), ) return result @@ -663,7 +683,25 @@ class StateResolutionHandler: async with self.resolve_linearizer.queue(group_names): cache = self._state_cache.get(group_names, None) if cache: - return cache + # Check that the returned cache entry doesn't point to deleted + # state groups. + state_groups_to_check = set() + if cache.state_group is not None: + state_groups_to_check.add(cache.state_group) + + if cache.prev_group is not None: + state_groups_to_check.add(cache.prev_group) + + missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion( + state_groups_to_check + ) + + if not missing_state_groups: + return cache + else: + # There are missing state groups, so let's remove the stale + # entry and continue as if it was a cache miss. + self._state_cache.pop(group_names, None) logger.info( "Resolving state for %s with groups %s", @@ -671,6 +709,16 @@ class StateResolutionHandler: list(group_names), ) + # We double check that none of the state groups have been deleted. + # They shouldn't be as all these state groups should be referenced. + missing_state_groups = await state_res_store.state_deletion_store.check_state_groups_and_bump_deletion( + group_names + ) + if missing_state_groups: + raise Exception( + f"State groups have been deleted: {shortstr(missing_state_groups)}" + ) + state_groups_histogram.observe(len(state_groups_ids)) new_state = await self.resolve_events_with_store( @@ -884,7 +932,8 @@ class StateResolutionStore: in well defined way. """ - store: "DataStore" + main_store: "DataStore" + state_deletion_store: "StateDeletionDataStore" def get_events( self, event_ids: StrCollection, allow_rejected: bool = False @@ -899,7 +948,7 @@ class StateResolutionStore: An awaitable which resolves to a dict from event_id to event. """ - return self.store.get_events( + return self.main_store.get_events( event_ids, redact_behaviour=EventRedactBehaviour.as_is, get_prev_content=False, @@ -920,4 +969,4 @@ class StateResolutionStore: An awaitable that resolves to a set of event IDs. """ - return self.store.get_auth_chain_difference(room_id, state_sets) + return self.main_store.get_auth_chain_difference(room_id, state_sets) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 879ee9039e..7963905479 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -332,6 +332,7 @@ class EventsPersistenceStorageController: # store for now. self.main_store = stores.main self.state_store = stores.state + self._state_deletion_store = stores.state_deletion assert stores.persist_events self.persist_events_store = stores.persist_events @@ -549,7 +550,9 @@ class EventsPersistenceStorageController: room_version, state_maps_by_state_group, event_map=None, - state_res_store=StateResolutionStore(self.main_store), + state_res_store=StateResolutionStore( + self.main_store, self._state_deletion_store + ), ) return await res.get_state(self._state_controller, StateFilter.all()) @@ -635,15 +638,20 @@ class EventsPersistenceStorageController: room_id, [e for e, _ in chunk] ) - await self.persist_events_store._persist_events_and_state_updates( - room_id, - chunk, - state_delta_for_room=state_delta_for_room, - new_forward_extremities=new_forward_extremities, - use_negative_stream_ordering=backfilled, - inhibit_local_membership_updates=backfilled, - new_event_links=new_event_links, - ) + # Stop the state groups from being deleted while we're persisting + # them. + async with self._state_deletion_store.persisting_state_group_references( + events_and_contexts + ): + await self.persist_events_store._persist_events_and_state_updates( + room_id, + chunk, + state_delta_for_room=state_delta_for_room, + new_forward_extremities=new_forward_extremities, + use_negative_stream_ordering=backfilled, + inhibit_local_membership_updates=backfilled, + new_event_links=new_event_links, + ) return replaced_events @@ -965,7 +973,9 @@ class EventsPersistenceStorageController: room_version, state_groups, events_map, - state_res_store=StateResolutionStore(self.main_store), + state_res_store=StateResolutionStore( + self.main_store, self._state_deletion_store + ), ) state_resolutions_during_persistence.inc() diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py index dd9fc01fb0..81886ff765 100644 --- a/synapse/storage/databases/__init__.py +++ b/synapse/storage/databases/__init__.py @@ -26,6 +26,7 @@ from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool, make_conn from synapse.storage.databases.main.events import PersistEventsStore from synapse.storage.databases.state import StateGroupDataStore +from synapse.storage.databases.state.deletion import StateDeletionDataStore from synapse.storage.engines import create_engine from synapse.storage.prepare_database import prepare_database @@ -49,12 +50,14 @@ class Databases(Generic[DataStoreT]): main state persist_events + state_deletion """ databases: List[DatabasePool] main: "DataStore" # FIXME: https://github.com/matrix-org/synapse/issues/11165: actually an instance of `main_store_class` state: StateGroupDataStore persist_events: Optional[PersistEventsStore] + state_deletion: StateDeletionDataStore def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): # Note we pass in the main store class here as workers use a different main @@ -63,6 +66,7 @@ class Databases(Generic[DataStoreT]): self.databases = [] main: Optional[DataStoreT] = None state: Optional[StateGroupDataStore] = None + state_deletion: Optional[StateDeletionDataStore] = None persist_events: Optional[PersistEventsStore] = None for database_config in hs.config.database.databases: @@ -114,7 +118,8 @@ class Databases(Generic[DataStoreT]): if state: raise Exception("'state' data store already configured") - state = StateGroupDataStore(database, db_conn, hs) + state_deletion = StateDeletionDataStore(database, db_conn, hs) + state = StateGroupDataStore(database, db_conn, hs, state_deletion) db_conn.commit() @@ -135,7 +140,7 @@ class Databases(Generic[DataStoreT]): if not main: raise Exception("No 'main' database configured") - if not state: + if not state or not state_deletion: raise Exception("No 'state' database configured") # We use local variables here to ensure that the databases do not have @@ -143,3 +148,4 @@ class Databases(Generic[DataStoreT]): self.main = main # type: ignore[assignment] self.state = state self.persist_events = persist_events + self.state_deletion = state_deletion diff --git a/synapse/storage/databases/state/deletion.py b/synapse/storage/databases/state/deletion.py new file mode 100644 index 0000000000..07dbbc8e75 --- /dev/null +++ b/synapse/storage/databases/state/deletion.py @@ -0,0 +1,446 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + + +import contextlib +from typing import ( + TYPE_CHECKING, + AbstractSet, + AsyncIterator, + Collection, + Mapping, + Set, + Tuple, +) + +from synapse.events import EventBase +from synapse.events.snapshot import EventContext +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) +from synapse.storage.engines import PostgresEngine +from synapse.util.stringutils import shortstr + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +class StateDeletionDataStore: + """Manages deletion of state groups in a safe manner. + + Deleting state groups is challenging as before we actually delete them we + need to ensure that there are no in-flight events that refer to the state + groups that we want to delete. + + To handle this, we take two approaches. First, before we persist any event + we ensure that the state group still exists and mark in the + `state_groups_persisting` table that the state group is about to be used. + (Note that we have to have the extra table here as state groups and events + can be in different databases, and thus we can't check for the existence of + state groups in the persist event transaction). Once the event has been + persisted, we can remove the row from `state_groups_persisting`. So long as + we check that table before deleting state groups, we can ensure that we + never persist events that reference deleted state groups, maintaining + database integrity. + + However, we want to avoid throwing exceptions so deep in the process of + persisting events. So instead of deleting state groups immediately, we mark + them as pending/proposed for deletion and wait for a certain amount of time + before performing the deletion. When we come to handle new events that + reference state groups, we check if they are pending deletion and bump the + time for when they'll be deleted (to give a chance for the event to be + persisted, or not). + + When deleting, we need to check that state groups remain unreferenced. There + is a race here where we a) fetch state groups that are ready for deletion, + b) check they're unreferenced, c) the state group becomes referenced but + then gets marked as pending deletion again, d) during the deletion + transaction we recheck `state_groups_pending_deletion` table again and see + that it exists and so continue with the deletion. To prevent this from + happening we add a `sequence_number` column to + `state_groups_pending_deletion`, and during deletion we ensure that for a + state group we're about to delete that the sequence number doesn't change + between steps (a) and (d). So long as we always bump the sequence number + whenever an event may become used the race can never happen. + """ + + # How long to wait before we delete state groups. This should be long enough + # for any in-flight events to be persisted. If events take longer to persist + # and any of the state groups they reference have been deleted, then the + # event will fail to persist (as well as any event in the same batch). + DELAY_BEFORE_DELETION_MS = 10 * 60 * 1000 + + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + self._clock = hs.get_clock() + self.db_pool = database + self._instance_name = hs.get_instance_name() + + # TODO: Clear from `state_groups_persisting` any holdovers from previous + # running instance. + + async def check_state_groups_and_bump_deletion( + self, state_groups: AbstractSet[int] + ) -> Collection[int]: + """Checks to make sure that the state groups haven't been deleted, and + if they're pending deletion we delay it (allowing time for any event + that will use them to finish persisting). + + Returns: + The state groups that are missing, if any. + """ + + return await self.db_pool.runInteraction( + "check_state_groups_and_bump_deletion", + self._check_state_groups_and_bump_deletion_txn, + state_groups, + ) + + def _check_state_groups_and_bump_deletion_txn( + self, txn: LoggingTransaction, state_groups: AbstractSet[int] + ) -> Collection[int]: + existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups) + + self._bump_deletion_txn(txn, existing_state_groups) + + missing_state_groups = state_groups - existing_state_groups + if missing_state_groups: + return missing_state_groups + + return () + + def _bump_deletion_txn( + self, txn: LoggingTransaction, state_groups: Collection[int] + ) -> None: + """Update any pending deletions of the state group that they may now be + referenced.""" + + if not state_groups: + return + + now = self._clock.time_msec() + if isinstance(self.db_pool.engine, PostgresEngine): + clause, args = make_in_list_sql_clause( + self.db_pool.engine, "state_group", state_groups + ) + sql = f""" + UPDATE state_groups_pending_deletion + SET sequence_number = DEFAULT, insertion_ts = ? + WHERE {clause} + """ + args.insert(0, now) + txn.execute(sql, args) + else: + rows = self.db_pool.simple_select_many_txn( + txn, + table="state_groups_pending_deletion", + column="state_group", + iterable=state_groups, + keyvalues={}, + retcols=("state_group",), + ) + if not rows: + return + + state_groups_to_update = [state_group for (state_group,) in rows] + + self.db_pool.simple_delete_many_txn( + txn, + table="state_groups_pending_deletion", + column="state_group", + values=state_groups_to_update, + keyvalues={}, + ) + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_pending_deletion", + keys=("state_group", "insertion_ts"), + values=[(state_group, now) for state_group in state_groups_to_update], + ) + + def _get_existing_groups_with_lock( + self, txn: LoggingTransaction, state_groups: Collection[int] + ) -> AbstractSet[int]: + """Return which of the given state groups are in the database, and locks + those rows with `KEY SHARE` to ensure they don't get concurrently + deleted.""" + clause, args = make_in_list_sql_clause(self.db_pool.engine, "id", state_groups) + + sql = f""" + SELECT id FROM state_groups + WHERE {clause} + """ + if isinstance(self.db_pool.engine, PostgresEngine): + # On postgres we add a row level lock to the rows to ensure that we + # conflict with any concurrent DELETEs. `FOR KEY SHARE` lock will + # not conflict with other read + sql += """ + FOR KEY SHARE + """ + + txn.execute(sql, args) + return {state_group for (state_group,) in txn} + + @contextlib.asynccontextmanager + async def persisting_state_group_references( + self, event_and_contexts: Collection[Tuple[EventBase, EventContext]] + ) -> AsyncIterator[None]: + """Wraps the persistence of the given events and contexts, ensuring that + any state groups referenced still exist and that they don't get deleted + during this.""" + + referenced_state_groups: Set[int] = set() + for event, ctx in event_and_contexts: + if ctx.rejected or event.internal_metadata.is_outlier(): + continue + + assert ctx.state_group is not None + + referenced_state_groups.add(ctx.state_group) + + if ctx.state_group_before_event: + referenced_state_groups.add(ctx.state_group_before_event) + + if not referenced_state_groups: + # We don't reference any state groups, so nothing to do + yield + return + + await self.db_pool.runInteraction( + "mark_state_groups_as_persisting", + self._mark_state_groups_as_persisting_txn, + referenced_state_groups, + ) + + error = True + try: + yield None + error = False + finally: + await self.db_pool.runInteraction( + "finish_persisting", + self._finish_persisting_txn, + referenced_state_groups, + error=error, + ) + + def _mark_state_groups_as_persisting_txn( + self, txn: LoggingTransaction, state_groups: Set[int] + ) -> None: + """Marks the given state groups as being persisted.""" + + existing_state_groups = self._get_existing_groups_with_lock(txn, state_groups) + missing_state_groups = state_groups - existing_state_groups + if missing_state_groups: + raise Exception( + f"state groups have been deleted: {shortstr(missing_state_groups)}" + ) + + self.db_pool.simple_insert_many_txn( + txn, + table="state_groups_persisting", + keys=("state_group", "instance_name"), + values=[(state_group, self._instance_name) for state_group in state_groups], + ) + + def _finish_persisting_txn( + self, txn: LoggingTransaction, state_groups: Collection[int], error: bool + ) -> None: + """Mark the state groups as having finished persistence. + + If `error` is true then we assume the state groups were not persisted, + and so we do not clear them from the pending deletion table. + """ + self.db_pool.simple_delete_many_txn( + txn, + table="state_groups_persisting", + column="state_group", + values=state_groups, + keyvalues={"instance_name": self._instance_name}, + ) + + if error: + # The state groups may or may not have been persisted, so we need to + # bump the deletion to ensure we recheck if they have become + # referenced. + self._bump_deletion_txn(txn, state_groups) + return + + self.db_pool.simple_delete_many_batch_txn( + txn, + table="state_groups_pending_deletion", + keys=("state_group",), + values=[(state_group,) for state_group in state_groups], + ) + + async def mark_state_groups_as_pending_deletion( + self, state_groups: Collection[int] + ) -> None: + """Mark the given state groups as pending deletion""" + + now = self._clock.time_msec() + + await self.db_pool.simple_upsert_many( + table="state_groups_pending_deletion", + key_names=("state_group",), + key_values=[(state_group,) for state_group in state_groups], + value_names=("insertion_ts",), + value_values=[(now,) for _ in state_groups], + desc="mark_state_groups_as_pending_deletion", + ) + + async def get_pending_deletions( + self, state_groups: Collection[int] + ) -> Mapping[int, int]: + """Get which state groups are pending deletion. + + Returns: + a mapping from state groups that are pending deletion to their + sequence number + """ + + rows = await self.db_pool.simple_select_many_batch( + table="state_groups_pending_deletion", + column="state_group", + iterable=state_groups, + retcols=("state_group", "sequence_number"), + keyvalues={}, + desc="get_pending_deletions", + ) + + return dict(rows) + + def get_state_groups_ready_for_potential_deletion_txn( + self, + txn: LoggingTransaction, + state_groups_to_sequence_numbers: Mapping[int, int], + ) -> Collection[int]: + """Given a set of state groups, return which state groups can + potentially be deleted. + + The state groups must have been checked to see if they remain + unreferenced before calling this function. + + Note: This must be called within the same transaction that the state + groups are deleted. + + Args: + state_groups_to_sequence_numbers: The state groups, and the sequence + numbers from before the state groups were checked to see if they + were unreferenced. + + Returns: + The subset of state groups that can safely be deleted + + """ + + if not state_groups_to_sequence_numbers: + return state_groups_to_sequence_numbers + + if isinstance(self.db_pool.engine, PostgresEngine): + # On postgres we want to lock the rows FOR UPDATE as early as + # possible to help conflicts. + clause, args = make_in_list_sql_clause( + self.db_pool.engine, "id", state_groups_to_sequence_numbers + ) + sql = f""" + SELECT id FROM state_groups + WHERE {clause} + FOR UPDATE + """ + txn.execute(sql, args) + + # Check the deletion status in the DB of the given state groups + clause, args = make_in_list_sql_clause( + self.db_pool.engine, + column="state_group", + iterable=state_groups_to_sequence_numbers, + ) + + sql = f""" + SELECT state_group, insertion_ts, sequence_number FROM ( + SELECT state_group, insertion_ts, sequence_number FROM state_groups_pending_deletion + UNION + SELECT state_group, null, null FROM state_groups_persisting + ) AS s + WHERE {clause} + """ + + txn.execute(sql, args) + + # The above query will return potentially two rows per state group (one + # for each table), so we track which state groups have enough time + # elapsed and which are not ready to be persisted. + ready_to_be_deleted = set() + not_ready_to_be_deleted = set() + + now = self._clock.time_msec() + for state_group, insertion_ts, sequence_number in txn: + if insertion_ts is None: + # A null insertion_ts means that we are currently persisting + # events that reference the state group, so we don't delete + # them. + not_ready_to_be_deleted.add(state_group) + continue + + # We know this can't be None if insertion_ts is not None + assert sequence_number is not None + + # Check if the sequence number has changed, if it has then it + # indicates that the state group may have become referenced since we + # checked. + if state_groups_to_sequence_numbers[state_group] != sequence_number: + not_ready_to_be_deleted.add(state_group) + continue + + if now - insertion_ts < self.DELAY_BEFORE_DELETION_MS: + # Not enough time has elapsed to allow us to delete. + not_ready_to_be_deleted.add(state_group) + continue + + ready_to_be_deleted.add(state_group) + + can_be_deleted = ready_to_be_deleted - not_ready_to_be_deleted + if not_ready_to_be_deleted: + # If there are any state groups that aren't ready to be deleted, + # then we also need to remove any state groups that are referenced + # by them. + clause, args = make_in_list_sql_clause( + self.db_pool.engine, + column="state_group", + iterable=state_groups_to_sequence_numbers, + ) + sql = f""" + WITH RECURSIVE ancestors(state_group) AS ( + SELECT DISTINCT prev_state_group + FROM state_group_edges WHERE {clause} + UNION + SELECT prev_state_group + FROM state_group_edges + INNER JOIN ancestors USING (state_group) + ) + SELECT state_group FROM ancestors + """ + txn.execute(sql, args) + + can_be_deleted.difference_update(state_group for (state_group,) in txn) + + return can_be_deleted diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 9944f90015..7e986e0601 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -36,7 +36,10 @@ import attr from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.events.snapshot import UnpersistedEventContext, UnpersistedEventContextBase +from synapse.events.snapshot import ( + UnpersistedEventContext, + UnpersistedEventContextBase, +) from synapse.logging.opentracing import tag_args, trace from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -55,6 +58,7 @@ from synapse.util.cancellation import cancellable if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.state.deletion import StateDeletionDataStore logger = logging.getLogger(__name__) @@ -83,8 +87,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): database: DatabasePool, db_conn: LoggingDatabaseConnection, hs: "HomeServer", + state_deletion_store: "StateDeletionDataStore", ): super().__init__(database, db_conn, hs) + self._state_deletion_store = state_deletion_store # Originally the state store used a single DictionaryCache to cache the # event IDs for the state types in a given state group to avoid hammering @@ -467,14 +473,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): Returns: A list of state groups """ - is_in_db = self.db_pool.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, + + # We need to check that the prev group isn't about to be deleted + is_missing = ( + self._state_deletion_store._check_state_groups_and_bump_deletion_txn( + txn, + {prev_group}, + ) ) - if not is_in_db: + if is_missing: raise Exception( "Trying to persist state with unpersisted prev_group: %r" % (prev_group,) @@ -546,6 +553,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): for key, state_id in context.state_delta_due_to_event.items() ], ) + return events_and_context return await self.db_pool.runInteraction( @@ -601,14 +609,15 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): The state group if successfully created, or None if the state needs to be persisted as a full state. """ - is_in_db = self.db_pool.simple_select_one_onecol_txn( - txn, - table="state_groups", - keyvalues={"id": prev_group}, - retcol="id", - allow_none=True, + + # We need to check that the prev group isn't about to be deleted + is_missing = ( + self._state_deletion_store._check_state_groups_and_bump_deletion_txn( + txn, + {prev_group}, + ) ) - if not is_in_db: + if is_missing: raise Exception( "Trying to persist state with unpersisted prev_group: %r" % (prev_group,) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 934e1ccced..49e648a92f 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -19,7 +19,7 @@ # # -SCHEMA_VERSION = 88 # remember to update the list below when updating +SCHEMA_VERSION = 89 # remember to update the list below when updating """Represents the expectations made by the codebase about the database schema This should be incremented whenever the codebase changes its requirements on the @@ -155,6 +155,9 @@ Changes in SCHEMA_VERSION = 88 be posted in response to a resettable timeout or an on-demand action. - Add background update to fix data integrity issue in the `sliding_sync_membership_snapshots` -> `forgotten` column + +Changes in SCHEMA_VERSION = 89 + - Add `state_groups_pending_deletion` and `state_groups_persisting` tables. """ diff --git a/synapse/storage/schema/state/delta/89/01_state_groups_deletion.sql b/synapse/storage/schema/state/delta/89/01_state_groups_deletion.sql new file mode 100644 index 0000000000..d4cb27a3a2 --- /dev/null +++ b/synapse/storage/schema/state/delta/89/01_state_groups_deletion.sql @@ -0,0 +1,39 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- See the `StateDeletionDataStore` for details of these tables. + +-- We add state groups to this table when we want to later delete them. The +-- `insertion_ts` column indicates when the state group was proposed for +-- deletion (rather than when it should be deleted). +CREATE TABLE IF NOT EXISTS state_groups_pending_deletion ( + sequence_number $%AUTO_INCREMENT_PRIMARY_KEY%$, + state_group BIGINT NOT NULL, + insertion_ts BIGINT NOT NULL +); + +CREATE UNIQUE INDEX state_groups_pending_deletion_state_group ON state_groups_pending_deletion(state_group); +CREATE INDEX state_groups_pending_deletion_insertion_ts ON state_groups_pending_deletion(insertion_ts); + + +-- Holds the state groups the worker is currently persisting. +-- +-- The `sequence_number` column of the `state_groups_pending_deletion` table +-- *must* be updated whenever a state group may have become referenced. +CREATE TABLE IF NOT EXISTS state_groups_persisting ( + state_group BIGINT NOT NULL, + instance_name TEXT NOT NULL, + PRIMARY KEY (state_group, instance_name) +); + +CREATE INDEX state_groups_persisting_instance_name ON state_groups_persisting(instance_name); diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 61b0efb87e..51eca56c3b 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -807,6 +807,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" main_store = self.hs.get_datastores().main + state_deletion_store = self.hs.get_datastores().state_deletion # Create the room. kermit_user_id = self.register_user("kermit", "test") @@ -958,7 +959,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): bert_member_event.event_id: bert_member_event, rejected_kick_event.event_id: rejected_kick_event, }, - state_res_store=StateResolutionStore(main_store), + state_res_store=StateResolutionStore( + main_store, state_deletion_store + ), ) ), [bert_member_event.event_id, rejected_kick_event.event_id], @@ -1003,7 +1006,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): rejected_power_levels_event.event_id, ], event_map={}, - state_res_store=StateResolutionStore(main_store), + state_res_store=StateResolutionStore( + main_store, state_deletion_store + ), full_conflicted_set=set(), ) ), diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 833bd6fff8..3bb539bf87 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -742,7 +742,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(34, channel.resource_usage.db_txn_count) + self.assertEqual(36, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -755,7 +755,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(36, channel.resource_usage.db_txn_count) + self.assertEqual(38, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id diff --git a/tests/storage/test_state_deletion.py b/tests/storage/test_state_deletion.py new file mode 100644 index 0000000000..19b290b554 --- /dev/null +++ b/tests/storage/test_state_deletion.py @@ -0,0 +1,411 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + + +import logging + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock + +from tests.test_utils.event_injection import create_event +from tests.unittest import HomeserverTestCase + +logger = logging.getLogger(__name__) + + +class StateDeletionStoreTestCase(HomeserverTestCase): + """Tests for the StateDeletionStore.""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.state_store = hs.get_datastores().state + self.state_deletion_store = hs.get_datastores().state_deletion + + self.user_id = self.register_user("test", "password") + tok = self.login("test", "password") + self.room_id = self.helper.create_room_as(self.user_id, tok=tok) + + def check_if_can_be_deleted(self, state_group: int) -> bool: + """Check if the state group is pending deletion.""" + + state_group_to_sequence_number = self.get_success( + self.state_deletion_store.get_pending_deletions([state_group]) + ) + + can_be_deleted = self.get_success( + self.state_deletion_store.db_pool.runInteraction( + "test_existing_pending_deletion_is_cleared", + self.state_deletion_store.get_state_groups_ready_for_potential_deletion_txn, + state_group_to_sequence_number, + ) + ) + + return state_group in can_be_deleted + + def test_no_deletion(self) -> None: + """Test that calling persisting_state_group_references is fine if + nothing is pending deletion""" + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + sender=self.user_id, + ) + ) + + ctx_mgr = self.state_deletion_store.persisting_state_group_references( + [(event, context)] + ) + + self.get_success(ctx_mgr.__aenter__()) + self.get_success(ctx_mgr.__aexit__(None, None, None)) + + def test_no_deletion_error(self) -> None: + """Test that calling persisting_state_group_references is fine if + nothing is pending deletion, but an error occurs.""" + + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + sender=self.user_id, + ) + ) + + ctx_mgr = self.state_deletion_store.persisting_state_group_references( + [(event, context)] + ) + + self.get_success(ctx_mgr.__aenter__()) + self.get_success(ctx_mgr.__aexit__(Exception, Exception("test"), None)) + + def test_existing_pending_deletion_is_cleared(self) -> None: + """Test that the pending deletion flag gets cleared when the state group + gets persisted.""" + + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + state_key="", + sender=self.user_id, + ) + ) + assert context.state_group is not None + + # Mark a state group that we're referencing as pending deletion. + self.get_success( + self.state_deletion_store.mark_state_groups_as_pending_deletion( + [context.state_group] + ) + ) + + ctx_mgr = self.state_deletion_store.persisting_state_group_references( + [(event, context)] + ) + + self.get_success(ctx_mgr.__aenter__()) + self.get_success(ctx_mgr.__aexit__(None, None, None)) + + # The pending deletion flag should be cleared + pending_deletion = self.get_success( + self.state_deletion_store.db_pool.simple_select_one_onecol( + table="state_groups_pending_deletion", + keyvalues={"state_group": context.state_group}, + retcol="1", + allow_none=True, + desc="test_existing_pending_deletion_is_cleared", + ) + ) + self.assertIsNone(pending_deletion) + + def test_pending_deletion_is_cleared_during_persist(self) -> None: + """Test that the pending deletion flag is cleared when a state group + gets marked for deletion during persistence""" + + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + state_key="", + sender=self.user_id, + ) + ) + assert context.state_group is not None + + ctx_mgr = self.state_deletion_store.persisting_state_group_references( + [(event, context)] + ) + self.get_success(ctx_mgr.__aenter__()) + + # Mark the state group that we're referencing as pending deletion, + # *after* we have started persisting. + self.get_success( + self.state_deletion_store.mark_state_groups_as_pending_deletion( + [context.state_group] + ) + ) + + self.get_success(ctx_mgr.__aexit__(None, None, None)) + + # The pending deletion flag should be cleared + pending_deletion = self.get_success( + self.state_deletion_store.db_pool.simple_select_one_onecol( + table="state_groups_pending_deletion", + keyvalues={"state_group": context.state_group}, + retcol="1", + allow_none=True, + desc="test_existing_pending_deletion_is_cleared", + ) + ) + self.assertIsNone(pending_deletion) + + def test_deletion_check(self) -> None: + """Test that the `get_state_groups_that_can_be_purged_txn` check is + correct during different points of the lifecycle of persisting an + event.""" + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + state_key="", + sender=self.user_id, + ) + ) + assert context.state_group is not None + + self.get_success( + self.state_deletion_store.mark_state_groups_as_pending_deletion( + [context.state_group] + ) + ) + + # We shouldn't be able to delete the state group as not enough time as passed + can_be_deleted = self.check_if_can_be_deleted(context.state_group) + self.assertFalse(can_be_deleted) + + # After enough time we can delete the state group + self.reactor.advance( + 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000 + ) + can_be_deleted = self.check_if_can_be_deleted(context.state_group) + self.assertTrue(can_be_deleted) + + ctx_mgr = self.state_deletion_store.persisting_state_group_references( + [(event, context)] + ) + self.get_success(ctx_mgr.__aenter__()) + + # But once we start persisting we can't delete the state group + can_be_deleted = self.check_if_can_be_deleted(context.state_group) + self.assertFalse(can_be_deleted) + + self.get_success(ctx_mgr.__aexit__(None, None, None)) + + # The pending deletion flag should remain cleared after persistence has + # finished. + can_be_deleted = self.check_if_can_be_deleted(context.state_group) + self.assertFalse(can_be_deleted) + + def test_deletion_error_during_persistence(self) -> None: + """Test that state groups remain marked as pending deletion if persisting + the event fails.""" + + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + state_key="", + sender=self.user_id, + ) + ) + assert context.state_group is not None + + # Mark a state group that we're referencing as pending deletion. + self.get_success( + self.state_deletion_store.mark_state_groups_as_pending_deletion( + [context.state_group] + ) + ) + + ctx_mgr = self.state_deletion_store.persisting_state_group_references( + [(event, context)] + ) + + self.get_success(ctx_mgr.__aenter__()) + self.get_success(ctx_mgr.__aexit__(Exception, Exception("test"), None)) + + # We should be able to delete the state group after a certain amount of + # time + self.reactor.advance( + 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000 + ) + can_be_deleted = self.check_if_can_be_deleted(context.state_group) + self.assertTrue(can_be_deleted) + + def test_race_between_check_and_insert(self) -> None: + """Check that we correctly handle the race where we go to delete a + state group, check that it is unreferenced, and then it becomes + referenced just before we delete it.""" + + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + state_key="", + sender=self.user_id, + ) + ) + assert context.state_group is not None + + # Mark a state group that we're referencing as pending deletion. + self.get_success( + self.state_deletion_store.mark_state_groups_as_pending_deletion( + [context.state_group] + ) + ) + + # Advance time enough so we can delete the state group + self.reactor.advance( + 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000 + ) + + # Check that we'd be able to delete this state group. + state_group_to_sequence_number = self.get_success( + self.state_deletion_store.get_pending_deletions([context.state_group]) + ) + + can_be_deleted = self.get_success( + self.state_deletion_store.db_pool.runInteraction( + "test_existing_pending_deletion_is_cleared", + self.state_deletion_store.get_state_groups_ready_for_potential_deletion_txn, + state_group_to_sequence_number, + ) + ) + self.assertCountEqual(can_be_deleted, [context.state_group]) + + # ... in the real world we'd check that the state group isn't referenced here ... + + # Now we persist the event to reference the state group, *after* we + # check that the state group wasn't referenced + ctx_mgr = self.state_deletion_store.persisting_state_group_references( + [(event, context)] + ) + + self.get_success(ctx_mgr.__aenter__()) + self.get_success(ctx_mgr.__aexit__(Exception, Exception("test"), None)) + + # We simulate a pause (required to hit the race) + self.reactor.advance( + 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000 + ) + + # We should no longer be able to delete the state group, without having + # to recheck if its referenced. + can_be_deleted = self.get_success( + self.state_deletion_store.db_pool.runInteraction( + "test_existing_pending_deletion_is_cleared", + self.state_deletion_store.get_state_groups_ready_for_potential_deletion_txn, + state_group_to_sequence_number, + ) + ) + self.assertCountEqual(can_be_deleted, []) + + def test_remove_ancestors_from_can_delete(self) -> None: + """Test that if a state group is not ready to be deleted, we also don't + delete anything that is refernced by it""" + + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + state_key="", + sender=self.user_id, + ) + ) + assert context.state_group is not None + + # Create a new state group that refernces the one from the event + new_state_group = self.get_success( + self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=context.state_group, + delta_ids={}, + current_state_ids=None, + ) + ) + + # Mark them both as pending deletion + self.get_success( + self.state_deletion_store.mark_state_groups_as_pending_deletion( + [context.state_group, new_state_group] + ) + ) + + # Advance time enough so we can delete the state group so they're both + # ready for deletion. + self.reactor.advance( + 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000 + ) + + # We can now delete both state groups + self.assertTrue(self.check_if_can_be_deleted(context.state_group)) + self.assertTrue(self.check_if_can_be_deleted(new_state_group)) + + # Use the new_state_group to bump its deletion time + self.get_success( + self.state_store.store_state_group( + event.event_id, + event.room_id, + prev_group=new_state_group, + delta_ids={}, + current_state_ids=None, + ) + ) + + # We should now not be able to delete either of the state groups. + state_group_to_sequence_number = self.get_success( + self.state_deletion_store.get_pending_deletions( + [context.state_group, new_state_group] + ) + ) + + # We shouldn't be able to delete the state group as not enough time has passed + can_be_deleted = self.get_success( + self.state_deletion_store.db_pool.runInteraction( + "test_existing_pending_deletion_is_cleared", + self.state_deletion_store.get_state_groups_ready_for_potential_deletion_txn, + state_group_to_sequence_number, + ) + ) + self.assertCountEqual(can_be_deleted, []) diff --git a/tests/test_state.py b/tests/test_state.py index 311a590693..dce56fe78a 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -31,7 +31,7 @@ from typing import ( Tuple, cast, ) -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.internet import defer @@ -221,7 +221,16 @@ class Graph: class StateTestCase(unittest.TestCase): def setUp(self) -> None: self.dummy_store = _DummyStore() - storage_controllers = Mock(main=self.dummy_store, state=self.dummy_store) + + # Add a dummy epoch store that always retruns that we have all the + # necessary state groups. + dummy_deletion_store = AsyncMock() + dummy_deletion_store.check_state_groups_and_bump_deletion.return_value = [] + + storage_controllers = Mock( + main=self.dummy_store, + state=self.dummy_store, + ) hs = Mock( spec_set=[ "config", @@ -241,7 +250,10 @@ class StateTestCase(unittest.TestCase): ) clock = cast(Clock, MockClock()) hs.config = default_config("tesths", True) - hs.get_datastores.return_value = Mock(main=self.dummy_store) + hs.get_datastores.return_value = Mock( + main=self.dummy_store, + state_deletion=dummy_deletion_store, + ) hs.get_state_handler.return_value = None hs.get_clock.return_value = clock hs.get_macaroon_generator.return_value = MacaroonGenerator( From 27dbb1b4290b9de64e24a11f892777378810b595 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 3 Feb 2025 18:58:55 +0100 Subject: [PATCH 15/16] Add locking to more safely delete state groups: Part 2 (#18130) This actually makes it so that deleting state groups goes via the new mechanism. c.f. #18107 --- changelog.d/18130.bugfix | 1 + synapse/storage/controllers/purge_events.py | 80 +++++++++++++++++++-- synapse/storage/databases/state/deletion.py | 65 +++++++++++++++++ synapse/storage/databases/state/store.py | 28 ++++++-- tests/rest/client/utils.py | 2 +- tests/storage/test_purge.py | 67 +++++++++++++++++ tests/storage/test_state_deletion.py | 68 +++++++++++++++++- 7 files changed, 297 insertions(+), 14 deletions(-) create mode 100644 changelog.d/18130.bugfix diff --git a/changelog.d/18130.bugfix b/changelog.d/18130.bugfix new file mode 100644 index 0000000000..4d0c19fab9 --- /dev/null +++ b/changelog.d/18130.bugfix @@ -0,0 +1 @@ +Fix rare edge case where state groups could be deleted while we are persisting new events that reference them. diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py index 15c04ffef8..2d6f80f770 100644 --- a/synapse/storage/controllers/purge_events.py +++ b/synapse/storage/controllers/purge_events.py @@ -21,9 +21,10 @@ import itertools import logging -from typing import TYPE_CHECKING, Set +from typing import TYPE_CHECKING, Collection, Mapping, Set from synapse.logging.context import nested_logging_context +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.databases import Databases if TYPE_CHECKING: @@ -38,6 +39,11 @@ class PurgeEventsStorageController: def __init__(self, hs: "HomeServer", stores: Databases): self.stores = stores + if hs.config.worker.run_background_tasks: + self._delete_state_loop_call = hs.get_clock().looping_call( + self._delete_state_groups_loop, 60 * 1000 + ) + async def purge_room(self, room_id: str) -> None: """Deletes all record of a room""" @@ -68,11 +74,15 @@ class PurgeEventsStorageController: logger.info("[purge] finding state groups that can be deleted") sg_to_delete = await self._find_unreferenced_groups(state_groups) - await self.stores.state.purge_unreferenced_state_groups( - room_id, sg_to_delete + # Mark these state groups as pending deletion, they will actually + # get deleted automatically later. + await self.stores.state_deletion.mark_state_groups_as_pending_deletion( + sg_to_delete ) - async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]: + async def _find_unreferenced_groups( + self, state_groups: Collection[int] + ) -> Set[int]: """Used when purging history to figure out which state groups can be deleted. @@ -121,3 +131,65 @@ class PurgeEventsStorageController: to_delete = state_groups_seen - referenced_groups return to_delete + + @wrap_as_background_process("_delete_state_groups_loop") + async def _delete_state_groups_loop(self) -> None: + """Background task that deletes any state groups that may be pending + deletion.""" + + while True: + next_to_delete = await self.stores.state_deletion.get_next_state_group_collection_to_delete() + if next_to_delete is None: + break + + (room_id, groups_to_sequences) = next_to_delete + made_progress = await self._delete_state_groups( + room_id, groups_to_sequences + ) + + # If no progress was made in deleting the state groups, then we + # break to allow a pause before trying again next time we get + # called. + if not made_progress: + break + + async def _delete_state_groups( + self, room_id: str, groups_to_sequences: Mapping[int, int] + ) -> bool: + """Tries to delete the given state groups. + + Returns: + Whether we made progress in deleting the state groups (or marking + them as referenced). + """ + + # We double check if any of the state groups have become referenced. + # This shouldn't happen, as any usages should cause the state group to + # be removed as pending deletion. + referenced_state_groups = await self.stores.main.get_referenced_state_groups( + groups_to_sequences + ) + + if referenced_state_groups: + # We mark any state groups that have become referenced as being + # used. + await self.stores.state_deletion.mark_state_groups_as_used( + referenced_state_groups + ) + + # Update list of state groups to remove referenced ones + groups_to_sequences = { + state_group: sequence_number + for state_group, sequence_number in groups_to_sequences.items() + if state_group not in referenced_state_groups + } + + if not groups_to_sequences: + # We made progress here as long as we marked some state groups as + # now referenced. + return len(referenced_state_groups) > 0 + + return await self.stores.state.purge_unreferenced_state_groups( + room_id, + groups_to_sequences, + ) diff --git a/synapse/storage/databases/state/deletion.py b/synapse/storage/databases/state/deletion.py index 07dbbc8e75..4853e5aa2f 100644 --- a/synapse/storage/databases/state/deletion.py +++ b/synapse/storage/databases/state/deletion.py @@ -20,6 +20,7 @@ from typing import ( AsyncIterator, Collection, Mapping, + Optional, Set, Tuple, ) @@ -307,6 +308,17 @@ class StateDeletionDataStore: desc="mark_state_groups_as_pending_deletion", ) + async def mark_state_groups_as_used(self, state_groups: Collection[int]) -> None: + """Mark the given state groups as now being referenced""" + + await self.db_pool.simple_delete_many( + table="state_groups_pending_deletion", + column="state_group", + iterable=state_groups, + keyvalues={}, + desc="mark_state_groups_as_used", + ) + async def get_pending_deletions( self, state_groups: Collection[int] ) -> Mapping[int, int]: @@ -444,3 +456,56 @@ class StateDeletionDataStore: can_be_deleted.difference_update(state_group for (state_group,) in txn) return can_be_deleted + + async def get_next_state_group_collection_to_delete( + self, + ) -> Optional[Tuple[str, Mapping[int, int]]]: + """Get the next set of state groups to try and delete + + Returns: + 2-tuple of room_id and mapping of state groups to sequence number. + """ + return await self.db_pool.runInteraction( + "get_next_state_group_collection_to_delete", + self._get_next_state_group_collection_to_delete_txn, + ) + + def _get_next_state_group_collection_to_delete_txn( + self, + txn: LoggingTransaction, + ) -> Optional[Tuple[str, Mapping[int, int]]]: + """Implementation of `get_next_state_group_collection_to_delete`""" + + # We want to return chunks of state groups that were marked for deletion + # at the same time (this isn't necessary, just more efficient). We do + # this by looking for the oldest insertion_ts, and then pulling out all + # rows that have the same insertion_ts (and room ID). + now = self._clock.time_msec() + + sql = """ + SELECT room_id, insertion_ts + FROM state_groups_pending_deletion AS sd + INNER JOIN state_groups AS sg ON (id = sd.state_group) + LEFT JOIN state_groups_persisting AS sp USING (state_group) + WHERE insertion_ts < ? AND sp.state_group IS NULL + ORDER BY insertion_ts + LIMIT 1 + """ + txn.execute(sql, (now - self.DELAY_BEFORE_DELETION_MS,)) + row = txn.fetchone() + if not row: + return None + + (room_id, insertion_ts) = row + + sql = """ + SELECT state_group, sequence_number + FROM state_groups_pending_deletion AS sd + INNER JOIN state_groups AS sg ON (id = sd.state_group) + LEFT JOIN state_groups_persisting AS sp USING (state_group) + WHERE room_id = ? AND insertion_ts = ? AND sp.state_group IS NULL + ORDER BY insertion_ts + """ + txn.execute(sql, (room_id, insertion_ts)) + + return room_id, dict(txn) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 7e986e0601..0f47642ae5 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -22,10 +22,10 @@ import logging from typing import ( TYPE_CHECKING, - Collection, Dict, Iterable, List, + Mapping, Optional, Set, Tuple, @@ -735,8 +735,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): ) async def purge_unreferenced_state_groups( - self, room_id: str, state_groups_to_delete: Collection[int] - ) -> None: + self, + room_id: str, + state_groups_to_sequence_numbers: Mapping[int, int], + ) -> bool: """Deletes no longer referenced state groups and de-deltas any state groups that reference them. @@ -744,21 +746,31 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): room_id: The room the state groups belong to (must all be in the same room). state_groups_to_delete: Set of all state groups to delete. + + Returns: + Whether any state groups were actually deleted. """ - await self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "purge_unreferenced_state_groups", self._purge_unreferenced_state_groups, room_id, - state_groups_to_delete, + state_groups_to_sequence_numbers, ) def _purge_unreferenced_state_groups( self, txn: LoggingTransaction, room_id: str, - state_groups_to_delete: Collection[int], - ) -> None: + state_groups_to_sequence_numbers: Mapping[int, int], + ) -> bool: + state_groups_to_delete = self._state_deletion_store.get_state_groups_ready_for_potential_deletion_txn( + txn, state_groups_to_sequence_numbers + ) + + if not state_groups_to_delete: + return False + logger.info( "[purge] found %i state groups to delete", len(state_groups_to_delete) ) @@ -821,6 +833,8 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): [(sg,) for sg in state_groups_to_delete], ) + return True + @trace @tag_args async def get_previous_state_groups( diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index dbd6049f9f..e766630afb 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -548,7 +548,7 @@ class RestHelper: room_id: str, event_type: str, body: Dict[str, Any], - tok: Optional[str], + tok: Optional[str] = None, expect_code: int = HTTPStatus.OK, state_key: str = "", ) -> JsonDict: diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 080d5640a5..efd8d25bd1 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -23,6 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError, SynapseError from synapse.rest.client import room from synapse.server import HomeServer +from synapse.types.state import StateFilter from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -40,6 +41,8 @@ class PurgeTests(HomeserverTestCase): self.room_id = self.helper.create_room_as(self.user_id) self.store = hs.get_datastores().main + self.state_store = hs.get_datastores().state + self.state_deletion_store = hs.get_datastores().state_deletion self._storage_controllers = self.hs.get_storage_controllers() def test_purge_history(self) -> None: @@ -128,3 +131,67 @@ class PurgeTests(HomeserverTestCase): self.store._invalidate_local_get_event_cache(create_event.event_id) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) + + def test_purge_history_deletes_state_groups(self) -> None: + """Test that unreferenced state groups get cleaned up after purge""" + + # Send four state changes to the room. + first = self.helper.send_state( + self.room_id, event_type="m.foo", body={"test": 1} + ) + second = self.helper.send_state( + self.room_id, event_type="m.foo", body={"test": 2} + ) + third = self.helper.send_state( + self.room_id, event_type="m.foo", body={"test": 3} + ) + last = self.helper.send_state( + self.room_id, event_type="m.foo", body={"test": 4} + ) + + # Get references to the state groups + event_to_groups = self.get_success( + self.store._get_state_group_for_events( + [ + first["event_id"], + second["event_id"], + third["event_id"], + last["event_id"], + ] + ) + ) + + # Get the topological token + token = self.get_success( + self.store.get_topological_token_for_event(last["event_id"]) + ) + token_str = self.get_success(token.to_string(self.hs.get_datastores().main)) + + # Purge everything before this topological token + self.get_success( + self._storage_controllers.purge_events.purge_history( + self.room_id, token_str, True + ) + ) + + # Advance so that the background jobs to delete the state groups runs + self.reactor.advance( + 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000 + ) + + # We expect all the state groups associated with events above, except + # the last one, should return no state. + state_groups = self.get_success( + self.state_store._get_state_groups_from_groups( + list(event_to_groups.values()), StateFilter.all() + ) + ) + first_state = state_groups[event_to_groups[first["event_id"]]] + second_state = state_groups[event_to_groups[second["event_id"]]] + third_state = state_groups[event_to_groups[third["event_id"]]] + last_state = state_groups[event_to_groups[last["event_id"]]] + + self.assertEqual(first_state, {}) + self.assertEqual(second_state, {}) + self.assertEqual(third_state, {}) + self.assertNotEqual(last_state, {}) diff --git a/tests/storage/test_state_deletion.py b/tests/storage/test_state_deletion.py index 19b290b554..a4d318ae20 100644 --- a/tests/storage/test_state_deletion.py +++ b/tests/storage/test_state_deletion.py @@ -41,6 +41,11 @@ class StateDeletionStoreTestCase(HomeserverTestCase): self.store = hs.get_datastores().main self.state_store = hs.get_datastores().state self.state_deletion_store = hs.get_datastores().state_deletion + self.purge_events = hs.get_storage_controllers().purge_events + + # We want to disable the automatic deletion of state groups in the + # background, so we can do controlled tests. + self.purge_events._delete_state_loop_call.stop() self.user_id = self.register_user("test", "password") tok = self.login("test", "password") @@ -341,7 +346,7 @@ class StateDeletionStoreTestCase(HomeserverTestCase): def test_remove_ancestors_from_can_delete(self) -> None: """Test that if a state group is not ready to be deleted, we also don't - delete anything that is refernced by it""" + delete anything that is referenced by it""" event, context = self.get_success( create_event( @@ -354,7 +359,7 @@ class StateDeletionStoreTestCase(HomeserverTestCase): ) assert context.state_group is not None - # Create a new state group that refernces the one from the event + # Create a new state group that references the one from the event new_state_group = self.get_success( self.state_store.store_state_group( event.event_id, @@ -409,3 +414,62 @@ class StateDeletionStoreTestCase(HomeserverTestCase): ) ) self.assertCountEqual(can_be_deleted, []) + + def test_newly_referenced_state_group_gets_removed_from_pending(self) -> None: + """Check that if a state group marked for deletion becomes referenced + (without being removed from pending deletion table), it gets removed + from pending deletion table.""" + + event, context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + type="m.test", + state_key="", + sender=self.user_id, + ) + ) + assert context.state_group is not None + + # Mark a state group that we're referencing as pending deletion. + self.get_success( + self.state_deletion_store.mark_state_groups_as_pending_deletion( + [context.state_group] + ) + ) + + # Advance time enough so we can delete the state group so they're both + # ready for deletion. + self.reactor.advance( + 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000 + ) + + # Manually insert into the table to mimic the state group getting used. + self.get_success( + self.store.db_pool.simple_insert( + table="event_to_state_groups", + values={"state_group": context.state_group, "event_id": event.event_id}, + desc="test_newly_referenced_state_group_gets_removed_from_pending", + ) + ) + + # Manually run the background task to delete pending state groups. + self.get_success(self.purge_events._delete_state_groups_loop()) + + # The pending deletion flag should be cleared... + pending_deletion = self.get_success( + self.state_deletion_store.db_pool.simple_select_one_onecol( + table="state_groups_pending_deletion", + keyvalues={"state_group": context.state_group}, + retcol="1", + allow_none=True, + desc="test_newly_referenced_state_group_gets_removed_from_pending", + ) + ) + self.assertIsNone(pending_deletion) + + # .. but the state should not have been deleted. + state = self.get_success( + self.state_store._get_state_for_groups([context.state_group]) + ) + self.assertGreater(len(state[context.state_group]), 0) From c46d452c7cdf7289ae7e8677b44a88c747761dce Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 3 Feb 2025 20:04:19 +0100 Subject: [PATCH 16/16] Fix bug where purging history could lead to increase in disk space usage (#18131) When purging history, we try and delete any state groups that become unreferenced (i.e. there are no longer any events that directly reference them). When we delete a state group that is referenced by another state group, we "de-delta" that state group so that it no longer refers to the state group that is deleted. There are two bugs with this approach that we fix here: 1. There is a common pattern where we end up storing two state groups when persisting a state event: the state before and after the new state event, where the latter is stored as a delta to the former. When deleting state groups we only deleted the "new" state and left (and potentially de-deltaed) the old state. This was due to a bug/typo when trying to find referenced state groups. 2. There are times where we store unreferenced state groups in the DB, during the purging of history these would not get rechecked and instead always de-deltaed. Instead, we should check for this case and delete any unreferenced state groups rather than de-deltaing them. The effect of the above bugs is that when purging history we'd end up with lots of unreferenced state groups that had been de-deltaed (i.e. stored as the full state). This can lead to dramatic increases in storage space used. --- changelog.d/18131.bugfix | 1 + synapse/storage/controllers/purge_events.py | 10 +++ synapse/storage/databases/state/store.py | 31 ++++++++- tests/storage/test_purge.py | 75 +++++++++++++++++++++ 4 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 changelog.d/18131.bugfix diff --git a/changelog.d/18131.bugfix b/changelog.d/18131.bugfix new file mode 100644 index 0000000000..4d0c19fab9 --- /dev/null +++ b/changelog.d/18131.bugfix @@ -0,0 +1 @@ +Fix rare edge case where state groups could be deleted while we are persisting new events that reference them. diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py index 2d6f80f770..47cec8c469 100644 --- a/synapse/storage/controllers/purge_events.py +++ b/synapse/storage/controllers/purge_events.py @@ -128,6 +128,16 @@ class PurgeEventsStorageController: next_to_search |= prevs state_groups_seen |= prevs + # We also check to see if anything referencing the state groups are + # also unreferenced. This helps ensure that we delete unreferenced + # state groups, if we don't then we will de-delta them when we + # delete the other state groups leading to increased DB usage. + next_edges = await self.stores.state.get_next_state_groups(current_search) + nexts = set(next_edges.keys()) + nexts -= state_groups_seen + next_to_search |= nexts + state_groups_seen |= nexts + to_delete = state_groups_seen - referenced_groups return to_delete diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 0f47642ae5..8c7980e719 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -853,7 +853,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): List[Tuple[int, int]], await self.db_pool.simple_select_many_batch( table="state_group_edges", - column="prev_state_group", + column="state_group", iterable=state_groups, keyvalues={}, retcols=("state_group", "prev_state_group"), @@ -863,6 +863,35 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): return dict(rows) + @trace + @tag_args + async def get_next_state_groups( + self, state_groups: Iterable[int] + ) -> Dict[int, int]: + """Fetch the groups that have the given state groups as their previous + state groups. + + Args: + state_groups + + Returns: + A mapping from state group to previous state group. + """ + + rows = cast( + List[Tuple[int, int]], + await self.db_pool.simple_select_many_batch( + table="state_group_edges", + column="prev_state_group", + iterable=state_groups, + keyvalues={}, + retcols=("state_group", "prev_state_group"), + desc="get_next_state_groups", + ), + ) + + return dict(rows) + async def purge_room_state(self, room_id: str) -> None: return await self.db_pool.runInteraction( "purge_room_state", diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index efd8d25bd1..5d6a8518c0 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -195,3 +195,78 @@ class PurgeTests(HomeserverTestCase): self.assertEqual(second_state, {}) self.assertEqual(third_state, {}) self.assertNotEqual(last_state, {}) + + def test_purge_unreferenced_state_group(self) -> None: + """Test that purging a room also gets rid of unreferenced state groups + it encounters during the purge. + + This is important, as otherwise these unreferenced state groups get + "de-deltaed" during the purge process, consuming lots of disk space. + """ + + self.helper.send(self.room_id, body="test1") + state1 = self.helper.send_state( + self.room_id, "org.matrix.test", body={"number": 2} + ) + state2 = self.helper.send_state( + self.room_id, "org.matrix.test", body={"number": 3} + ) + self.helper.send(self.room_id, body="test4") + last = self.helper.send(self.room_id, body="test5") + + # Create an unreferenced state group that has a prev group of one of the + # to-be-purged events. + prev_group = self.get_success( + self.store._get_state_group_for_event(state1["event_id"]) + ) + unreferenced_state_group = self.get_success( + self.state_store.store_state_group( + event_id=last["event_id"], + room_id=self.room_id, + prev_group=prev_group, + delta_ids={("org.matrix.test", ""): state2["event_id"]}, + current_state_ids=None, + ) + ) + + # Get the topological token + token = self.get_success( + self.store.get_topological_token_for_event(last["event_id"]) + ) + token_str = self.get_success(token.to_string(self.hs.get_datastores().main)) + + # Purge everything before this topological token + self.get_success( + self._storage_controllers.purge_events.purge_history( + self.room_id, token_str, True + ) + ) + + # Advance so that the background jobs to delete the state groups runs + self.reactor.advance( + 1 + self.state_deletion_store.DELAY_BEFORE_DELETION_MS / 1000 + ) + + # We expect that the unreferenced state group has been deleted. + row = self.get_success( + self.state_store.db_pool.simple_select_one_onecol( + table="state_groups", + keyvalues={"id": unreferenced_state_group}, + retcol="id", + allow_none=True, + desc="test_purge_unreferenced_state_group", + ) + ) + self.assertIsNone(row) + + # We expect there to now only be one state group for the room, which is + # the state group of the last event (as the only outlier). + state_groups = self.get_success( + self.state_store.db_pool.simple_select_onecol( + table="state_groups", + keyvalues={"room_id": self.room_id}, + retcol="id", + desc="test_purge_unreferenced_state_group", + ) + ) + self.assertEqual(len(state_groups), 1)