1
0

Replace instances of request.getClientAddress with new method

This commit is contained in:
Andrew Morgan
2025-10-01 16:29:14 +01:00
parent 4e333c310a
commit d912558c4f
10 changed files with 41 additions and 30 deletions

View File

@@ -359,7 +359,8 @@ class BaseAuth:
return None
if app_service.ip_range_whitelist:
ip_address = IPAddress(request.getClientAddress().host)
ip_address_str = self.get_ip_address_from_request(request)
ip_address = IPAddress(ip_address_str)
if ip_address not in app_service.ip_range_whitelist:
return None

View File

@@ -567,7 +567,7 @@ class AuthHandler:
await self.store.set_ui_auth_clientdict(sid, clientdict)
user_agent = get_request_user_agent(request)
clientip = request.getClientAddress().host
clientip = self.auth.get_ip_address_from_request(request)
await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip

View File

@@ -57,6 +57,7 @@ id_server_scheme = "https://"
class IdentityHandler:
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self.store = hs.get_datastores().main
# An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs)
@@ -97,9 +98,8 @@ class IdentityHandler:
address: The actual threepid ID, e.g. the phone number or email address
"""
await self._3pid_validation_ratelimiter_ip.ratelimit(
None, (medium, request.getClientAddress().host)
)
ip_address = self._auth.get_ip_address_from_request(request)
await self._3pid_validation_ratelimiter_ip.ratelimit(None, (medium, ip_address))
await self._3pid_validation_ratelimiter_address.ratelimit(
None, (medium, address)
)

View File

@@ -205,6 +205,7 @@ class SsoHandler:
self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self._registration_handler = hs.get_registration_handler()
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._error_template = hs.config.sso.sso_error_template
@@ -505,12 +506,13 @@ class SsoHandler:
auth_provider_session_id,
)
ip_address = self._auth.get_ip_address_from_request(request)
user_id = await self._register_mapped_user(
attributes,
auth_provider_id,
remote_user_id,
get_request_user_agent(request),
request.getClientAddress().host,
ip_address,
)
new_user = True
elif self._sso_update_profile_information:
@@ -1080,6 +1082,8 @@ class SsoHandler:
if session.use_avatar:
attributes.picture = session.avatar_url
ip_address = self._auth.get_ip_address_from_request(request)
# the following will raise a 400 error if the username has been taken in the
# meantime.
user_id = await self._register_mapped_user(
@@ -1087,7 +1091,7 @@ class SsoHandler:
session.auth_provider_id,
session.remote_user_id,
get_request_user_agent(request),
request.getClientAddress().host,
ip_address,
)
logger.info(

View File

@@ -134,6 +134,8 @@ class AuthRestServlet(RestServlet):
if not session:
raise SynapseError(400, "No session supplied")
ip_address = self.auth.get_ip_address_from_request(request)
if stagetype == LoginType.RECAPTCHA:
response = parse_string(request, "g-recaptcha-response")
@@ -144,7 +146,9 @@ class AuthRestServlet(RestServlet):
try:
await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, request.getClientAddress().host
LoginType.RECAPTCHA,
authdict,
ip_address,
)
except LoginError as e:
# Authentication failed, let user try again
@@ -164,7 +168,9 @@ class AuthRestServlet(RestServlet):
try:
await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, request.getClientAddress().host
LoginType.TERMS,
authdict,
ip_address,
)
except LoginError as e:
# Authentication failed, let user try again
@@ -195,7 +201,7 @@ class AuthRestServlet(RestServlet):
await self.auth_handler.add_oob_auth(
LoginType.REGISTRATION_TOKEN,
authdict,
request.getClientAddress().host,
ip_address,
)
except LoginError as e:
html = self.registration_token_template.render(

View File

@@ -205,6 +205,7 @@ class LoginRestServlet(RestServlet):
)
request_info = request.request_info()
ip_address = self.auth.get_ip_address_from_request(request)
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
@@ -224,9 +225,7 @@ class LoginRestServlet(RestServlet):
)
if appservice.is_rate_limited():
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
await self._address_ratelimiter.ratelimit(None, ip_address)
result = await self._do_appservice_login(
login_submission,
@@ -238,27 +237,21 @@ class LoginRestServlet(RestServlet):
self.jwt_enabled
and login_submission["type"] == LoginRestServlet.JWT_TYPE
):
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
await self._address_ratelimiter.ratelimit(None, ip_address)
result = await self._do_jwt_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
await self._address_ratelimiter.ratelimit(None, ip_address)
result = await self._do_token_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,
request_info=request_info,
)
else:
await self._address_ratelimiter.ratelimit(
None, request.getClientAddress().host
)
await self._address_ratelimiter.ratelimit(None, ip_address)
result = await self._do_other_login(
login_submission,
should_issue_refresh_token=should_issue_refresh_token,

View File

@@ -192,7 +192,8 @@ class ThumbnailResource(RestServlet):
respond_404(request)
return
ip_address = request.getClientAddress().host
ip_address = self.auth.get_ip_address_from_request(request)
remote_resp_function = (
self.thumbnailer.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails
@@ -263,7 +264,8 @@ class DownloadResource(RestServlet):
request, media_id, file_name, max_timeout_ms
)
else:
ip_address = request.getClientAddress().host
ip_address = self.auth.get_ip_address_from_request(request)
await self.media_repo.get_remote_media(
request,
server_name,

View File

@@ -329,6 +329,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self._auth = hs.get_auth()
self.server_name = hs.hostname
self.registration_handler = hs.get_registration_handler()
self.ratelimiter = FederationRateLimiter(
@@ -361,7 +362,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
if self.inhibit_user_in_use_error:
return 200, {"available": True}
ip = request.getClientAddress().host
ip = self._auth.get_ip_address_from_request(request)
with self.ratelimiter.ratelimit(ip) as wait_deferred:
await wait_deferred
@@ -395,6 +396,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self._auth = hs.get_auth()
self.store = hs.get_datastores().main
self.ratelimiter = Ratelimiter(
store=self.store,
@@ -403,7 +405,8 @@ class RegistrationTokenValidityRestServlet(RestServlet):
)
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
await self.ratelimiter.ratelimit(None, (request.getClientAddress().host,))
ip_address = self._auth.get_ip_address_from_request(request)
await self.ratelimiter.ratelimit(None, (ip_address,))
if not self.hs.config.registration.enable_registration:
raise SynapseError(
@@ -456,7 +459,7 @@ class RegisterRestServlet(RestServlet):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
client_addr = request.getClientAddress().host
client_addr = self.auth.get_ip_address_from_request(request)
await self.ratelimiter.ratelimit(None, client_addr, update=False)
@@ -916,7 +919,7 @@ class RegisterAppServiceOnlyRestServlet(RestServlet):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
client_addr = request.getClientAddress().host
client_addr = self.auth.get_ip_address_from_request(request)
await self.ratelimiter.ratelimit(None, client_addr, update=False)

View File

@@ -49,6 +49,7 @@ class DownloadResource(RestServlet):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self._auth = hs.get_auth()
self.media_repo = media_repo
self._is_mine_server_name = hs.is_mine_server_name
@@ -97,7 +98,7 @@ class DownloadResource(RestServlet):
respond_404(request)
return
ip_address = request.getClientAddress().host
ip_address = self._auth.get_ip_address_from_request(request)
await self.media_repo.get_remote_media(
request,
server_name,

View File

@@ -58,6 +58,7 @@ class ThumbnailResource(RestServlet):
):
super().__init__()
self._auth = hs.get_auth()
self.store = hs.get_datastores().main
self.media_repo = media_repo
self.media_storage = media_storage
@@ -120,7 +121,7 @@ class ThumbnailResource(RestServlet):
respond_404(request)
return
ip_address = request.getClientAddress().host
ip_address = self._auth.get_ip_address_from_request(request)
remote_resp_function = (
self.thumbnail_provider.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails