Compare commits

...

8 Commits

Author SHA1 Message Date
Andrew Morgan
27e4157727 wip 2020-06-01 17:04:52 +01:00
Andrew Morgan
f6203a60e0 Make rate_hz and burst_count overridable per-request 2020-05-29 18:35:19 +01:00
Andrew Morgan
c322ba00a3 changelog 2020-05-28 22:57:21 +01:00
Andrew Morgan
6a07c2d9ad lint 2020-05-28 22:53:23 +01:00
Andrew Morgan
a0ef594905 Update unittests 2020-05-28 22:53:23 +01:00
Andrew Morgan
82eac22286 Modify servlets to pull Ratelimiters from HomeServer class 2020-05-28 22:53:23 +01:00
Andrew Morgan
0e6ee7ca17 Ratelimiters are instantiated by the HomeServer class
This makes it simple for tests to modify/nullify them.
2020-05-28 22:53:23 +01:00
Andrew Morgan
4f715beebf Refactor and comment ratelimiting. Set limits in constructor 2020-05-28 22:53:21 +01:00
19 changed files with 307 additions and 219 deletions

1
changelog.d/7595.misc Normal file
View File

@@ -0,0 +1 @@
Refactor `Ratelimiter` and try to limit the amount of related, expensive config value accesses.

View File

@@ -16,75 +16,161 @@ from collections import OrderedDict
from typing import Any, Optional, Tuple
from synapse.api.errors import LimitExceededError
from synapse.util import Clock
import logging
logger = logging.getLogger(__name__)
class Ratelimiter(object):
"""
Ratelimit message sending by user.
Ratelimit actions marked by arbitrary keys.
Args:
clock: A homeserver clock, for retrieving the current time
rate_hz: The long term number of actions that can be performed in a second.
burst_count: How many actions that can be performed before being limited.
"""
def __init__(self):
self.message_counts = (
OrderedDict()
) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]]
def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
self.clock = clock
self.rate_hz = rate_hz
self.burst_count = burst_count
def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True):
# A ordered dictionary keeping track of actions, when they were last
# performed and how often. Each entry is a mapping from a key of arbitrary type
# to a tuple representing:
# * How many times an action has occurred since a point in time
# * The point in time
# * The rate_hz of this particular entry. This can vary per-request
self.actions = (
OrderedDict()
) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]]
def can_do_action(
self,
key: Any,
time_now_s: Optional[int] = None,
update: bool = True,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
) -> Tuple[bool, float]:
"""Can the entity (e.g. user or IP address) perform the action?
Args:
key: The key we should use when rate limiting. Can be a user ID
(when sending events), an IP address, etc.
time_now_s: The time now.
rate_hz: The long term number of messages a user can send in a
second.
burst_count: How many messages the user can send before being
limited.
update (bool): Whether to update the message rates or not. This is
useful to check if a message would be allowed to be sent before
its ready to be actually sent.
time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Pretty much only used for tests.
update: Whether to count this check as performing the action
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set.
Returns:
A pair of a bool indicating if they can send a message now and a
time in seconds of when they can next send a message.
A tuple containing:
* A bool indicating if they can perform the action now
* The time in seconds of when it can next be performed.
-1 if a rate_hz has not been defined for this Ratelimiter
"""
self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.get(
key, (0.0, time_now_s, None)
)
# Override default values if set
time_now_s = time_now_s if time_now_s is not None else self.clock.time()
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
burst_count = burst_count if burst_count is not None else self.burst_count
# Remove any expired entries
self._prune_message_counts(time_now_s)
# Check if there is an existing count entry for this key
action_count, time_start, _ = self.actions.get(key, (0.0, time_now_s, None))
# Check whether performing another action is allowed
time_delta = time_now_s - time_start
sent_count = message_count - time_delta * rate_hz
if sent_count < 0:
performed_count = action_count - time_delta * rate_hz
if performed_count < 0:
# Allow, reset back to count 1
allowed = True
time_start = time_now_s
message_count = 1.0
elif sent_count > burst_count - 1.0:
action_count = 1.0
elif performed_count > burst_count - 1.0:
# Deny, we have exceeded our burst count
allowed = False
else:
# We haven't reached our limit yet
allowed = True
message_count += 1
action_count += 1.0
if update:
self.message_counts[key] = (message_count, time_start, rate_hz)
self.actions[key] = (action_count, time_start, rate_hz)
if rate_hz > 0:
time_allowed = time_start + (message_count - burst_count + 1) / rate_hz
logger.info("rate and burst: %s %s. performed_count: %s, allowed: %s", rate_hz,
burst_count, performed_count, allowed)
# Figure out the time when an action can be performed again
if self.rate_hz > 0:
time_allowed = time_start + (action_count - burst_count + 1) / rate_hz
# Don't give back a time in the past
if time_allowed < time_now_s:
time_allowed = time_now_s
else:
# This does not apply
time_allowed = -1
return allowed, time_allowed
def prune_message_counts(self, time_now_s):
for key in list(self.message_counts.keys()):
message_count, time_start, rate_hz = self.message_counts[key]
time_delta = time_now_s - time_start
if message_count - time_delta * rate_hz > 0:
break
else:
del self.message_counts[key]
def _prune_message_counts(self, time_now_s: int):
"""Remove message count entries that have not exceeded their defined
rate_hz limit
Args:
time_now_s: The current time
"""
# We create a copy of the key list here as the dictionary is modified during
# the loop
for key in list(self.actions.keys()):
action_count, time_start, rate_hz = self.actions[key]
# Rate limit = "seconds since we started limiting this action" * rate_hz
# If this limit has not been exceeded, wipe our record of this action
time_delta = time_now_s - time_start
if action_count - time_delta * rate_hz > 0:
continue
else:
del self.actions[key]
def ratelimit(
self,
key: Any,
time_now_s: Optional[int] = None,
update: bool = True,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
):
"""Checks if an action can be performed. If not, raises a LimitExceededError
Args:
key: An arbitrary key used to classify an action
time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Pretty much only used for tests.
update: Whether to count this check as performing the action
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set.
Raises:
LimitExceededError: If an action could not be performed, along with the time in
milliseconds until the action can be performed again
"""
# Override default values if set
time_now_s = time_now_s if time_now_s is not None else self.clock.time()
rate_hz = rate_hz if rate_hz is not None else self.rate_hz
burst_count = burst_count if burst_count is not None else self.burst_count
def ratelimit(self, key, time_now_s, rate_hz, burst_count, update=True):
allowed, time_allowed = self.can_do_action(
key, time_now_s, rate_hz, burst_count, update
key, time_now_s, update=update, rate_hz=rate_hz, burst_count=burst_count
)
if not allowed:

