1
0

Modify servlets to pull Ratelimiters from HomeServer class

This commit is contained in:
Andrew Morgan
2020-05-28 22:38:26 +01:00
parent 0e6ee7ca17
commit 82eac22286
8 changed files with 36 additions and 62 deletions
+7 -1
View File
@@ -14,9 +14,15 @@
from ._base import Config
from typing import Dict
class RateLimitConfig(object):
def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}):
def __init__(
self,
config: Dict[str, float],
defaults={"per_second": 0.17, "burst_count": 3.0},
):
self.per_second = config.get("per_second", defaults["per_second"])
self.burst_count = config.get("burst_count", defaults["burst_count"])
+20 -34
View File
@@ -19,7 +19,6 @@ from twisted.internet import defer
import synapse.types
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import LimitExceededError
from synapse.types import UserID
logger = logging.getLogger(__name__)
@@ -44,11 +43,16 @@ class BaseHandler(object):
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.ratelimiter = hs.get_ratelimiter()
self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter()
self.clock = hs.get_clock()
self.hs = hs
self.ratelimiter = None
self.request_ratelimiter = hs.get_request_ratelimiter()
self._rc_message = self.hs.config.rc_message
# If special admin redaction ratelimiting is disabled, this will be None
self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter()
self.server_name = hs.hostname
self.event_builder_factory = hs.get_event_builder_factory()
@@ -83,48 +87,30 @@ class BaseHandler(object):
if requester.app_service and not requester.app_service.is_rate_limited():
return
messages_per_second = self._rc_message.per_second
burst_count = self._rc_message.burst_count
# Check if there is a per user override in the DB.
override = yield self.store.get_ratelimit_for_user(user_id)
if override:
# If overriden with a null Hz then ratelimiting has been entirely
# If overridden with a null Hz then ratelimiting has been entirely
# disabled for the user
if not override.messages_per_second:
return
messages_per_second = override.messages_per_second
burst_count = override.burst_count
else:
# We default to different values if this is an admin redaction and
# the config is set
if is_admin_redaction and self.hs.config.rc_admin_redaction:
messages_per_second = self.hs.config.rc_admin_redaction.per_second
burst_count = self.hs.config.rc_admin_redaction.burst_count
else:
messages_per_second = self.hs.config.rc_message.per_second
burst_count = self.hs.config.rc_message.burst_count
if is_admin_redaction and self.hs.config.rc_admin_redaction:
# If we have separate config for admin redactions we use a separate
# ratelimiter
allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action(
user_id,
time_now,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)
if is_admin_redaction and self.admin_redaction_ratelimiter:
# If we have separate config for admin redactions, use a separate
# ratelimiter as to not have user_id's clash
self.admin_redaction_ratelimiter.ratelimit(user_id, time_now, update)
else:
allowed, time_allowed = self.ratelimiter.can_do_action(
user_id,
time_now,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now))
)
# Override rate and burst count per-user
self.request_ratelimiter.rate_hz = messages_per_second
self.request_ratelimiter.burst_count = burst_count
self.request_ratelimiter.ratelimit(user_id, time_now, update)
async def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it.
+5 -5
View File
@@ -108,7 +108,11 @@ class AuthHandler(BaseHandler):
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter()
# XXX: Should this be hs.get_login_failed_attempts_ratelimiter?
self._failed_uia_attempts_ratelimiter = Ratelimiter(
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
self._clock = self.hs.get_clock()
@@ -199,8 +203,6 @@ class AuthHandler(BaseHandler):
self._failed_uia_attempts_ratelimiter.ratelimit(
user_id,
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
)
@@ -216,8 +218,6 @@ class AuthHandler(BaseHandler):
self._failed_uia_attempts_ratelimiter.can_do_action(
user_id,
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
)
raise
-1
View File
@@ -362,7 +362,6 @@ class EventCreationHandler(object):
self.profile_handler = hs.get_profile_handler()
self.event_builder_factory = hs.get_event_builder_factory()
self.server_name = hs.hostname
self.ratelimiter = hs.get_ratelimiter()
self.notifier = hs.get_notifier()
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
-2
View File
@@ -430,8 +430,6 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter.ratelimit(
address,
time_now_s=time_now,
rate_hz=self.hs.config.rc_registration.per_second,
burst_count=self.hs.config.rc_registration.burst_count,
)
def register_with_store(
+3 -16
View File
@@ -89,9 +89,8 @@ class LoginRestServlet(RestServlet):
self.handlers = hs.get_handlers()
self._clock = hs.get_clock()
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter()
self._account_ratelimiter = Ratelimiter()
self._failed_attempts_ratelimiter = Ratelimiter()
self._account_ratelimiter = hs.get_login_ratelimiter()
self._failed_attempts_ratelimiter = hs.get_login_failed_attempts_ratelimiter()
def on_GET(self, request):
flows = []
@@ -129,11 +128,9 @@ class LoginRestServlet(RestServlet):
return 200, {}
async def on_POST(self, request):
self._address_ratelimiter.ratelimit(
self._account_ratelimiter.ratelimit(
request.getClientIP(),
time_now_s=self.hs.clock.time(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
update=True,
)
@@ -206,8 +203,6 @@ class LoginRestServlet(RestServlet):
self._failed_attempts_ratelimiter.ratelimit(
(medium, address),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
)
@@ -246,8 +241,6 @@ class LoginRestServlet(RestServlet):
self._failed_attempts_ratelimiter.can_do_action(
(medium, address),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@@ -270,8 +263,6 @@ class LoginRestServlet(RestServlet):
self._failed_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
)
@@ -287,8 +278,6 @@ class LoginRestServlet(RestServlet):
self._failed_attempts_ratelimiter.can_do_action(
qualified_user_id.lower(),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
)
raise
@@ -326,8 +315,6 @@ class LoginRestServlet(RestServlet):
self._account_ratelimiter.ratelimit(
user_id.lower(),
time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
update=True,
)
-2
View File
@@ -401,8 +401,6 @@ class RegisterRestServlet(RestServlet):
allowed, time_allowed = self.ratelimiter.can_do_action(
client_addr,
time_now_s=time_now,
rate_hz=self.hs.config.rc_registration.per_second,
burst_count=self.hs.config.rc_registration.burst_count,
update=False,
)
+1 -1
View File
@@ -43,7 +43,7 @@ class FederationRateLimiter(object):
self.ratelimiters = collections.defaultdict(new_limiter)
def ratelimit(self, host):
"""Used to ratelimit an incoming request from given host
"""Used to ratelimit an incoming request from a given host
Example usage: