Rate limiter: Introduce record_action
This commit is contained in:
@@ -209,6 +209,37 @@ class Ratelimiter:
|
||||
|
||||
return allowed, time_allowed
|
||||
|
||||
def record_action(
|
||||
self,
|
||||
requester: Optional[Requester],
|
||||
key: Optional[Hashable] = None,
|
||||
n_actions: int = 1,
|
||||
_time_now_s: Optional[float] = None,
|
||||
) -> None:
|
||||
"""Record that an action(s) took place, even if they violate the rate limit.
|
||||
|
||||
This is useful for tracking the frequency of events that happen across
|
||||
federation which we still want to impose local rate limits on. For instance, if
|
||||
we are alice.com monitoring a particular room, we cannot prevent bob.com
|
||||
from joining users to that room. However, we can track the number of recent
|
||||
joins in the room and refuse to serve new joins ourselves if there have been too
|
||||
many in the room across both homeservers.
|
||||
|
||||
Args:
|
||||
requester: The requester that is doing the action, if any.
|
||||
key: An arbitrary key used to classify an action. Defaults to the
|
||||
requester's user ID.
|
||||
n_actions: The number of times the user wants to do this action. If the user
|
||||
cannot do all of the actions, the user's action count is not incremented
|
||||
at all.
|
||||
_time_now_s: The current time. Optional, defaults to the current time according
|
||||
to self.clock. Only used by tests.
|
||||
"""
|
||||
key = self._get_key(requester, key)
|
||||
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
|
||||
action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
|
||||
self.actions[key] = (action_count + n_actions, time_start, rate_hz)
|
||||
|
||||
def _prune_message_counts(self, time_now_s: float) -> None:
|
||||
"""Remove message count entries that have not exceeded their defined
|
||||
rate_hz limit
|
||||
|
||||
@@ -314,3 +314,77 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
||||
|
||||
# Check that we get rate limited after using that token.
|
||||
self.assertFalse(consume_at(11.1))
|
||||
|
||||
def test_record_action_which_doesnt_fill_bucket(self) -> None:
|
||||
limiter = Ratelimiter(
|
||||
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
|
||||
)
|
||||
|
||||
# Observe two actions, leaving room in the bucket for one more.
|
||||
limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0)
|
||||
|
||||
# We should be able to take a new action now.
|
||||
success, _ = self.get_success_or_raise(
|
||||
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
|
||||
)
|
||||
self.assertTrue(success)
|
||||
|
||||
# ... but not two.
|
||||
success, _ = self.get_success_or_raise(
|
||||
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
|
||||
)
|
||||
self.assertFalse(success)
|
||||
|
||||
def test_record_action_which_fills_bucket(self) -> None:
|
||||
limiter = Ratelimiter(
|
||||
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
|
||||
)
|
||||
|
||||
# Observe three actions, filling up the bucket.
|
||||
limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0)
|
||||
|
||||
# We should be unable to take a new action now.
|
||||
success, _ = self.get_success_or_raise(
|
||||
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
|
||||
)
|
||||
self.assertFalse(success)
|
||||
|
||||
# If we wait 10 seconds to leak a token, we should be able to take one action...
|
||||
success, _ = self.get_success_or_raise(
|
||||
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
|
||||
)
|
||||
self.assertTrue(success)
|
||||
|
||||
# ... but not two.
|
||||
success, _ = self.get_success_or_raise(
|
||||
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
|
||||
)
|
||||
self.assertFalse(success)
|
||||
|
||||
def test_record_action_which_overfills_bucket(self) -> None:
|
||||
limiter = Ratelimiter(
|
||||
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
|
||||
)
|
||||
|
||||
# Observe four actions, exceeding the bucket.
|
||||
limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0)
|
||||
|
||||
# We should be prevented from taking a new action now.
|
||||
success, _ = self.get_success_or_raise(
|
||||
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
|
||||
)
|
||||
self.assertFalse(success)
|
||||
|
||||
# If we wait 10 seconds to leak a token, we should be unable to take an action
|
||||
# because the bucket is still full.
|
||||
success, _ = self.get_success_or_raise(
|
||||
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
|
||||
)
|
||||
self.assertFalse(success)
|
||||
|
||||
# But after another 10 seconds we leak a second token, giving us room for
|
||||
# action.
|
||||
success, _ = self.get_success_or_raise(
|
||||
limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
|
||||
)
|
||||
self.assertTrue(success)
|
||||
|
||||
Reference in New Issue
Block a user