View File

@@ -12,11 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from ._base import Config
class RateLimitConfig(object):
def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}):
def __init__(
self,
config: Dict[str, float],
defaults={"per_second": 0.17, "burst_count": 3.0},
):
self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = config.get("burst_count", defaults["burst_count"])

View File

@@ -19,8 +19,8 @@ from twisted.internet import defer
import synapse.types
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import LimitExceededError
from synapse.types import UserID
from synapse.api.ratelimiting import Ratelimiter
logger = logging.getLogger(__name__)
@@ -44,11 +44,24 @@ class BaseHandler(object):
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.ratelimiter = hs.get_ratelimiter()
self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter()
self.clock = hs.get_clock()
self.hs = hs
# The rate_hz and burst_count are overridden on a per-user basis
self.request_ratelimiter = Ratelimiter(clock=self.clock, rate_hz=0, burst_count=0)
self._rc_message = self.hs.config.rc_message
# Check whether ratelimiting room admin message redaction is enabled
# by the presence of rate limits in the config
if self.hs.config.rc_admin_redaction:
self.admin_redaction_ratelimiter = Ratelimiter(
clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count,
)
else:
self.admin_redaction_ratelimiter = None
self.server_name = hs.hostname
self.event_builder_factory = hs.get_event_builder_factory()
@@ -83,47 +96,32 @@ class BaseHandler(object):
if requester.app_service and not requester.app_service.is_rate_limited():
return
messages_per_second = self._rc_message.per_second
burst_count = self._rc_message.burst_count
# Check if there is a per user override in the DB.
override = yield self.store.get_ratelimit_for_user(user_id)
if override:
# If overriden with a null Hz then ratelimiting has been entirely
# If overridden with a null Hz then ratelimiting has been entirely
# disabled for the user
if not override.messages_per_second:
return
messages_per_second = override.messages_per_second
burst_count = override.burst_count
else:
# We default to different values if this is an admin redaction and
# the config is set
if is_admin_redaction and self.hs.config.rc_admin_redaction:
messages_per_second = self.hs.config.rc_admin_redaction.per_second
burst_count = self.hs.config.rc_admin_redaction.burst_count
else:
messages_per_second = self.hs.config.rc_message.per_second
burst_count = self.hs.config.rc_message.burst_count
if is_admin_redaction and self.hs.config.rc_admin_redaction:
# If we have separate config for admin redactions we use a separate
# ratelimiter
allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action(
user_id,
time_now,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)
if is_admin_redaction and self.admin_redaction_ratelimiter:
# If we have separate config for admin redactions, use a separate
# ratelimiter as to not have user_ids clash
self.admin_redaction_ratelimiter.ratelimit(user_id, time_now, update)
else:
allowed, time_allowed = self.ratelimiter.can_do_action(
# Override rate and burst count per-user
self.request_ratelimiter.ratelimit(
user_id,
time_now,
update,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now))
)
async def maybe_kick_guest_users(self, event, context=None):

