Speed up pruning of ratelimiter (#19129)

I noticed this in some profiling. Basically, we prune the ratelimiters
by copying and iterating over every entry every 60 seconds. Instead,
let's use a wheel timer to track when we should potentially prune a
given key, and then we a) check fewer keys, and b) can run more
frequently. Hopefully this should mean we don't have a large pause
everytime we prune a ratelimiter with lots of keys.

Also fixes a bug where we didn't prune entries that were added via
`record_action` and never subsequently updated. This affected the media
and joins-per-room ratelimiter.
This commit is contained in:
Erik Johnston
2025-11-04 12:44:57 +00:00
committed by GitHub
parent 08f570f5f5
commit 5408101d21
5 changed files with 80 additions and 19 deletions

1
changelog.d/19129.misc Normal file
View File

@@ -0,0 +1 @@
Speed up pruning of ratelimiters.

View File

@@ -27,6 +27,7 @@ from synapse.config.ratelimiting import RatelimitSettings
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import Requester from synapse.types import Requester
from synapse.util.clock import Clock from synapse.util.clock import Clock
from synapse.util.wheel_timer import WheelTimer
if TYPE_CHECKING: if TYPE_CHECKING:
# To avoid circular imports: # To avoid circular imports:
@@ -92,9 +93,14 @@ class Ratelimiter:
# * The number of tokens currently in the bucket, # * The number of tokens currently in the bucket,
# * The time point when the bucket was last completely empty, and # * The time point when the bucket was last completely empty, and
# * The rate_hz (leak rate) of this particular bucket. # * The rate_hz (leak rate) of this particular bucket.
self.actions: dict[Hashable, tuple[float, float, float]] = {} self.actions: dict[Hashable, tuple[int, float, float]] = {}
self.clock.looping_call(self._prune_message_counts, 60 * 1000) # Records when actions should potentially be pruned. Note that we don't
# need to be accurate here, as this is just a cleanup job of `actions`
# and doesn't affect correctness.
self._timer: WheelTimer[Hashable] = WheelTimer()
self.clock.looping_call(self._prune_message_counts, 15 * 1000)
def _get_key( def _get_key(
self, requester: Optional[Requester], key: Optional[Hashable] self, requester: Optional[Requester], key: Optional[Hashable]
@@ -109,9 +115,9 @@ class Ratelimiter:
def _get_action_counts( def _get_action_counts(
self, key: Hashable, time_now_s: float self, key: Hashable, time_now_s: float
) -> tuple[float, float, float]: ) -> tuple[int, float, float]:
"""Retrieve the action counts, with a fallback representing an empty bucket.""" """Retrieve the action counts, with a fallback representing an empty bucket."""
return self.actions.get(key, (0.0, time_now_s, 0.0)) return self.actions.get(key, (0, time_now_s, self.rate_hz))
async def can_do_action( async def can_do_action(
self, self,
@@ -217,8 +223,11 @@ class Ratelimiter:
allowed = True allowed = True
action_count = action_count + n_actions action_count = action_count + n_actions
if update: # Only record the action if we're allowed to perform it.
self.actions[key] = (action_count, time_start, rate_hz) if allowed and update:
self._record_action_inner(
key, action_count, time_start, rate_hz, time_now_s
)
if rate_hz > 0: if rate_hz > 0:
# Find out when the count of existing actions expires # Find out when the count of existing actions expires
@@ -264,7 +273,37 @@ class Ratelimiter:
key = self._get_key(requester, key) key = self._get_key(requester, key)
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time() time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s) action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
self.actions[key] = (action_count + n_actions, time_start, rate_hz) self._record_action_inner(
key, action_count + n_actions, time_start, rate_hz, time_now_s
)
def _record_action_inner(
self,
key: Hashable,
action_count: int,
time_start: float,
rate_hz: float,
time_now_s: float,
) -> None:
"""Helper to atomically update the action count for a given key."""
prune_time_s = time_start + action_count / rate_hz
# If the prune time is in the past, we can just remove the entry rather
# than inserting and immediately pruning.
if prune_time_s <= time_now_s:
self.actions.pop(key, None)
return
self.actions[key] = (action_count, time_start, rate_hz)
# We need to make sure that we only call prune *after* the entry
# expires, otherwise the scheduled prune may not actually prune it. This
# is just a cleanup job, so it doesn't matter if entries aren't pruned
# immediately after they expire. Hence we schedule the prune a little
# after the entry is due to expire.
prune_time_s += 0.1
self._timer.insert(int(time_now_s * 1000), key, int(prune_time_s * 1000))
def _prune_message_counts(self) -> None: def _prune_message_counts(self) -> None:
"""Remove message count entries that have not exceeded their defined """Remove message count entries that have not exceeded their defined
@@ -272,18 +311,24 @@ class Ratelimiter:
""" """
time_now_s = self.clock.time() time_now_s = self.clock.time()
# We create a copy of the key list here as the dictionary is modified during # Pull out all the keys that *might* need pruning. We still need to
# the loop # verify they haven't since been updated.
for key in list(self.actions.keys()): to_prune = self._timer.fetch(int(time_now_s * 1000))
action_count, time_start, rate_hz = self.actions[key]
for key in to_prune:
value = self.actions.get(key)
if value is None:
continue
action_count, time_start, rate_hz = value
# Rate limit = "seconds since we started limiting this action" * rate_hz # Rate limit = "seconds since we started limiting this action" * rate_hz
# If this limit has not been exceeded, wipe our record of this action # If this limit has not been exceeded, wipe our record of this action
time_delta = time_now_s - time_start time_delta = time_now_s - time_start
if action_count - time_delta * rate_hz > 0: if action_count - time_delta * rate_hz > 0:
continue continue
else:
del self.actions[key] del self.actions[key]
async def ratelimit( async def ratelimit(
self, self,

View File

@@ -228,6 +228,21 @@ class TestRatelimiter(unittest.HomeserverTestCase):
self.assertNotIn("test_id_1", limiter.actions) self.assertNotIn("test_id_1", limiter.actions)
def test_pruning_record_action(self) -> None:
"""Test that entries added by record_action also get pruned."""
limiter = Ratelimiter(
store=self.hs.get_datastores().main,
clock=self.clock,
cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
)
limiter.record_action(None, key="test_id_1", n_actions=1, _time_now_s=0)
self.assertIn("test_id_1", limiter.actions)
self.reactor.advance(60)
self.assertNotIn("test_id_1", limiter.actions)
def test_db_user_override(self) -> None: def test_db_user_override(self) -> None:
"""Test that users that have ratelimiting disabled in the DB aren't """Test that users that have ratelimiting disabled in the DB aren't
ratelimited. ratelimited.

View File

@@ -462,7 +462,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
) )
self.assertEqual(r[("m.room.member", joining_user)].membership, "join") self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}}) @override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 3}})
def test_make_join_respects_room_join_rate_limit(self) -> None: def test_make_join_respects_room_join_rate_limit(self) -> None:
# In the test setup, two users join the room. Since the rate limiter burst # In the test setup, two users join the room. Since the rate limiter burst
# count is 3, a new make_join request to the room should be accepted. # count is 3, a new make_join request to the room should be accepted.
@@ -484,7 +484,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
) )
self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body) self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body)
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}}) @override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 3}})
def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> None: def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> None:
# Make two make_join requests up front. (These are rate limited, but do not # Make two make_join requests up front. (These are rate limited, but do not
# contribute to the rate limit.) # contribute to the rate limit.)

View File

@@ -50,7 +50,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}" self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}"
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) @override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 2}})
def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None: def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None:
# The rate limiter has accumulated one token from Alice's join after the create # The rate limiter has accumulated one token from Alice's join after the create
# event. # event.
@@ -76,7 +76,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
by=0.5, by=0.5,
) )
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) @override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 2}})
def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None: def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None:
# The rate limiter has accumulated one token from Alice's join after the create # The rate limiter has accumulated one token from Alice's join after the create
# event. Alice should still be able to change her displayname. # event. Alice should still be able to change her displayname.
@@ -100,7 +100,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
) )
) )
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 1}}) @override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 1}})
def test_remote_joins_contribute_to_rate_limit(self) -> None: def test_remote_joins_contribute_to_rate_limit(self) -> None:
# Join once, to fill the rate limiter bucket. # Join once, to fill the rate limiter bucket.
# #
@@ -248,7 +248,7 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCa
self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token) self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
self.intially_unjoined_room_id = "!example:otherhs" self.intially_unjoined_room_id = "!example:otherhs"
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}}) @override_config({"rc_joins_per_room": {"per_second": 0.01, "burst_count": 2}})
def test_local_users_joining_on_another_worker_contribute_to_rate_limit( def test_local_users_joining_on_another_worker_contribute_to_rate_limit(
self, self,
) -> None: ) -> None: