Replace instances of request.getClientAddress with new method
This commit is contained in:
@@ -359,7 +359,8 @@ class BaseAuth:
|
||||
return None
|
||||
|
||||
if app_service.ip_range_whitelist:
|
||||
ip_address = IPAddress(request.getClientAddress().host)
|
||||
ip_address_str = self.get_ip_address_from_request(request)
|
||||
ip_address = IPAddress(ip_address_str)
|
||||
if ip_address not in app_service.ip_range_whitelist:
|
||||
return None
|
||||
|
||||
|
||||
@@ -567,7 +567,7 @@ class AuthHandler:
|
||||
await self.store.set_ui_auth_clientdict(sid, clientdict)
|
||||
|
||||
user_agent = get_request_user_agent(request)
|
||||
clientip = request.getClientAddress().host
|
||||
clientip = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
await self.store.add_user_agent_ip_to_ui_auth_session(
|
||||
session.session_id, user_agent, clientip
|
||||
|
||||
@@ -57,6 +57,7 @@ id_server_scheme = "https://"
|
||||
|
||||
class IdentityHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
# An HTTP client for contacting trusted URLs.
|
||||
self.http_client = SimpleHttpClient(hs)
|
||||
@@ -97,9 +98,8 @@ class IdentityHandler:
|
||||
address: The actual threepid ID, e.g. the phone number or email address
|
||||
"""
|
||||
|
||||
await self._3pid_validation_ratelimiter_ip.ratelimit(
|
||||
None, (medium, request.getClientAddress().host)
|
||||
)
|
||||
ip_address = self._auth.get_ip_address_from_request(request)
|
||||
await self._3pid_validation_ratelimiter_ip.ratelimit(None, (medium, ip_address))
|
||||
await self._3pid_validation_ratelimiter_address.ratelimit(
|
||||
None, (medium, address)
|
||||
)
|
||||
|
||||
@@ -205,6 +205,7 @@ class SsoHandler:
|
||||
self.server_name = hs.hostname
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
self._auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
self._error_template = hs.config.sso.sso_error_template
|
||||
@@ -505,12 +506,13 @@ class SsoHandler:
|
||||
auth_provider_session_id,
|
||||
)
|
||||
|
||||
ip_address = self._auth.get_ip_address_from_request(request)
|
||||
user_id = await self._register_mapped_user(
|
||||
attributes,
|
||||
auth_provider_id,
|
||||
remote_user_id,
|
||||
get_request_user_agent(request),
|
||||
request.getClientAddress().host,
|
||||
ip_address,
|
||||
)
|
||||
new_user = True
|
||||
elif self._sso_update_profile_information:
|
||||
@@ -1080,6 +1082,8 @@ class SsoHandler:
|
||||
if session.use_avatar:
|
||||
attributes.picture = session.avatar_url
|
||||
|
||||
ip_address = self._auth.get_ip_address_from_request(request)
|
||||
|
||||
# the following will raise a 400 error if the username has been taken in the
|
||||
# meantime.
|
||||
user_id = await self._register_mapped_user(
|
||||
@@ -1087,7 +1091,7 @@ class SsoHandler:
|
||||
session.auth_provider_id,
|
||||
session.remote_user_id,
|
||||
get_request_user_agent(request),
|
||||
request.getClientAddress().host,
|
||||
ip_address,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -134,6 +134,8 @@ class AuthRestServlet(RestServlet):
|
||||
if not session:
|
||||
raise SynapseError(400, "No session supplied")
|
||||
|
||||
ip_address = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
if stagetype == LoginType.RECAPTCHA:
|
||||
response = parse_string(request, "g-recaptcha-response")
|
||||
|
||||
@@ -144,7 +146,9 @@ class AuthRestServlet(RestServlet):
|
||||
|
||||
try:
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.RECAPTCHA, authdict, request.getClientAddress().host
|
||||
LoginType.RECAPTCHA,
|
||||
authdict,
|
||||
ip_address,
|
||||
)
|
||||
except LoginError as e:
|
||||
# Authentication failed, let user try again
|
||||
@@ -164,7 +168,9 @@ class AuthRestServlet(RestServlet):
|
||||
|
||||
try:
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.TERMS, authdict, request.getClientAddress().host
|
||||
LoginType.TERMS,
|
||||
authdict,
|
||||
ip_address,
|
||||
)
|
||||
except LoginError as e:
|
||||
# Authentication failed, let user try again
|
||||
@@ -195,7 +201,7 @@ class AuthRestServlet(RestServlet):
|
||||
await self.auth_handler.add_oob_auth(
|
||||
LoginType.REGISTRATION_TOKEN,
|
||||
authdict,
|
||||
request.getClientAddress().host,
|
||||
ip_address,
|
||||
)
|
||||
except LoginError as e:
|
||||
html = self.registration_token_template.render(
|
||||
|
||||
@@ -205,6 +205,7 @@ class LoginRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
request_info = request.request_info()
|
||||
ip_address = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
try:
|
||||
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
|
||||
@@ -224,9 +225,7 @@ class LoginRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
if appservice.is_rate_limited():
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
await self._address_ratelimiter.ratelimit(None, ip_address)
|
||||
|
||||
result = await self._do_appservice_login(
|
||||
login_submission,
|
||||
@@ -238,27 +237,21 @@ class LoginRestServlet(RestServlet):
|
||||
self.jwt_enabled
|
||||
and login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||
):
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
await self._address_ratelimiter.ratelimit(None, ip_address)
|
||||
result = await self._do_jwt_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
request_info=request_info,
|
||||
)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
await self._address_ratelimiter.ratelimit(None, ip_address)
|
||||
result = await self._do_token_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
request_info=request_info,
|
||||
)
|
||||
else:
|
||||
await self._address_ratelimiter.ratelimit(
|
||||
None, request.getClientAddress().host
|
||||
)
|
||||
await self._address_ratelimiter.ratelimit(None, ip_address)
|
||||
result = await self._do_other_login(
|
||||
login_submission,
|
||||
should_issue_refresh_token=should_issue_refresh_token,
|
||||
|
||||
@@ -192,7 +192,8 @@ class ThumbnailResource(RestServlet):
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
ip_address = request.getClientAddress().host
|
||||
ip_address = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
remote_resp_function = (
|
||||
self.thumbnailer.select_or_generate_remote_thumbnail
|
||||
if self.dynamic_thumbnails
|
||||
@@ -263,7 +264,8 @@ class DownloadResource(RestServlet):
|
||||
request, media_id, file_name, max_timeout_ms
|
||||
)
|
||||
else:
|
||||
ip_address = request.getClientAddress().host
|
||||
ip_address = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
await self.media_repo.get_remote_media(
|
||||
request,
|
||||
server_name,
|
||||
|
||||
@@ -329,6 +329,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self._auth = hs.get_auth()
|
||||
self.server_name = hs.hostname
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
self.ratelimiter = FederationRateLimiter(
|
||||
@@ -361,7 +362,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
|
||||
if self.inhibit_user_in_use_error:
|
||||
return 200, {"available": True}
|
||||
|
||||
ip = request.getClientAddress().host
|
||||
ip = self._auth.get_ip_address_from_request(request)
|
||||
with self.ratelimiter.ratelimit(ip) as wait_deferred:
|
||||
await wait_deferred
|
||||
|
||||
@@ -395,6 +396,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self._auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self.ratelimiter = Ratelimiter(
|
||||
store=self.store,
|
||||
@@ -403,7 +405,8 @@ class RegistrationTokenValidityRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,))
|
||||
ip_address = self._auth.get_ip_address_from_request(request)
|
||||
await self.ratelimiter.ratelimit(None, (ip_address,))
|
||||
|
||||
if not self.hs.config.registration.enable_registration:
|
||||
raise SynapseError(
|
||||
@@ -456,7 +459,7 @@ class RegisterRestServlet(RestServlet):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
client_addr = request.getClientAddress().host
|
||||
client_addr = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
await self.ratelimiter.ratelimit(None, client_addr, update=False)
|
||||
|
||||
@@ -916,7 +919,7 @@ class RegisterAppServiceOnlyRestServlet(RestServlet):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
client_addr = request.getClientAddress().host
|
||||
client_addr = self.auth.get_ip_address_from_request(request)
|
||||
|
||||
await self.ratelimiter.ratelimit(None, client_addr, update=False)
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ class DownloadResource(RestServlet):
|
||||
|
||||
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
||||
super().__init__()
|
||||
self._auth = hs.get_auth()
|
||||
self.media_repo = media_repo
|
||||
self._is_mine_server_name = hs.is_mine_server_name
|
||||
|
||||
@@ -97,7 +98,7 @@ class DownloadResource(RestServlet):
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
ip_address = request.getClientAddress().host
|
||||
ip_address = self._auth.get_ip_address_from_request(request)
|
||||
await self.media_repo.get_remote_media(
|
||||
request,
|
||||
server_name,
|
||||
|
||||
@@ -58,6 +58,7 @@ class ThumbnailResource(RestServlet):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self.media_repo = media_repo
|
||||
self.media_storage = media_storage
|
||||
@@ -120,7 +121,7 @@ class ThumbnailResource(RestServlet):
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
ip_address = request.getClientAddress().host
|
||||
ip_address = self._auth.get_ip_address_from_request(request)
|
||||
remote_resp_function = (
|
||||
self.thumbnail_provider.select_or_generate_remote_thumbnail
|
||||
if self.dynamic_thumbnails
|
||||
|
||||
Reference in New Issue
Block a user