View File

@@ -108,7 +108,12 @@ class AuthHandler(BaseHandler):
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter()
# XXX: Should this be hs.get_login_failed_attempts_ratelimiter?
self._failed_uia_attempts_ratelimiter = Ratelimiter(
clock=self.clock,
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
self._clock = self.hs.get_clock()
@@ -196,13 +201,7 @@ class AuthHandler(BaseHandler):
user_id = requester.user.to_string()
# Check if we should be ratelimited due to too many previous failed attempts
self._failed_uia_attempts_ratelimiter.ratelimit(
user_id,
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
)
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows
flows = [[login_type] for login_type in self._supported_ui_auth_types]
@@ -212,14 +211,8 @@ class AuthHandler(BaseHandler):
flows, request, request_body, clientip, description
)
except LoginError:
# Update the ratelimite to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action(
user_id,
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
)
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action(user_id)
raise
# find the completed login type

View File

@@ -362,7 +362,6 @@ class EventCreationHandler(object):
self.profile_handler = hs.get_profile_handler()
self.event_builder_factory = hs.get_event_builder_factory()
self.server_name = hs.hostname
self.ratelimiter = hs.get_ratelimiter()
self.notifier = hs.get_notifier()
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases

View File

@@ -427,12 +427,7 @@ class RegistrationHandler(BaseHandler):
time_now = self.clock.time()
self.ratelimiter.ratelimit(
address,
time_now_s=time_now,
rate_hz=self.hs.config.rc_registration.per_second,
burst_count=self.hs.config.rc_registration.burst_count,
)
self.ratelimiter.ratelimit(address)
def register_with_store(
self,

View File

@@ -16,7 +16,6 @@
import logging
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -28,6 +27,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.api.ratelimiting import Ratelimiter
logger = logging.getLogger(__name__)
@@ -87,11 +87,28 @@ class LoginRestServlet(RestServlet):
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
self._clock = hs.get_clock()
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter()
self._account_ratelimiter = Ratelimiter()
self._failed_attempts_ratelimiter = Ratelimiter()
self._address_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
)
self._account_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
)
print(
"Creating fail ratelimiter: %s %s" % (
self.hs.config.rc_login_failed_attempts.per_second,
self.hs.config.rc_login_failed_attempts.burst_count,
),
)
self._failed_attempts_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
def on_GET(self, request):
flows = []
@@ -129,13 +146,7 @@ class LoginRestServlet(RestServlet):
return 200, {}
async def on_POST(self, request):
self._address_ratelimiter.ratelimit(
request.getClientIP(),
time_now_s=self.hs.clock.time(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
update=True,
)
self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
try:
@@ -203,13 +214,7 @@ class LoginRestServlet(RestServlet):
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
self._failed_attempts_ratelimiter.ratelimit(
(medium, address),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
)
self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
# Check for login providers that support 3pid login types
(
@@ -243,13 +248,7 @@ class LoginRestServlet(RestServlet):
# If it returned None but the 3PID was bound then we won't hit
# this code path, which is fine as then the per-user ratelimit
# will kick in below.
self._failed_attempts_ratelimiter.can_do_action(
(medium, address),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
)
self._failed_attempts_ratelimiter.can_do_action((medium, address))
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {"type": "m.id.user", "user": user_id}
@@ -267,13 +266,7 @@ class LoginRestServlet(RestServlet):
qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
# Check if we've hit the failed ratelimit (but don't update it)
self._failed_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
)
self._failed_attempts_ratelimiter.ratelimit(qualified_user_id.lower(), update=False)
try:
canonical_user_id, callback = await self.auth_handler.validate_login(
@@ -284,13 +277,7 @@ class LoginRestServlet(RestServlet):
# limiter. Using `can_do_action` avoids us raising a ratelimit
# exception and masking the LoginError. The actual ratelimiting
# should have happened above.
self._failed_attempts_ratelimiter.can_do_action(
qualified_user_id.lower(),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
)
self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
raise
result = await self._complete_login(
@@ -323,13 +310,7 @@ class LoginRestServlet(RestServlet):
# Before we actually log them in we check if they've already logged in
# too often. This happens here rather than before as we don't
# necessarily know the user before now.
self._account_ratelimiter.ratelimit(
user_id.lower(),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
update=True,
)
self._account_ratelimiter.ratelimit(user_id.lower())
if create_non_existant_users:
user_id = await self.auth_handler.check_user_exists(user_id)

View File

@@ -396,20 +396,7 @@ class RegisterRestServlet(RestServlet):
client_addr = request.getClientIP()
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.can_do_action(
client_addr,
time_now_s=time_now,
rate_hz=self.hs.config.rc_registration.per_second,
burst_count=self.hs.config.rc_registration.burst_count,
update=False,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now))
)
self.ratelimiter.ratelimit(client_addr, update=False)
kind = b"user"
if b"kind" in request.args:

View File

@@ -24,6 +24,7 @@
import abc
import logging
import os
from typing import Optional
from twisted.mail.smtp import sendmail
@@ -242,9 +243,12 @@ class HomeServer(object):
self.clock = Clock(reactor)
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
self.admin_redaction_ratelimiter = Ratelimiter()
self.registration_ratelimiter = Ratelimiter()
self.registration_ratelimiter = Ratelimiter(
clock=self.clock,
rate_hz=config.rc_registration.per_second,
burst_count=config.rc_registration.burst_count,
)
self.datastores = None
@@ -314,15 +318,9 @@ class HomeServer(object):
def get_distributor(self):
return self.distributor
def get_ratelimiter(self):
return self.ratelimiter
def get_registration_ratelimiter(self):
def get_registration_ratelimiter(self) -> Ratelimiter:
return self.registration_ratelimiter
def get_admin_redaction_ratelimiter(self):
return self.admin_redaction_ratelimiter
def build_federation_client(self):
return FederationClient(self)

View File

@@ -43,7 +43,7 @@ class FederationRateLimiter(object):
self.ratelimiters = collections.defaultdict(new_limiter)
def ratelimit(self, host):
"""Used to ratelimit an incoming request from given host
"""Used to ratelimit an incoming request from a given host
Example usage:

View File

@@ -5,35 +5,25 @@ from tests import unittest
class TestRatelimiter(unittest.TestCase):
def test_allowed(self):
limiter = Ratelimiter()
allowed, time_allowed = limiter.can_do_action(
key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
)
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=0)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_do_action(
key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
)
allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=5)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_do_action(
key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
)
allowed, time_allowed = limiter.can_do_action(key="test_id", time_now_s=10)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
def test_pruning(self):
limiter = Ratelimiter()
allowed, time_allowed = limiter.can_do_action(
key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1
)
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
_, _ = limiter.can_do_action(key="test_id_1", time_now_s=0)
self.assertIn("test_id_1", limiter.message_counts)
self.assertIn("test_id_1", limiter.actions)
allowed, time_allowed = limiter.can_do_action(
key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1
)
_, _ = limiter.can_do_action(key="test_id_2", time_now_s=10)
self.assertNotIn("test_id_1", limiter.message_counts)
self.assertNotIn("test_id_1", limiter.actions)

View File

@@ -14,12 +14,13 @@
# limitations under the License.
from mock import Mock, NonCallableMock
from mock import Mock, patch
from twisted.internet import defer
import synapse.types
from synapse.api.errors import AuthError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
@@ -55,11 +56,15 @@ class ProfileTestCase(unittest.TestCase):
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.can_do_action.return_value = (True, 0)
# Patch Ratelimiter to allow all requests
patch.object(
Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0)
)
patch.object(
Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None
)
self.store = hs.get_datastore()

View File

@@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock, NonCallableMock
from mock import Mock, patch
from synapse.api.ratelimiting import Ratelimiter
from tests.replication._base import BaseStreamTestCase
@@ -23,10 +24,15 @@ class BaseSlavedStoreTestCase(BaseStreamTestCase):
hs = self.setup_test_homeserver(
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
hs.get_ratelimiter().can_do_action.return_value = (True, 0)
# Patch Ratelimiter to allow all requests
patch.object(
Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0)
)
patch.object(
Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None
)
return hs

View File

@@ -15,10 +15,11 @@
""" Tests REST events for /events paths."""
from mock import Mock, NonCallableMock
from mock import Mock, patch
import synapse.rest.admin
from synapse.rest.client.v1 import events, login, room
from synapse.api.ratelimiting import Ratelimiter
from tests import unittest
@@ -41,10 +42,16 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
config["auto_join_rooms"] = []
hs = self.setup_test_homeserver(
config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
config=config,
)
# Patch Ratelimiter to allow all requests
patch.object(
Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0)
)
patch.object(
Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None
)
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.can_do_action.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock()

