Modify servlets to pull Ratelimiters from HomeServer class
This commit is contained in:
@@ -14,9 +14,15 @@
|
||||
|
||||
from ._base import Config
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class RateLimitConfig(object):
|
||||
def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}):
|
||||
def __init__(
|
||||
self,
|
||||
config: Dict[str, float],
|
||||
defaults={"per_second": 0.17, "burst_count": 3.0},
|
||||
):
|
||||
self.per_second = config.get("per_second", defaults["per_second"])
|
||||
self.burst_count = config.get("burst_count", defaults["burst_count"])
|
||||
|
||||
|
||||
+20
-34
@@ -19,7 +19,6 @@ from twisted.internet import defer
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import LimitExceededError
|
||||
from synapse.types import UserID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,11 +43,16 @@ class BaseHandler(object):
|
||||
self.notifier = hs.get_notifier()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.distributor = hs.get_distributor()
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter()
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
||||
self.ratelimiter = None
|
||||
self.request_ratelimiter = hs.get_request_ratelimiter()
|
||||
self._rc_message = self.hs.config.rc_message
|
||||
|
||||
# If special admin redaction ratelimiting is disabled, this will be None
|
||||
self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter()
|
||||
|
||||
self.server_name = hs.hostname
|
||||
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
@@ -83,48 +87,30 @@ class BaseHandler(object):
|
||||
if requester.app_service and not requester.app_service.is_rate_limited():
|
||||
return
|
||||
|
||||
messages_per_second = self._rc_message.per_second
|
||||
burst_count = self._rc_message.burst_count
|
||||
|
||||
# Check if there is a per user override in the DB.
|
||||
override = yield self.store.get_ratelimit_for_user(user_id)
|
||||
if override:
|
||||
# If overriden with a null Hz then ratelimiting has been entirely
|
||||
# If overridden with a null Hz then ratelimiting has been entirely
|
||||
# disabled for the user
|
||||
if not override.messages_per_second:
|
||||
return
|
||||
|
||||
messages_per_second = override.messages_per_second
|
||||
burst_count = override.burst_count
|
||||
else:
|
||||
# We default to different values if this is an admin redaction and
|
||||
# the config is set
|
||||
if is_admin_redaction and self.hs.config.rc_admin_redaction:
|
||||
messages_per_second = self.hs.config.rc_admin_redaction.per_second
|
||||
burst_count = self.hs.config.rc_admin_redaction.burst_count
|
||||
else:
|
||||
messages_per_second = self.hs.config.rc_message.per_second
|
||||
burst_count = self.hs.config.rc_message.burst_count
|
||||
|
||||
if is_admin_redaction and self.hs.config.rc_admin_redaction:
|
||||
# If we have separate config for admin redactions we use a separate
|
||||
# ratelimiter
|
||||
allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action(
|
||||
user_id,
|
||||
time_now,
|
||||
rate_hz=messages_per_second,
|
||||
burst_count=burst_count,
|
||||
update=update,
|
||||
)
|
||||
if is_admin_redaction and self.admin_redaction_ratelimiter:
|
||||
# If we have separate config for admin redactions, use a separate
|
||||
# ratelimiter as to not have user_id's clash
|
||||
self.admin_redaction_ratelimiter.ratelimit(user_id, time_now, update)
|
||||
else:
|
||||
allowed, time_allowed = self.ratelimiter.can_do_action(
|
||||
user_id,
|
||||
time_now,
|
||||
rate_hz=messages_per_second,
|
||||
burst_count=burst_count,
|
||||
update=update,
|
||||
)
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now))
|
||||
)
|
||||
# Override rate and burst count per-user
|
||||
self.request_ratelimiter.rate_hz = messages_per_second
|
||||
self.request_ratelimiter.burst_count = burst_count
|
||||
|
||||
self.request_ratelimiter.ratelimit(user_id, time_now, update)
|
||||
|
||||
async def maybe_kick_guest_users(self, event, context=None):
|
||||
# Technically this function invalidates current_state by changing it.
|
||||
|
||||
@@ -108,7 +108,11 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
|
||||
# as per `rc_login.failed_attempts`.
|
||||
self._failed_uia_attempts_ratelimiter = Ratelimiter()
|
||||
# XXX: Should this be hs.get_login_failed_attempts_ratelimiter?
|
||||
self._failed_uia_attempts_ratelimiter = Ratelimiter(
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
)
|
||||
|
||||
self._clock = self.hs.get_clock()
|
||||
|
||||
@@ -199,8 +203,6 @@ class AuthHandler(BaseHandler):
|
||||
self._failed_uia_attempts_ratelimiter.ratelimit(
|
||||
user_id,
|
||||
time_now_s=self._clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
update=False,
|
||||
)
|
||||
|
||||
@@ -216,8 +218,6 @@ class AuthHandler(BaseHandler):
|
||||
self._failed_uia_attempts_ratelimiter.can_do_action(
|
||||
user_id,
|
||||
time_now_s=self._clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
update=True,
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -362,7 +362,6 @@ class EventCreationHandler(object):
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.server_name = hs.hostname
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.config = hs.config
|
||||
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
|
||||
|
||||
@@ -430,8 +430,6 @@ class RegistrationHandler(BaseHandler):
|
||||
self.ratelimiter.ratelimit(
|
||||
address,
|
||||
time_now_s=time_now,
|
||||
rate_hz=self.hs.config.rc_registration.per_second,
|
||||
burst_count=self.hs.config.rc_registration.burst_count,
|
||||
)
|
||||
|
||||
def register_with_store(
|
||||
|
||||
@@ -89,9 +89,8 @@ class LoginRestServlet(RestServlet):
|
||||
self.handlers = hs.get_handlers()
|
||||
self._clock = hs.get_clock()
|
||||
self._well_known_builder = WellKnownBuilder(hs)
|
||||
self._address_ratelimiter = Ratelimiter()
|
||||
self._account_ratelimiter = Ratelimiter()
|
||||
self._failed_attempts_ratelimiter = Ratelimiter()
|
||||
self._account_ratelimiter = hs.get_login_ratelimiter()
|
||||
self._failed_attempts_ratelimiter = hs.get_login_failed_attempts_ratelimiter()
|
||||
|
||||
def on_GET(self, request):
|
||||
flows = []
|
||||
@@ -129,11 +128,9 @@ class LoginRestServlet(RestServlet):
|
||||
return 200, {}
|
||||
|
||||
async def on_POST(self, request):
|
||||
self._address_ratelimiter.ratelimit(
|
||||
self._account_ratelimiter.ratelimit(
|
||||
request.getClientIP(),
|
||||
time_now_s=self.hs.clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_address.per_second,
|
||||
burst_count=self.hs.config.rc_login_address.burst_count,
|
||||
update=True,
|
||||
)
|
||||
|
||||
@@ -206,8 +203,6 @@ class LoginRestServlet(RestServlet):
|
||||
self._failed_attempts_ratelimiter.ratelimit(
|
||||
(medium, address),
|
||||
time_now_s=self._clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
update=False,
|
||||
)
|
||||
|
||||
@@ -246,8 +241,6 @@ class LoginRestServlet(RestServlet):
|
||||
self._failed_attempts_ratelimiter.can_do_action(
|
||||
(medium, address),
|
||||
time_now_s=self._clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
update=True,
|
||||
)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
@@ -270,8 +263,6 @@ class LoginRestServlet(RestServlet):
|
||||
self._failed_attempts_ratelimiter.ratelimit(
|
||||
qualified_user_id.lower(),
|
||||
time_now_s=self._clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
update=False,
|
||||
)
|
||||
|
||||
@@ -287,8 +278,6 @@ class LoginRestServlet(RestServlet):
|
||||
self._failed_attempts_ratelimiter.can_do_action(
|
||||
qualified_user_id.lower(),
|
||||
time_now_s=self._clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
update=True,
|
||||
)
|
||||
raise
|
||||
@@ -326,8 +315,6 @@ class LoginRestServlet(RestServlet):
|
||||
self._account_ratelimiter.ratelimit(
|
||||
user_id.lower(),
|
||||
time_now_s=self._clock.time(),
|
||||
rate_hz=self.hs.config.rc_login_account.per_second,
|
||||
burst_count=self.hs.config.rc_login_account.burst_count,
|
||||
update=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -401,8 +401,6 @@ class RegisterRestServlet(RestServlet):
|
||||
allowed, time_allowed = self.ratelimiter.can_do_action(
|
||||
client_addr,
|
||||
time_now_s=time_now,
|
||||
rate_hz=self.hs.config.rc_registration.per_second,
|
||||
burst_count=self.hs.config.rc_registration.burst_count,
|
||||
update=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ class FederationRateLimiter(object):
|
||||
self.ratelimiters = collections.defaultdict(new_limiter)
|
||||
|
||||
def ratelimit(self, host):
|
||||
"""Used to ratelimit an incoming request from given host
|
||||
"""Used to ratelimit an incoming request from a given host
|
||||
|
||||
Example usage:
|
||||
|
||||
|
||||
Reference in New Issue
Block a user