Speed up pruning of ratelimiter (#19129)
I noticed this in some profiling. Basically, we prune the ratelimiters by copying and iterating over every entry every 60 seconds. Instead, let's use a wheel timer to track when we should potentially prune a given key, and then we a) check fewer keys, and b) can run more frequently. Hopefully this should mean we don't have a large pause everytime we prune a ratelimiter with lots of keys. Also fixes a bug where we didn't prune entries that were added via `record_action` and never subsequently updated. This affected the media and joins-per-room ratelimiter.
This commit is contained in:
1
changelog.d/19129.misc
Normal file
1
changelog.d/19129.misc
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Speed up pruning of ratelimiters.
|
||||||
@@ -27,6 +27,7 @@ from synapse.config.ratelimiting import RatelimitSettings
|
|||||||
from synapse.storage.databases.main import DataStore
|
from synapse.storage.databases.main import DataStore
|
||||||
from synapse.types import Requester
|
from synapse.types import Requester
|
||||||
from synapse.util.clock import Clock
|
from synapse.util.clock import Clock
|
||||||
|
from synapse.util.wheel_timer import WheelTimer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# To avoid circular imports:
|
# To avoid circular imports:
|
||||||
@@ -92,9 +93,14 @@ class Ratelimiter:
|
|||||||
# * The number of tokens currently in the bucket,
|
# * The number of tokens currently in the bucket,
|
||||||
# * The time point when the bucket was last completely empty, and
|
# * The time point when the bucket was last completely empty, and
|
||||||
# * The rate_hz (leak rate) of this particular bucket.
|
# * The rate_hz (leak rate) of this particular bucket.
|
||||||
self.actions: dict[Hashable, tuple[float, float, float]] = {}
|
self.actions: dict[Hashable, tuple[int, float, float]] = {}
|
||||||
|
|
||||||
self.clock.looping_call(self._prune_message_counts, 60 * 1000)
|
# Records when actions should potentially be pruned. Note that we don't
|
||||||
|
# need to be accurate here, as this is just a cleanup job of `actions`
|
||||||
|
# and doesn't affect correctness.
|
||||||
|
self._timer: WheelTimer[Hashable] = WheelTimer()
|
||||||
|
|
||||||
|
self.clock.looping_call(self._prune_message_counts, 15 * 1000)
|
||||||
|
|
||||||
def _get_key(
|
def _get_key(
|
||||||
self, requester: Optional[Requester], key: Optional[Hashable]
|
self, requester: Optional[Requester], key: Optional[Hashable]
|
||||||
@@ -109,9 +115,9 @@ class Ratelimiter:
|
|||||||
|
|
||||||
def _get_action_counts(
|
def _get_action_counts(
|
||||||
self, key: Hashable, time_now_s: float
|
self, key: Hashable, time_now_s: float
|
||||||
) -> tuple[float, float, float]:
|
) -> tuple[int, float, float]:
|
||||||
"""Retrieve the action counts, with a fallback representing an empty bucket."""
|
"""Retrieve the action counts, with a fallback representing an empty bucket."""
|
||||||
return self.actions.get(key, (0.0, time_now_s, 0.0))
|
return self.actions.get(key, (0, time_now_s, self.rate_hz))
|
||||||
|
|
||||||
async def can_do_action(
|
async def can_do_action(
|
||||||
self,
|
self,
|
||||||
@@ -217,8 +223,11 @@ class Ratelimiter:
|
|||||||
allowed = True
|
allowed = True
|
||||||
action_count = action_count + n_actions
|
action_count = action_count + n_actions
|
||||||
|
|
||||||
if update:
|
# Only record the action if we're allowed to perform it.
|
||||||
self.actions[key] = (action_count, time_start, rate_hz)
|
if allowed and update:
|
||||||
|
self._record_action_inner(
|
||||||
|
key, action_count, time_start, rate_hz, time_now_s
|
||||||
|
)
|
||||||
|
|
||||||
if rate_hz > 0:
|
if rate_hz > 0:
|
||||||
# Find out when the count of existing actions expires
|
# Find out when the count of existing actions expires
|
||||||
@@ -264,7 +273,37 @@ class Ratelimiter:
|
|||||||
key = self._get_key(requester, key)
|
key = self._get_key(requester, key)
|
||||||
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
|
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
|
||||||
action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
|
action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
|
||||||
self.actions[key] = (action_count + n_actions, time_start, rate_hz)
|
self._record_action_inner(
|
||||||
|
key, action_count + n_actions, time_start, rate_hz, time_now_s
|
||||||
|
)
|
||||||
|
|
||||||
|
def _record_action_inner(
|
||||||
|
self,
|
||||||
|
key: Hashable,
|
||||||
|
action_count: int,
|
||||||
|
time_start: float,
|
||||||
|
rate_hz: float,
|
||||||
|
time_now_s: float,
|
||||||
|
) -> None:
|
||||||
|
"""Helper to atomically update the action count for a given key."""
|
||||||
|
prune_time_s = time_start + action_count / rate_hz
|
||||||
|
|
||||||
|
# If the prune time is in the past, we can just remove the entry rather
|
||||||
|
# than inserting and immediately pruning.
|
||||||
|
if prune_time_s <= time_now_s:
|
||||||
|
self.actions.pop(key, None)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.actions[key] = (action_count, time_start, rate_hz)
|
||||||
|
|
||||||
|
# We need to make sure that we only call prune *after* the entry
|
||||||
|
# expires, otherwise the scheduled prune may not actually prune it. This
|
||||||
|
# is just a cleanup job, so it doesn't matter if entries aren't pruned
|
||||||
|
# immediately after they expire. Hence we schedule the prune a little
|
||||||
|
# after the entry is due to expire.
|
||||||
|
prune_time_s += 0.1
|
||||||
|
|
||||||
|
self._timer.insert(int(time_now_s * 1000), key, int(prune_time_s * 1000))
|
||||||
|
|
||||||
def _prune_message_counts(self) -> None:
|
def _prune_message_counts(self) -> None:
|
||||||
"""Remove message count entries that have not exceeded their defined
|
"""Remove message count entries that have not exceeded their defined
|
||||||
@@ -272,17 +311,23 @@ class Ratelimiter:
|
|||||||
"""
|
"""
|
||||||
time_now_s = self.clock.time()
|
time_now_s = self.clock.time()
|
||||||
|
|
||||||
# We create a copy of the key list here as the dictionary is modified during
|
# Pull out all the keys that *might* need pruning. We still need to
|
||||||
# the loop
|
# verify they haven't since been updated.
|
||||||
for key in list(self.actions.keys()):
|
to_prune = self._timer.fetch(int(time_now_s * 1000))
|
||||||
action_count, time_start, rate_hz = self.actions[key]
|
|
||||||
|
for key in to_prune:
|
||||||
|
value = self.actions.get(key)
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
action_count, time_start, rate_hz = value
|
||||||
|
|
||||||
# Rate limit = "seconds since we started limiting this action" * rate_hz
|
# Rate limit = "seconds since we started limiting this action" * rate_hz
|
||||||
# If this limit has not been exceeded, wipe our record of this action
|
# If this limit has not been exceeded, wipe our record of this action
|
||||||
time_delta = time_now_s - time_start
|
time_delta = time_now_s - time_start
|
||||||
if action_count - time_delta * rate_hz > 0:
|
if action_count - time_delta * rate_hz > 0:
|
||||||
continue
|
continue
|
||||||
else:
|
|
||||||
del self.actions[key]
|
del self.actions[key]
|
||||||
|
|
||||||
async def ratelimit(
|
async def ratelimit(
|
||||||
|
|||||||
@@ -228,6 +228,21 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
self.assertNotIn("test_id_1", limiter.actions)
|
self.assertNotIn("test_id_1", limiter.actions)
|
||||||
|
|
||||||
|
def test_pruning_record_action(self) -> None:
|
||||||
|
"""Test that entries added by record_action also get pruned."""
|
||||||
|
limiter = Ratelimiter(
|
||||||
|
store=self.hs.get_datastores().main,
|
||||||
|
clock=self.clock,
|
||||||
|
cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
|
||||||
|
)
|
||||||
|
limiter.record_action(None, key="test_id_1", n_actions=1, _time_now_s=0)
|
||||||
|
|
||||||
|
self.assertIn("test_id_1", limiter.actions)
|
||||||
|
|
||||||
|
self.reactor.advance(60)
|
||||||
|
|
||||||
|
self.assertNotIn("test_id_1", limiter.actions)
|
||||||
|
|
||||||
def test_db_user_override(self) -> None:
|
def test_db_user_override(self) -> None:
|
||||||
"""Test that users that have ratelimiting disabled in the DB aren't
|
"""Test that users that have ratelimiting disabled in the DB aren't
|
||||||
ratelimited.
|
ratelimited.
|
||||||
|
|||||||
@@ -462,7 +462,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
||||||
|
|
||||||
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}})
|
@override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 3}})
|
||||||
def test_make_join_respects_room_join_rate_limit(self) -> None:
|
def test_make_join_respects_room_join_rate_limit(self) -> None:
|
||||||
# In the test setup, two users join the room. Since the rate limiter burst
|
# In the test setup, two users join the room. Since the rate limiter burst
|
||||||
# count is 3, a new make_join request to the room should be accepted.
|
# count is 3, a new make_join request to the room should be accepted.
|
||||||
@@ -484,7 +484,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body)
|
self.assertEqual(channel.code, HTTPStatus.TOO_MANY_REQUESTS, channel.json_body)
|
||||||
|
|
||||||
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 3}})
|
@override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 3}})
|
||||||
def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> None:
|
def test_send_join_contributes_to_room_join_rate_limit_and_is_limited(self) -> None:
|
||||||
# Make two make_join requests up front. (These are rate limited, but do not
|
# Make two make_join requests up front. (These are rate limited, but do not
|
||||||
# contribute to the rate limit.)
|
# contribute to the rate limit.)
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
|
|||||||
|
|
||||||
self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}"
|
self.intially_unjoined_room_id = f"!example:{self.OTHER_SERVER_NAME}"
|
||||||
|
|
||||||
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
|
@override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 2}})
|
||||||
def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None:
|
def test_local_user_local_joins_contribute_to_limit_and_are_limited(self) -> None:
|
||||||
# The rate limiter has accumulated one token from Alice's join after the create
|
# The rate limiter has accumulated one token from Alice's join after the create
|
||||||
# event.
|
# event.
|
||||||
@@ -76,7 +76,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
|
|||||||
by=0.5,
|
by=0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
|
@override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 2}})
|
||||||
def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None:
|
def test_local_user_profile_edits_dont_contribute_to_limit(self) -> None:
|
||||||
# The rate limiter has accumulated one token from Alice's join after the create
|
# The rate limiter has accumulated one token from Alice's join after the create
|
||||||
# event. Alice should still be able to change her displayname.
|
# event. Alice should still be able to change her displayname.
|
||||||
@@ -100,7 +100,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 1}})
|
@override_config({"rc_joins_per_room": {"per_second": 0.1, "burst_count": 1}})
|
||||||
def test_remote_joins_contribute_to_rate_limit(self) -> None:
|
def test_remote_joins_contribute_to_rate_limit(self) -> None:
|
||||||
# Join once, to fill the rate limiter bucket.
|
# Join once, to fill the rate limiter bucket.
|
||||||
#
|
#
|
||||||
@@ -248,7 +248,7 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCa
|
|||||||
self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
|
self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
|
||||||
self.intially_unjoined_room_id = "!example:otherhs"
|
self.intially_unjoined_room_id = "!example:otherhs"
|
||||||
|
|
||||||
@override_config({"rc_joins_per_room": {"per_second": 0, "burst_count": 2}})
|
@override_config({"rc_joins_per_room": {"per_second": 0.01, "burst_count": 2}})
|
||||||
def test_local_users_joining_on_another_worker_contribute_to_rate_limit(
|
def test_local_users_joining_on_another_worker_contribute_to_rate_limit(
|
||||||
self,
|
self,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user