From 5408101d21a08c42359737643a6cdab5021c1eb4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 4 Nov 2025 12:44:57 +0000 Subject: [PATCH] Speed up pruning of ratelimiter (#19129) I noticed this in some profiling. Basically, we prune the ratelimiters by copying and iterating over every entry every 60 seconds. Instead, let's use a wheel timer to track when we should potentially prune a given key, and then we a) check fewer keys, and b) can run more frequently. Hopefully this should mean we don't have a large pause everytime we prune a ratelimiter with lots of keys. Also fixes a bug where we didn't prune entries that were added via `record_action` and never subsequently updated. This affected the media and joins-per-room ratelimiter. --- changelog.d/19129.misc | 1 + synapse/api/ratelimiting.py | 71 ++++++++++++++++++---- tests/api/test_ratelimiting.py | 15 +++++ tests/federation/test_federation_server.py | 4 +- tests/handlers/test_room_member.py | 8 +-- 5 files changed, 80 insertions(+), 19 deletions(-) create mode 100644 changelog.d/19129.misc diff --git a/changelog.d/19129.misc b/changelog.d/19129.misc new file mode 100644 index 0000000000..117dbfadea --- /dev/null +++ b/changelog.d/19129.misc @@ -0,0 +1 @@ +Speed up pruning of ratelimiters. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 1a43bdff23..ee0e9181ce 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -27,6 +27,7 @@ from synapse.config.ratelimiting import RatelimitSettings from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util.clock import Clock +from synapse.util.wheel_timer import WheelTimer if TYPE_CHECKING: # To avoid circular imports: @@ -92,9 +93,14 @@ class Ratelimiter: # * The number of tokens currently in the bucket, # * The time point when the bucket was last completely empty, and # * The rate_hz (leak rate) of this particular bucket. - self.actions: dict[Hashable, tuple[float, float, float]] = {} + self.actions: dict[Hashable, tuple[int, float, float]] = {} - self.clock.looping_call(self._prune_message_counts, 60 * 1000) + # Records when actions should potentially be pruned. Note that we don't + # need to be accurate here, as this is just a cleanup job of `actions` + # and doesn't affect correctness. + self._timer: WheelTimer[Hashable] = WheelTimer() + + self.clock.looping_call(self._prune_message_counts, 15 * 1000) def _get_key( self, requester: Optional[Requester], key: Optional[Hashable] @@ -109,9 +115,9 @@ class Ratelimiter: def _get_action_counts( self, key: Hashable, time_now_s: float - ) -> tuple[float, float, float]: + ) -> tuple[int, float, float]: """Retrieve the action counts, with a fallback representing an empty bucket.""" - return self.actions.get(key, (0.0, time_now_s, 0.0)) + return self.actions.get(key, (0, time_now_s, self.rate_hz)) async def can_do_action( self, @@ -217,8 +223,11 @@ class Ratelimiter: allowed = True action_count = action_count + n_actions - if update: - self.actions[key] = (action_count, time_start, rate_hz) + # Only record the action if we're allowed to perform it. + if allowed and update: + self._record_action_inner( + key, action_count, time_start, rate_hz, time_now_s + ) if rate_hz > 0: # Find out when the count of existing actions expires @@ -264,7 +273,37 @@ class Ratelimiter: key = self._get_key(requester, key) time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s) - self.actions[key] = (action_count + n_actions, time_start, rate_hz) + self._record_action_inner( + key, action_count + n_actions, time_start, rate_hz, time_now_s + ) + + def _record_action_inner( + self, + key: Hashable, + action_count: int, + time_start: float, + rate_hz: float, + time_now_s: float, + ) -> None: + """Helper to atomically update the action count for a given key.""" + prune_time_s = time_start + action_count / rate_hz + + # If the prune time is in the past, we can just remove the entry rather + # than inserting and immediately pruning. + if prune_time_s <= time_now_s: + self.actions.pop(key, None) + return + + self.actions[key] = (action_count, time_start, rate_hz) + + # We need to make sure that we only call prune *after* the entry + # expires, otherwise the scheduled prune may not actually prune it. This + # is just a cleanup job, so it doesn't matter if entries aren't pruned + # immediately after they expire. Hence we schedule the prune a little + # after the entry is due to expire. + prune_time_s += 0.1 + + self._timer.insert(int(time_now_s * 1000), key, int(prune_time_s * 1000)) def _prune_message_counts(self) -> None: """Remove message count entries that have not exceeded their defined @@ -272,18 +311,24 @@ class Ratelimiter: """ time_now_s = self.clock.time() - # We create a copy of the key list here as the dictionary is modified during - # the loop - for key in list(self.actions.keys()): - action_count, time_start, rate_hz = self.actions[key] + # Pull out all the keys that *might* need pruning. We still need to + # verify they haven't since been updated. + to_prune = self._timer.fetch(int(time_now_s * 1000)) + + for key in to_prune: + value = self.actions.get(key) + if value is None: + continue + + action_count, time_start, rate_hz = value # Rate limit = "seconds since we started limiting this action" * rate_hz # If this limit has not been exceeded, wipe our record of this action time_delta = time_now_s - time_start if action_count - time_delta * rate_hz > 0: continue - else: - del self.actions[key] + + del self.actions[key] async def ratelimit( self, diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 2e45d4e4d2..34369a8746 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -228,6 +228,21 @@ class TestRatelimiter(unittest.HomeserverTestCase): self.assertNotIn("test_id_1", limiter.actions) + def test_pruning_record_action(self) -> None: + """Test that entries added by record_action also get pruned.""" + limiter = Ratelimiter( + store=self.hs.get_datastores().main, + clock=self.clock, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), + ) + limiter.record_action(None, key="test_id_1", n_actions=1, _time_now_s=0) + + self.assertIn("test_id_1", limiter.actions) + + self.reactor.advance(60) + + self.assertNotIn("test_id_1", limiter.actions) + def test_db_user_override(self) -> None: """Test that users that have ratelimiting disabled in the DB aren't ratelimited. diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 509f1f1e82..b1371d0ac7 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -462,7 +462,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") - @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}}) + @override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 3}}) def test_make_join_respects_room_join_rate_limit(self) -> None: # In the test setup, two users join the room. Since the rate limiter burst # count is 3, a new make_join request to the room should be accepted. @@ -484,7 +484,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): ) self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body) - @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}}) + @override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 3}}) def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> None: # Make two make_join requests up front. (These are rate limited, but do not # contribute to the rate limit.) diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index 92c7c36602..8f9e27603e 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -50,7 +50,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}" - @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) + @override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 2}}) def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None: # The rate limiter has accumulated one token from Alice's join after the create # event. @@ -76,7 +76,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): by=0.5, ) - @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) + @override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 2}}) def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None: # The rate limiter has accumulated one token from Alice's join after the create # event. Alice should still be able to change her displayname. @@ -100,7 +100,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): ) ) - @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 1}}) + @override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 1}}) def test_remote_joins_contribute_to_rate_limit(self) -> None: # Join once, to fill the rate limiter bucket. # @@ -248,7 +248,7 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCa self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token) self.intially_unjoined_room_id = "!example:otherhs" - @override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) + @override_config({"rc_joins_per_room": {"per_second": 0.01, "burst_count": 2}}) def test_local_users_joining_on_another_worker_contribute_to_rate_limit( self, ) -> None: