1
0

Rate limiter: Introduce record_action

This commit is contained in:
David Robertson
2022-06-28 15:50:32 +01:00
parent c2e3025b33
commit c594ab774b
2 changed files with 105 additions and 0 deletions

View File

@@ -209,6 +209,37 @@ class Ratelimiter:
return allowed, time_allowed
def record_action(
self,
requester: Optional[Requester],
key: Optional[Hashable] = None,
n_actions: int = 1,
_time_now_s: Optional[float] = None,
) -> None:
"""Record that an action(s) took place, even if they violate the rate limit.
This is useful for tracking the frequency of events that happen across
federation which we still want to impose local rate limits on. For instance, if
we are alice.com monitoring a particular room, we cannot prevent bob.com
from joining users to that room. However, we can track the number of recent
joins in the room and refuse to serve new joins ourselves if there have been too
many in the room across both homeservers.
Args:
requester: The requester that is doing the action, if any.
key: An arbitrary key used to classify an action. Defaults to the
requester's user ID.
n_actions: The number of times the user wants to do this action. If the user
cannot do all of the actions, the user's action count is not incremented
at all.
_time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests.
"""
key = self._get_key(requester, key)
time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
action_count, time_start, rate_hz = self._get_action_counts(key, time_now_s)
self.actions[key] = (action_count + n_actions, time_start, rate_hz)
def _prune_message_counts(self, time_now_s: float) -> None:
"""Remove message count entries that have not exceeded their defined
rate_hz limit

View File

@@ -314,3 +314,77 @@ class TestRatelimiter(unittest.HomeserverTestCase):
# Check that we get rate limited after using that token.
self.assertFalse(consume_at(11.1))
def test_record_action_which_doesnt_fill_bucket(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
)
# Observe two actions, leaving room in the bucket for one more.
limiter.record_action(requester=None, key="a", n_actions=2, _time_now_s=0.0)
# We should be able to take a new action now.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
)
self.assertTrue(success)
# ... but not two.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
)
self.assertFalse(success)
def test_record_action_which_fills_bucket(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
)
# Observe three actions, filling up the bucket.
limiter.record_action(requester=None, key="a", n_actions=3, _time_now_s=0.0)
# We should be unable to take a new action now.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
)
self.assertFalse(success)
# If we wait 10 seconds to leak a token, we should be able to take one action...
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
)
self.assertTrue(success)
# ... but not two.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
)
self.assertFalse(success)
def test_record_action_which_overfills_bucket(self) -> None:
limiter = Ratelimiter(
store=self.hs.get_datastores().main, clock=None, rate_hz=0.1, burst_count=3
)
# Observe four actions, exceeding the bucket.
limiter.record_action(requester=None, key="a", n_actions=4, _time_now_s=0.0)
# We should be prevented from taking a new action now.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=0.0)
)
self.assertFalse(success)
# If we wait 10 seconds to leak a token, we should be unable to take an action
# because the bucket is still full.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=10.0)
)
self.assertFalse(success)
# But after another 10 seconds we leak a second token, giving us room for
# action.
success, _ = self.get_success_or_raise(
limiter.can_do_action(requester=None, key="a", _time_now_s=20.0)
)
self.assertTrue(success)