View File

@@ -7,6 +7,7 @@ import synapse.rest.admin
from synapse.rest.client.v1 import login, logout
from synapse.rest.client.v2_alpha import devices
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
from synapse.api.ratelimiting import Ratelimiter
from tests import unittest
from tests.unittest import override_config
@@ -26,7 +27,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor, clock):
self.hs = self.setup_test_homeserver()
self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = []
@@ -35,10 +35,17 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
return self.hs
@override_config(
{
"rc_login": {
"account": {
"per_second": 0.17,
"burst_count": 5,
}
}
}
)
def test_POST_ratelimiting_per_address(self):
self.hs.config.rc_login_address.burst_count = 5
self.hs.config.rc_login_address.per_second = 0.17
# Create different users so we're sure not to be bothered by the per-user
# ratelimiter.
for i in range(0, 6):
@@ -77,10 +84,17 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config(
{
"rc_login": {
"account": {
"per_second": 0.17,
"burst_count": 5,
}
}
}
)
def test_POST_ratelimiting_per_account(self):
self.hs.config.rc_login_account.burst_count = 5
self.hs.config.rc_login_account.per_second = 0.17
self.register_user("kermit", "monkey")
for i in range(0, 6):
@@ -116,10 +130,23 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
@override_config(
{
"rc_login": {
# Prevent the generic login ratelimiter from raising first
"address": {
"per_second": 1000,
"burst_count": 1000,
},
"failed_attempts": {
"per_second": 0.17,
"burst_count": 5,
}
}
}
)
@unittest.DEBUG
def test_POST_ratelimiting_per_account_failed_attempts(self):
self.hs.config.rc_login_failed_attempts.burst_count = 5
self.hs.config.rc_login_failed_attempts.per_second = 0.17
self.register_user("kermit", "monkey")
for i in range(0, 6):
@@ -128,8 +155,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
if i == 5:
@@ -149,7 +175,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)

View File

@@ -20,12 +20,13 @@
import json
from mock import Mock, NonCallableMock
from mock import Mock, patch
from six.moves.urllib import parse as urlparse
from twisted.internet import defer
import synapse.rest.admin
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
@@ -49,10 +50,15 @@ class RoomBase(unittest.HomeserverTestCase):
"red",
http_client=None,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
self.ratelimiter = self.hs.get_ratelimiter()
self.ratelimiter.can_do_action.return_value = (True, 0)
# Patch Ratelimiter to allow all requests
patch.object(
Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0)
)
patch.object(
Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None
)
self.hs.get_federation_handler = Mock(return_value=Mock())

View File

@@ -16,12 +16,13 @@
"""Tests REST events for /rooms paths."""
from mock import Mock, NonCallableMock
from mock import Mock, NonCallableMock, patch
from twisted.internet import defer
from synapse.rest.client.v1 import room
from synapse.types import UserID
from synapse.api.ratelimiting import Ratelimiter
from tests import unittest
@@ -42,14 +43,18 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
"red",
http_client=None,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
)
# Patch Ratelimiter to allow all requests
patch.object(
Ratelimiter, "can_do_action", new_callable=lambda *args, **kwargs: (True, 0.0)
)
patch.object(
Ratelimiter, "ratelimit", new_callable=lambda *args, **kwargs: None
)
self.event_source = hs.get_event_sources().sources["typing"]
self.ratelimiter = hs.get_ratelimiter()
self.ratelimiter.can_do_action.return_value = (True, 0)
hs.get_handlers().federation_handler = Mock()
def get_user_by_access_token(token=None, allow_guest=False):

View File

@@ -147,8 +147,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
def test_POST_ratelimiting_guest(self):
self.hs.config.rc_registration.burst_count = 5
self.hs.config.rc_registration.per_second = 0.17
self.hs.get_registration_ratelimiter().burst_count = 5
self.hs.get_registration_ratelimiter().rate_hz = 0.17
for i in range(0, 6):
url = self.url + b"?kind=guest"
@@ -169,8 +169,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
def test_POST_ratelimiting(self):
self.hs.config.rc_registration.burst_count = 5
self.hs.config.rc_registration.per_second = 0.17
self.hs.get_registration_ratelimiter().burst_count = 5
self.hs.get_registration_ratelimiter().rate_hz = 0.17
for i in range(0, 6):
params = {