Merge commit '55da8df07' into anoa/dinsic_release_1_31_0
This commit is contained in:
1
changelog.d/9563.misc
Normal file
1
changelog.d/9563.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper.
|
||||
1
changelog.d/9587.bugfix
Normal file
1
changelog.d/9587.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Re-Activating account with admin API when local passwords are disabled.
|
||||
1
changelog.d/9590.misc
Normal file
1
changelog.d/9590.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add logging for redis connection setup.
|
||||
1
changelog.d/9591.misc
Normal file
1
changelog.d/9591.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix incorrect type hints.
|
||||
1
changelog.d/9596.misc
Normal file
1
changelog.d/9596.misc
Normal file
@@ -0,0 +1 @@
|
||||
Improve logging when processing incoming transactions.
|
||||
1
changelog.d/9597.bugfix
Normal file
1
changelog.d/9597.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages.
|
||||
@@ -17,6 +17,8 @@
|
||||
"""
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
|
||||
from twisted.internet import protocol
|
||||
|
||||
class RedisProtocol:
|
||||
def publish(self, channel: str, message: bytes): ...
|
||||
async def ping(self) -> None: ...
|
||||
@@ -52,7 +54,7 @@ def lazyConnection(
|
||||
|
||||
class ConnectionHandler: ...
|
||||
|
||||
class RedisFactory:
|
||||
class RedisFactory(protocol.ReconnectingClientFactory):
|
||||
continueTrying: bool
|
||||
handler: RedisProtocol
|
||||
pool: List[RedisProtocol]
|
||||
|
||||
@@ -164,7 +164,7 @@ class Auth:
|
||||
|
||||
async def get_user_by_req(
|
||||
self,
|
||||
request: Request,
|
||||
request: SynapseRequest,
|
||||
allow_guest: bool = False,
|
||||
rights: str = "access",
|
||||
allow_expired: bool = False,
|
||||
|
||||
@@ -113,10 +113,11 @@ class FederationServer(FederationBase):
|
||||
# with FederationHandlerRegistry.
|
||||
hs.get_directory_handler()
|
||||
|
||||
self._federation_ratelimiter = hs.get_federation_ratelimiter()
|
||||
|
||||
self._server_linearizer = Linearizer("fed_server")
|
||||
self._transaction_linearizer = Linearizer("fed_txn_handler")
|
||||
|
||||
# origins that we are currently processing a transaction from.
|
||||
# a dict from origin to txn id.
|
||||
self._active_transactions = {} # type: Dict[str, str]
|
||||
|
||||
# We cache results for transaction with the same ID
|
||||
self._transaction_resp_cache = ResponseCache(
|
||||
@@ -170,6 +171,33 @@ class FederationServer(FederationBase):
|
||||
|
||||
logger.debug("[%s] Got transaction", transaction_id)
|
||||
|
||||
# Reject malformed transactions early: reject if too many PDUs/EDUs
|
||||
if len(transaction.pdus) > 50 or ( # type: ignore
|
||||
hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
|
||||
):
|
||||
logger.info("Transaction PDU or EDU count too large. Returning 400")
|
||||
return 400, {}
|
||||
|
||||
# we only process one transaction from each origin at a time. We need to do
|
||||
# this check here, rather than in _on_incoming_transaction_inner so that we
|
||||
# don't cache the rejection in _transaction_resp_cache (so that if the txn
|
||||
# arrives again later, we can process it).
|
||||
current_transaction = self._active_transactions.get(origin)
|
||||
if current_transaction and current_transaction != transaction_id:
|
||||
logger.warning(
|
||||
"Received another txn %s from %s while still processing %s",
|
||||
transaction_id,
|
||||
origin,
|
||||
current_transaction,
|
||||
)
|
||||
return 429, {
|
||||
"errcode": Codes.UNKNOWN,
|
||||
"error": "Too many concurrent transactions",
|
||||
}
|
||||
|
||||
# CRITICAL SECTION: we must now not await until we populate _active_transactions
|
||||
# in _on_incoming_transaction_inner.
|
||||
|
||||
# We wrap in a ResponseCache so that we de-duplicate retried
|
||||
# transactions.
|
||||
return await self._transaction_resp_cache.wrap(
|
||||
@@ -183,26 +211,18 @@ class FederationServer(FederationBase):
|
||||
async def _on_incoming_transaction_inner(
|
||||
self, origin: str, transaction: Transaction, request_time: int
|
||||
) -> Tuple[int, Dict[str, Any]]:
|
||||
# Use a linearizer to ensure that transactions from a remote are
|
||||
# processed in order.
|
||||
with await self._transaction_linearizer.queue(origin):
|
||||
# We rate limit here *after* we've queued up the incoming requests,
|
||||
# so that we don't fill up the ratelimiter with blocked requests.
|
||||
#
|
||||
# This is important as the ratelimiter allows N concurrent requests
|
||||
# at a time, and only starts ratelimiting if there are more requests
|
||||
# than that being processed at a time. If we queued up requests in
|
||||
# the linearizer/response cache *after* the ratelimiting then those
|
||||
# queued up requests would count as part of the allowed limit of N
|
||||
# concurrent requests.
|
||||
with self._federation_ratelimiter.ratelimit(origin) as d:
|
||||
await d
|
||||
# CRITICAL SECTION: the first thing we must do (before awaiting) is
|
||||
# add an entry to _active_transactions.
|
||||
assert origin not in self._active_transactions
|
||||
self._active_transactions[origin] = transaction.transaction_id # type: ignore
|
||||
|
||||
result = await self._handle_incoming_transaction(
|
||||
origin, transaction, request_time
|
||||
)
|
||||
|
||||
return result
|
||||
try:
|
||||
result = await self._handle_incoming_transaction(
|
||||
origin, transaction, request_time
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
del self._active_transactions[origin]
|
||||
|
||||
async def _handle_incoming_transaction(
|
||||
self, origin: str, transaction: Transaction, request_time: int
|
||||
@@ -228,19 +248,6 @@ class FederationServer(FederationBase):
|
||||
|
||||
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
|
||||
|
||||
# Reject if PDU count > 50 or EDU count > 100
|
||||
if len(transaction.pdus) > 50 or ( # type: ignore
|
||||
hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
|
||||
):
|
||||
|
||||
logger.info("Transaction PDU or EDU count too large. Returning 400")
|
||||
|
||||
response = {}
|
||||
await self.transaction_actions.set_response(
|
||||
origin, transaction, 400, response
|
||||
)
|
||||
return 400, response
|
||||
|
||||
# We process PDUs and EDUs in parallel. This is important as we don't
|
||||
# want to block things like to device messages from reaching clients
|
||||
# behind the potentially expensive handling of PDUs.
|
||||
@@ -336,34 +343,41 @@ class FederationServer(FederationBase):
|
||||
# impose a limit to avoid going too crazy with ram/cpu.
|
||||
|
||||
async def process_pdus_for_room(room_id: str):
|
||||
logger.debug("Processing PDUs for %s", room_id)
|
||||
try:
|
||||
await self.check_server_matches_acl(origin_host, room_id)
|
||||
except AuthError as e:
|
||||
logger.warning("Ignoring PDUs for room %s from banned server", room_id)
|
||||
for pdu in pdus_by_room[room_id]:
|
||||
event_id = pdu.event_id
|
||||
pdu_results[event_id] = e.error_dict()
|
||||
return
|
||||
with nested_logging_context(room_id):
|
||||
logger.debug("Processing PDUs for %s", room_id)
|
||||
|
||||
for pdu in pdus_by_room[room_id]:
|
||||
event_id = pdu.event_id
|
||||
with pdu_process_time.time():
|
||||
with nested_logging_context(event_id):
|
||||
try:
|
||||
await self._handle_received_pdu(origin, pdu)
|
||||
pdu_results[event_id] = {}
|
||||
except FederationError as e:
|
||||
logger.warning("Error handling PDU %s: %s", event_id, e)
|
||||
pdu_results[event_id] = {"error": str(e)}
|
||||
except Exception as e:
|
||||
f = failure.Failure()
|
||||
pdu_results[event_id] = {"error": str(e)}
|
||||
logger.error(
|
||||
"Failed to handle PDU %s",
|
||||
event_id,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||
)
|
||||
try:
|
||||
await self.check_server_matches_acl(origin_host, room_id)
|
||||
except AuthError as e:
|
||||
logger.warning(
|
||||
"Ignoring PDUs for room %s from banned server", room_id
|
||||
)
|
||||
for pdu in pdus_by_room[room_id]:
|
||||
event_id = pdu.event_id
|
||||
pdu_results[event_id] = e.error_dict()
|
||||
return
|
||||
|
||||
for pdu in pdus_by_room[room_id]:
|
||||
pdu_results[pdu.event_id] = await process_pdu(pdu)
|
||||
|
||||
async def process_pdu(pdu: EventBase) -> JsonDict:
|
||||
event_id = pdu.event_id
|
||||
with pdu_process_time.time():
|
||||
with nested_logging_context(event_id):
|
||||
try:
|
||||
await self._handle_received_pdu(origin, pdu)
|
||||
return {}
|
||||
except FederationError as e:
|
||||
logger.warning("Error handling PDU %s: %s", event_id, e)
|
||||
return {"error": str(e)}
|
||||
except Exception as e:
|
||||
f = failure.Failure()
|
||||
logger.error(
|
||||
"Failed to handle PDU %s",
|
||||
event_id,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||
)
|
||||
return {"error": str(e)}
|
||||
|
||||
await concurrently_execute(
|
||||
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
|
||||
@@ -942,7 +956,9 @@ class FederationHandlerRegistry:
|
||||
self.edu_handlers = (
|
||||
{}
|
||||
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
|
||||
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
|
||||
self.query_handlers = (
|
||||
{}
|
||||
) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
|
||||
|
||||
# Map from type to instance names that we should route EDU handling to.
|
||||
# We randomly choose one instance from the list to route to for each new
|
||||
@@ -976,7 +992,7 @@ class FederationHandlerRegistry:
|
||||
self.edu_handlers[edu_type] = handler
|
||||
|
||||
def register_query_handler(
|
||||
self, query_type: str, handler: Callable[[dict], defer.Deferred]
|
||||
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
|
||||
):
|
||||
"""Sets the handler callable that will be used to handle an incoming
|
||||
federation query of the given type.
|
||||
@@ -1049,7 +1065,7 @@ class FederationHandlerRegistry:
|
||||
# Oh well, let's just log and move on.
|
||||
logger.warning("No handler registered for EDU type %s", edu_type)
|
||||
|
||||
async def on_query(self, query_type: str, args: dict):
|
||||
async def on_query(self, query_type: str, args: dict) -> JsonDict:
|
||||
handler = self.query_handlers.get(query_type)
|
||||
if handler:
|
||||
return await handler(args)
|
||||
|
||||
@@ -202,7 +202,7 @@ class FederationHandler(BaseHandler):
|
||||
or pdu.internal_metadata.is_outlier()
|
||||
)
|
||||
if already_seen:
|
||||
logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
|
||||
logger.debug("Already seen pdu")
|
||||
return
|
||||
|
||||
# do some initial sanity-checking of the event. In particular, make
|
||||
@@ -211,18 +211,14 @@ class FederationHandler(BaseHandler):
|
||||
try:
|
||||
self._sanity_check_event(pdu)
|
||||
except SynapseError as err:
|
||||
logger.warning(
|
||||
"[%s %s] Received event failed sanity checks", room_id, event_id
|
||||
)
|
||||
logger.warning("Received event failed sanity checks")
|
||||
raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
|
||||
|
||||
# If we are currently in the process of joining this room, then we
|
||||
# queue up events for later processing.
|
||||
if room_id in self.room_queues:
|
||||
logger.info(
|
||||
"[%s %s] Queuing PDU from %s for now: join in progress",
|
||||
room_id,
|
||||
event_id,
|
||||
"Queuing PDU from %s for now: join in progress",
|
||||
origin,
|
||||
)
|
||||
self.room_queues[room_id].append((pdu, origin))
|
||||
@@ -237,9 +233,7 @@ class FederationHandler(BaseHandler):
|
||||
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
|
||||
if not is_in_room:
|
||||
logger.info(
|
||||
"[%s %s] Ignoring PDU from %s as we're not in the room",
|
||||
room_id,
|
||||
event_id,
|
||||
"Ignoring PDU from %s as we're not in the room",
|
||||
origin,
|
||||
)
|
||||
return None
|
||||
@@ -251,7 +245,7 @@ class FederationHandler(BaseHandler):
|
||||
# We only backfill backwards to the min depth.
|
||||
min_depth = await self.get_min_depth_for_context(pdu.room_id)
|
||||
|
||||
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
|
||||
logger.debug("min_depth: %d", min_depth)
|
||||
|
||||
prevs = set(pdu.prev_event_ids())
|
||||
seen = await self.store.have_events_in_timeline(prevs)
|
||||
@@ -268,17 +262,13 @@ class FederationHandler(BaseHandler):
|
||||
# If we're missing stuff, ensure we only fetch stuff one
|
||||
# at a time.
|
||||
logger.info(
|
||||
"[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
|
||||
room_id,
|
||||
event_id,
|
||||
"Acquiring room lock to fetch %d missing prev_events: %s",
|
||||
len(missing_prevs),
|
||||
shortstr(missing_prevs),
|
||||
)
|
||||
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
logger.info(
|
||||
"[%s %s] Acquired room lock to fetch %d missing prev_events",
|
||||
room_id,
|
||||
event_id,
|
||||
"Acquired room lock to fetch %d missing prev_events",
|
||||
len(missing_prevs),
|
||||
)
|
||||
|
||||
@@ -298,9 +288,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
if not prevs - seen:
|
||||
logger.info(
|
||||
"[%s %s] Found all missing prev_events",
|
||||
room_id,
|
||||
event_id,
|
||||
"Found all missing prev_events",
|
||||
)
|
||||
elif missing_prevs:
|
||||
logger.info(
|
||||
@@ -338,9 +326,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
if sent_to_us_directly:
|
||||
logger.warning(
|
||||
"[%s %s] Rejecting: failed to fetch %d prev events: %s",
|
||||
room_id,
|
||||
event_id,
|
||||
"Rejecting: failed to fetch %d prev events: %s",
|
||||
len(prevs - seen),
|
||||
shortstr(prevs - seen),
|
||||
)
|
||||
@@ -416,10 +402,7 @@ class FederationHandler(BaseHandler):
|
||||
state = [event_map[e] for e in state_map.values()]
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[%s %s] Error attempting to resolve state at missing "
|
||||
"prev_events",
|
||||
room_id,
|
||||
event_id,
|
||||
"Error attempting to resolve state at missing " "prev_events",
|
||||
exc_info=True,
|
||||
)
|
||||
raise FederationError(
|
||||
@@ -456,9 +439,7 @@ class FederationHandler(BaseHandler):
|
||||
latest |= seen
|
||||
|
||||
logger.info(
|
||||
"[%s %s]: Requesting missing events between %s and %s",
|
||||
room_id,
|
||||
event_id,
|
||||
"Requesting missing events between %s and %s",
|
||||
shortstr(latest),
|
||||
event_id,
|
||||
)
|
||||
@@ -525,15 +506,11 @@ class FederationHandler(BaseHandler):
|
||||
# We failed to get the missing events, but since we need to handle
|
||||
# the case of `get_missing_events` not returning the necessary
|
||||
# events anyway, it is safe to simply log the error and continue.
|
||||
logger.warning(
|
||||
"[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
|
||||
)
|
||||
logger.warning("Failed to get prev_events: %s", e)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"[%s %s]: Got %d prev_events: %s",
|
||||
room_id,
|
||||
event_id,
|
||||
"Got %d prev_events: %s",
|
||||
len(missing_events),
|
||||
shortstr(missing_events),
|
||||
)
|
||||
@@ -544,9 +521,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
for ev in missing_events:
|
||||
logger.info(
|
||||
"[%s %s] Handling received prev_event %s",
|
||||
room_id,
|
||||
event_id,
|
||||
"Handling received prev_event %s",
|
||||
ev.event_id,
|
||||
)
|
||||
with nested_logging_context(ev.event_id):
|
||||
@@ -555,9 +530,7 @@ class FederationHandler(BaseHandler):
|
||||
except FederationError as e:
|
||||
if e.code == 403:
|
||||
logger.warning(
|
||||
"[%s %s] Received prev_event %s failed history check.",
|
||||
room_id,
|
||||
event_id,
|
||||
"Received prev_event %s failed history check.",
|
||||
ev.event_id,
|
||||
)
|
||||
else:
|
||||
@@ -709,10 +682,7 @@ class FederationHandler(BaseHandler):
|
||||
(ie, we are missing one or more prev_events), the resolved state at the
|
||||
event
|
||||
"""
|
||||
room_id = event.room_id
|
||||
event_id = event.event_id
|
||||
|
||||
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
|
||||
logger.debug("Processing event: %s", event)
|
||||
|
||||
try:
|
||||
await self._handle_new_event(origin, event, state=state)
|
||||
|
||||
@@ -34,6 +34,7 @@ from pymacaroons.exceptions import (
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from twisted.web.client import readBody
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config.oidc_config import (
|
||||
@@ -538,7 +539,7 @@ class OidcProvider:
|
||||
"""
|
||||
metadata = await self.load_metadata()
|
||||
token_endpoint = metadata.get("token_endpoint")
|
||||
headers = {
|
||||
raw_headers = {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"User-Agent": self._http_client.user_agent,
|
||||
"Accept": "application/json",
|
||||
@@ -552,10 +553,10 @@ class OidcProvider:
|
||||
body = urlencode(args, True)
|
||||
|
||||
# Fill the body/headers with credentials
|
||||
uri, headers, body = self._client_auth.prepare(
|
||||
method="POST", uri=token_endpoint, headers=headers, body=body
|
||||
uri, raw_headers, body = self._client_auth.prepare(
|
||||
method="POST", uri=token_endpoint, headers=raw_headers, body=body
|
||||
)
|
||||
headers = {k: [v] for (k, v) in headers.items()}
|
||||
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
|
||||
|
||||
# Do the actual request
|
||||
# We're not using the SimpleHttpClient util methods as we don't want to
|
||||
|
||||
@@ -39,6 +39,7 @@ from zope.interface import implementer, provider
|
||||
from OpenSSL import SSL
|
||||
from OpenSSL.SSL import VERIFY_NONE
|
||||
from twisted.internet import defer, error as twisted_error, protocol, ssl
|
||||
from twisted.internet.address import IPv4Address, IPv6Address
|
||||
from twisted.internet.interfaces import (
|
||||
IAddress,
|
||||
IHostResolution,
|
||||
@@ -56,7 +57,13 @@ from twisted.web.client import (
|
||||
)
|
||||
from twisted.web.http import PotentialDataLoss
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
|
||||
from twisted.web.iweb import (
|
||||
UNKNOWN_LENGTH,
|
||||
IAgent,
|
||||
IBodyProducer,
|
||||
IPolicyForHTTPS,
|
||||
IResponse,
|
||||
)
|
||||
|
||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
|
||||
@@ -151,16 +158,17 @@ class _IPBlacklistingResolver:
|
||||
def resolveHostName(
|
||||
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
|
||||
) -> IResolutionReceiver:
|
||||
|
||||
r = recv()
|
||||
addresses = [] # type: List[IAddress]
|
||||
|
||||
def _callback() -> None:
|
||||
r.resolutionBegan(None)
|
||||
|
||||
has_bad_ip = False
|
||||
for i in addresses:
|
||||
ip_address = IPAddress(i.host)
|
||||
for address in addresses:
|
||||
# We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
|
||||
# should go through this path.
|
||||
if not isinstance(address, (IPv4Address, IPv6Address)):
|
||||
continue
|
||||
|
||||
ip_address = IPAddress(address.host)
|
||||
|
||||
if check_against_blacklist(
|
||||
ip_address, self._ip_whitelist, self._ip_blacklist
|
||||
@@ -175,15 +183,15 @@ class _IPBlacklistingResolver:
|
||||
# request, but all we can really do from here is claim that there were no
|
||||
# valid results.
|
||||
if not has_bad_ip:
|
||||
for i in addresses:
|
||||
r.addressResolved(i)
|
||||
r.resolutionComplete()
|
||||
for address in addresses:
|
||||
recv.addressResolved(address)
|
||||
recv.resolutionComplete()
|
||||
|
||||
@provider(IResolutionReceiver)
|
||||
class EndpointReceiver:
|
||||
@staticmethod
|
||||
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
|
||||
pass
|
||||
recv.resolutionBegan(resolutionInProgress)
|
||||
|
||||
@staticmethod
|
||||
def addressResolved(address: IAddress) -> None:
|
||||
@@ -197,7 +205,7 @@ class _IPBlacklistingResolver:
|
||||
EndpointReceiver, hostname, portNumber=portNumber
|
||||
)
|
||||
|
||||
return r
|
||||
return recv
|
||||
|
||||
|
||||
@implementer(ISynapseReactor)
|
||||
@@ -346,7 +354,7 @@ class SimpleHttpClient:
|
||||
contextFactory=self.hs.get_http_client_context_factory(),
|
||||
pool=pool,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
) # type: IAgent
|
||||
|
||||
if self._ip_blacklist:
|
||||
# If we have an IP blacklist, we then install the blacklisting Agent
|
||||
@@ -868,6 +876,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
|
||||
return query_str.encode("utf8")
|
||||
|
||||
|
||||
@implementer(IPolicyForHTTPS)
|
||||
class InsecureInterceptableContextFactory(ssl.ContextFactory):
|
||||
"""
|
||||
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
|
||||
|
||||
@@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
|
||||
TCP4ClientEndpoint,
|
||||
TCP6ClientEndpoint,
|
||||
)
|
||||
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
|
||||
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
|
||||
from twisted.internet.protocol import Factory, Protocol
|
||||
from twisted.internet.tcp import Connection
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -52,7 +53,9 @@ class LogProducer:
|
||||
format: A callable to format the log record to a string.
|
||||
"""
|
||||
|
||||
transport = attr.ib(type=ITransport)
|
||||
# This is essentially ITCPTransport, but that is missing certain fields
|
||||
# (connected and registerProducer) which are part of the implementation.
|
||||
transport = attr.ib(type=Connection)
|
||||
_format = attr.ib(type=Callable[[logging.LogRecord], str])
|
||||
_buffer = attr.ib(type=deque)
|
||||
_paused = attr.ib(default=False, type=bool, init=False)
|
||||
@@ -149,8 +152,6 @@ class RemoteHandler(logging.Handler):
|
||||
if self._connection_waiter:
|
||||
return
|
||||
|
||||
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
|
||||
|
||||
def fail(failure: Failure) -> None:
|
||||
# If the Deferred was cancelled (e.g. during shutdown) do not try to
|
||||
# reconnect (this will cause an infinite loop of errors).
|
||||
@@ -163,9 +164,13 @@ class RemoteHandler(logging.Handler):
|
||||
self._connect()
|
||||
|
||||
def writer(result: Protocol) -> None:
|
||||
# Force recognising transport as a Connection and not the more
|
||||
# generic ITransport.
|
||||
transport = result.transport # type: Connection # type: ignore
|
||||
|
||||
# We have a connection. If we already have a producer, and its
|
||||
# transport is the same, just trigger a resumeProducing.
|
||||
if self._producer and result.transport is self._producer.transport:
|
||||
if self._producer and transport is self._producer.transport:
|
||||
self._producer.resumeProducing()
|
||||
self._connection_waiter = None
|
||||
return
|
||||
@@ -177,14 +182,16 @@ class RemoteHandler(logging.Handler):
|
||||
# Make a new producer and start it.
|
||||
self._producer = LogProducer(
|
||||
buffer=self._buffer,
|
||||
transport=result.transport,
|
||||
transport=transport,
|
||||
format=self.format,
|
||||
)
|
||||
result.transport.registerProducer(self._producer, True)
|
||||
transport.registerProducer(self._producer, True)
|
||||
self._producer.resumeProducing()
|
||||
self._connection_waiter = None
|
||||
|
||||
self._connection_waiter.addCallbacks(writer, fail)
|
||||
deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
|
||||
deferred.addCallbacks(writer, fail)
|
||||
self._connection_waiter = deferred
|
||||
|
||||
def _handle_pressure(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from twisted.internet.base import DelayedCall
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
from twisted.internet.interfaces import IDelayedCall
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.push import Pusher, PusherConfig, ThrottleParams
|
||||
@@ -66,7 +66,7 @@ class EmailPusher(Pusher):
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
self.email = pusher_config.pushkey
|
||||
self.timed_call = None # type: Optional[DelayedCall]
|
||||
self.timed_call = None # type: Optional[IDelayedCall]
|
||||
self.throttle_params = {} # type: Dict[str, ThrottleParams]
|
||||
self._inited = False
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
|
||||
UserIpCommand,
|
||||
UserSyncCommand,
|
||||
)
|
||||
from synapse.replication.tcp.protocol import AbstractConnection
|
||||
from synapse.replication.tcp.protocol import IReplicationConnection
|
||||
from synapse.replication.tcp.streams import (
|
||||
STREAMS_MAP,
|
||||
AccountDataStream,
|
||||
@@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
|
||||
|
||||
# the type of the entries in _command_queues_by_stream
|
||||
_StreamCommandQueue = Deque[
|
||||
Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
|
||||
Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
|
||||
]
|
||||
|
||||
|
||||
@@ -174,7 +174,7 @@ class ReplicationCommandHandler:
|
||||
|
||||
# The currently connected connections. (The list of places we need to send
|
||||
# outgoing replication commands to.)
|
||||
self._connections = [] # type: List[AbstractConnection]
|
||||
self._connections = [] # type: List[IReplicationConnection]
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_total_connections",
|
||||
@@ -197,7 +197,7 @@ class ReplicationCommandHandler:
|
||||
|
||||
# For each connection, the incoming stream names that have received a POSITION
|
||||
# from that connection.
|
||||
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
|
||||
self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_command_queue",
|
||||
@@ -220,7 +220,7 @@ class ReplicationCommandHandler:
|
||||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
|
||||
def _add_command_to_stream_queue(
|
||||
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
|
||||
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
|
||||
) -> None:
|
||||
"""Queue the given received command for processing
|
||||
|
||||
@@ -267,7 +267,7 @@ class ReplicationCommandHandler:
|
||||
async def _process_command(
|
||||
self,
|
||||
cmd: Union[PositionCommand, RdataCommand],
|
||||
conn: AbstractConnection,
|
||||
conn: IReplicationConnection,
|
||||
stream_name: str,
|
||||
) -> None:
|
||||
if isinstance(cmd, PositionCommand):
|
||||
@@ -321,10 +321,10 @@ class ReplicationCommandHandler:
|
||||
"""Get a list of streams that this instances replicates."""
|
||||
return self._streams_to_replicate
|
||||
|
||||
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
|
||||
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
|
||||
self.send_positions_to_connection(conn)
|
||||
|
||||
def send_positions_to_connection(self, conn: AbstractConnection):
|
||||
def send_positions_to_connection(self, conn: IReplicationConnection):
|
||||
"""Send current position of all streams this process is source of to
|
||||
the connection.
|
||||
"""
|
||||
@@ -347,7 +347,7 @@ class ReplicationCommandHandler:
|
||||
)
|
||||
|
||||
def on_USER_SYNC(
|
||||
self, conn: AbstractConnection, cmd: UserSyncCommand
|
||||
self, conn: IReplicationConnection, cmd: UserSyncCommand
|
||||
) -> Optional[Awaitable[None]]:
|
||||
user_sync_counter.inc()
|
||||
|
||||
@@ -359,21 +359,23 @@ class ReplicationCommandHandler:
|
||||
return None
|
||||
|
||||
def on_CLEAR_USER_SYNC(
|
||||
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
|
||||
self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
|
||||
) -> Optional[Awaitable[None]]:
|
||||
if self._is_master:
|
||||
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
|
||||
else:
|
||||
return None
|
||||
|
||||
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
|
||||
def on_FEDERATION_ACK(
|
||||
self, conn: IReplicationConnection, cmd: FederationAckCommand
|
||||
):
|
||||
federation_ack_counter.inc()
|
||||
|
||||
if self._federation_sender:
|
||||
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
|
||||
|
||||
def on_USER_IP(
|
||||
self, conn: AbstractConnection, cmd: UserIpCommand
|
||||
self, conn: IReplicationConnection, cmd: UserIpCommand
|
||||
) -> Optional[Awaitable[None]]:
|
||||
user_ip_cache_counter.inc()
|
||||
|
||||
@@ -395,7 +397,7 @@ class ReplicationCommandHandler:
|
||||
assert self._server_notices_sender is not None
|
||||
await self._server_notices_sender.on_user_ip(cmd.user_id)
|
||||
|
||||
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
|
||||
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore RDATA that are just our own echoes
|
||||
return
|
||||
@@ -412,7 +414,7 @@ class ReplicationCommandHandler:
|
||||
self._add_command_to_stream_queue(conn, cmd)
|
||||
|
||||
async def _process_rdata(
|
||||
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
|
||||
self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
|
||||
) -> None:
|
||||
"""Process an RDATA command
|
||||
|
||||
@@ -486,7 +488,7 @@ class ReplicationCommandHandler:
|
||||
stream_name, instance_name, token, rows
|
||||
)
|
||||
|
||||
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
|
||||
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
|
||||
if cmd.instance_name == self._instance_name:
|
||||
# Ignore POSITION that are just our own echoes
|
||||
return
|
||||
@@ -496,7 +498,7 @@ class ReplicationCommandHandler:
|
||||
self._add_command_to_stream_queue(conn, cmd)
|
||||
|
||||
async def _process_position(
|
||||
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
|
||||
self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
|
||||
) -> None:
|
||||
"""Process a POSITION command
|
||||
|
||||
@@ -553,7 +555,9 @@ class ReplicationCommandHandler:
|
||||
|
||||
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
||||
|
||||
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
|
||||
def on_REMOTE_SERVER_UP(
|
||||
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
|
||||
):
|
||||
""""Called when get a new REMOTE_SERVER_UP command."""
|
||||
self._replication_data_handler.on_remote_server_up(cmd.data)
|
||||
|
||||
@@ -576,7 +580,7 @@ class ReplicationCommandHandler:
|
||||
# between two instances, but that is not currently supported).
|
||||
self.send_command(cmd, ignore_conn=conn)
|
||||
|
||||
def new_connection(self, connection: AbstractConnection):
|
||||
def new_connection(self, connection: IReplicationConnection):
|
||||
"""Called when we have a new connection."""
|
||||
self._connections.append(connection)
|
||||
|
||||
@@ -603,7 +607,7 @@ class ReplicationCommandHandler:
|
||||
UserSyncCommand(self._instance_id, user_id, True, now)
|
||||
)
|
||||
|
||||
def lost_connection(self, connection: AbstractConnection):
|
||||
def lost_connection(self, connection: IReplicationConnection):
|
||||
"""Called when a connection is closed/lost."""
|
||||
# we no longer need _streams_by_connection for this connection.
|
||||
streams = self._streams_by_connection.pop(connection, None)
|
||||
@@ -624,7 +628,7 @@ class ReplicationCommandHandler:
|
||||
return bool(self._connections)
|
||||
|
||||
def send_command(
|
||||
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
|
||||
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
|
||||
):
|
||||
"""Send a command to all connected connections.
|
||||
|
||||
|
||||
@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
|
||||
> ERROR server stopping
|
||||
* connection closed by server *
|
||||
"""
|
||||
import abc
|
||||
import fcntl
|
||||
import logging
|
||||
import struct
|
||||
@@ -54,6 +53,7 @@ from inspect import isawaitable
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from prometheus_client import Counter
|
||||
from zope.interface import Interface, implementer
|
||||
|
||||
from twisted.internet import task
|
||||
from twisted.protocols.basic import LineOnlyReceiver
|
||||
@@ -121,6 +121,14 @@ class ConnectionStates:
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class IReplicationConnection(Interface):
|
||||
"""An interface for replication connections."""
|
||||
|
||||
def send_command(cmd: Command):
|
||||
"""Send the command down the connection"""
|
||||
|
||||
|
||||
@implementer(IReplicationConnection)
|
||||
class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
"""Base replication protocol shared between client and server.
|
||||
|
||||
@@ -495,20 +503,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
self.send_command(ReplicateCommand())
|
||||
|
||||
|
||||
class AbstractConnection(abc.ABC):
|
||||
"""An interface for replication connections."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def send_command(self, cmd: Command):
|
||||
"""Send the command down the connection"""
|
||||
pass
|
||||
|
||||
|
||||
# This tells python that `BaseReplicationStreamProtocol` implements the
|
||||
# interface.
|
||||
AbstractConnection.register(BaseReplicationStreamProtocol)
|
||||
|
||||
|
||||
# The following simply registers metrics for the replication connections
|
||||
|
||||
pending_commands = LaterGauge(
|
||||
|
||||
@@ -19,6 +19,11 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
|
||||
|
||||
import attr
|
||||
import txredisapi
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet.address import IPv4Address, IPv6Address
|
||||
from twisted.internet.interfaces import IAddress, IConnector
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||
from synapse.metrics.background_process_metrics import (
|
||||
@@ -32,7 +37,7 @@ from synapse.replication.tcp.commands import (
|
||||
parse_command_from_line,
|
||||
)
|
||||
from synapse.replication.tcp.protocol import (
|
||||
AbstractConnection,
|
||||
IReplicationConnection,
|
||||
tcp_inbound_commands_counter,
|
||||
tcp_outbound_commands_counter,
|
||||
)
|
||||
@@ -62,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
|
||||
pass
|
||||
|
||||
|
||||
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
@implementer(IReplicationConnection)
|
||||
class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||
"""Connection to redis subscribed to replication stream.
|
||||
|
||||
This class fulfils two functions:
|
||||
@@ -71,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
connection, parsing *incoming* messages into replication commands, and passing them
|
||||
to `ReplicationCommandHandler`
|
||||
|
||||
(b) it implements the AbstractConnection API, where it sends *outgoing* commands
|
||||
(b) it implements the IReplicationConnection API, where it sends *outgoing* commands
|
||||
onto outbound_redis_connection.
|
||||
|
||||
Due to the vagaries of `txredisapi` we don't want to have a custom
|
||||
@@ -253,6 +259,37 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
|
||||
except Exception:
|
||||
logger.warning("Failed to send ping to a redis connection")
|
||||
|
||||
# ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
|
||||
# it's rubbish. We add our own here.
|
||||
|
||||
def startedConnecting(self, connector: IConnector):
|
||||
logger.info(
|
||||
"Connecting to redis server %s", format_address(connector.getDestination())
|
||||
)
|
||||
super().startedConnecting(connector)
|
||||
|
||||
def clientConnectionFailed(self, connector: IConnector, reason: Failure):
|
||||
logger.info(
|
||||
"Connection to redis server %s failed: %s",
|
||||
format_address(connector.getDestination()),
|
||||
reason.value,
|
||||
)
|
||||
super().clientConnectionFailed(connector, reason)
|
||||
|
||||
def clientConnectionLost(self, connector: IConnector, reason: Failure):
|
||||
logger.info(
|
||||
"Connection to redis server %s lost: %s",
|
||||
format_address(connector.getDestination()),
|
||||
reason.value,
|
||||
)
|
||||
super().clientConnectionLost(connector, reason)
|
||||
|
||||
|
||||
def format_address(address: IAddress) -> str:
|
||||
if isinstance(address, (IPv4Address, IPv6Address)):
|
||||
return "%s:%i" % (address.host, address.port)
|
||||
return str(address)
|
||||
|
||||
|
||||
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
|
||||
"""This is a reconnecting factory that connects to redis and immediately
|
||||
|
||||
@@ -15,10 +15,9 @@
|
||||
|
||||
import re
|
||||
|
||||
import twisted.web.server
|
||||
|
||||
import synapse.api.auth
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import UserID
|
||||
|
||||
|
||||
@@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"):
|
||||
return patterns
|
||||
|
||||
|
||||
async def assert_requester_is_admin(
|
||||
auth: synapse.api.auth.Auth, request: twisted.web.server.Request
|
||||
) -> None:
|
||||
async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
|
||||
"""Verify that the requester is an admin user
|
||||
|
||||
Args:
|
||||
auth: api.auth.Auth singleton
|
||||
auth: Auth singleton
|
||||
request: incoming request
|
||||
|
||||
Raises:
|
||||
@@ -53,11 +50,11 @@ async def assert_requester_is_admin(
|
||||
await assert_user_is_admin(auth, requester.user)
|
||||
|
||||
|
||||
async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
|
||||
async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
|
||||
"""Verify that the given user is an admin user
|
||||
|
||||
Args:
|
||||
auth: api.auth.Auth singleton
|
||||
auth: Auth singleton
|
||||
user_id: user to check
|
||||
|
||||
Raises:
|
||||
|
||||
@@ -17,10 +17,9 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import (
|
||||
admin_patterns,
|
||||
assert_requester_is_admin,
|
||||
@@ -50,7 +49,9 @@ class QuarantineMediaInRoom(RestServlet):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
@@ -75,7 +76,9 @@ class QuarantineMediaByUser(RestServlet):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
@@ -103,7 +106,7 @@ class QuarantineMediaByID(RestServlet):
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(
|
||||
self, request: Request, server_name: str, media_id: str
|
||||
self, request: SynapseRequest, server_name: str, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
@@ -127,7 +130,9 @@ class ProtectMediaByID(RestServlet):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
@@ -148,7 +153,9 @@ class ListMediaInRoom(RestServlet):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
is_admin = await self.auth.is_server_admin(requester.user)
|
||||
if not is_admin:
|
||||
@@ -166,7 +173,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
|
||||
self.media_repository = hs.get_media_repository()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
before_ts = parse_integer(request, "before_ts", required=True)
|
||||
@@ -189,7 +196,7 @@ class DeleteMediaByID(RestServlet):
|
||||
self.media_repository = hs.get_media_repository()
|
||||
|
||||
async def on_DELETE(
|
||||
self, request: Request, server_name: str, media_id: str
|
||||
self, request: SynapseRequest, server_name: str, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
@@ -218,7 +225,9 @@ class DeleteMediaByDateSize(RestServlet):
|
||||
self.server_name = hs.hostname
|
||||
self.media_repository = hs.get_media_repository()
|
||||
|
||||
async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, server_name: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
before_ts = parse_integer(request, "before_ts", required=True)
|
||||
|
||||
@@ -269,7 +269,10 @@ class UserRestServletV2(RestServlet):
|
||||
target_user.to_string(), False, requester, by_admin=True
|
||||
)
|
||||
elif not deactivate and user["deactivated"]:
|
||||
if "password" not in body:
|
||||
if (
|
||||
"password" not in body
|
||||
and self.hs.config.password_localdb_enabled
|
||||
):
|
||||
raise SynapseError(
|
||||
400, "Must provide a password to re-activate an account."
|
||||
)
|
||||
|
||||
@@ -32,6 +32,7 @@ from synapse.http.servlet import (
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.types import GroupID, JsonDict
|
||||
|
||||
from ._base import client_patterns
|
||||
@@ -70,7 +71,9 @@ class GroupServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -81,7 +84,9 @@ class GroupServlet(RestServlet):
|
||||
return 200, group_description
|
||||
|
||||
@_validate_group_id
|
||||
async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -111,7 +116,9 @@ class GroupSummaryServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -144,7 +151,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, category_id: Optional[str], room_id: str
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
group_id: str,
|
||||
category_id: Optional[str],
|
||||
room_id: str,
|
||||
):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -176,7 +187,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, category_id: str, room_id: str
|
||||
self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
|
||||
):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -206,7 +217,7 @@ class GroupCategoryServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(
|
||||
self, request: Request, group_id: str, category_id: str
|
||||
self, request: SynapseRequest, group_id: str, category_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -219,7 +230,7 @@ class GroupCategoryServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, category_id: str
|
||||
self, request: SynapseRequest, group_id: str, category_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -247,7 +258,7 @@ class GroupCategoryServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, category_id: str
|
||||
self, request: SynapseRequest, group_id: str, category_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -274,7 +285,9 @@ class GroupCategoriesServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -298,7 +311,7 @@ class GroupRoleServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(
|
||||
self, request: Request, group_id: str, role_id: str
|
||||
self, request: SynapseRequest, group_id: str, role_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -311,7 +324,7 @@ class GroupRoleServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, role_id: str
|
||||
self, request: SynapseRequest, group_id: str, role_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -339,7 +352,7 @@ class GroupRoleServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, role_id: str
|
||||
self, request: SynapseRequest, group_id: str, role_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -366,7 +379,9 @@ class GroupRolesServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -399,7 +414,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, role_id: Optional[str], user_id: str
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
group_id: str,
|
||||
role_id: Optional[str],
|
||||
user_id: str,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -431,7 +450,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, role_id: str, user_id: str
|
||||
self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
|
||||
):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -458,7 +477,9 @@ class GroupRoomServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -481,7 +502,9 @@ class GroupUsersServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -504,7 +527,9 @@ class GroupInvitedUsersServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -526,7 +551,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -554,7 +581,7 @@ class GroupCreateServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
self.server_name = hs.hostname
|
||||
|
||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -598,7 +625,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, room_id: str
|
||||
self, request: SynapseRequest, group_id: str, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -615,7 +642,7 @@ class GroupAdminRoomsServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_DELETE(
|
||||
self, request: Request, group_id: str, room_id: str
|
||||
self, request: SynapseRequest, group_id: str, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -646,7 +673,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(
|
||||
self, request: Request, group_id: str, room_id: str, config_key: str
|
||||
self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
|
||||
):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
@@ -678,7 +705,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id, user_id
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -708,7 +737,9 @@ class GroupAdminUsersKickServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id, user_id
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -735,7 +766,9 @@ class GroupSelfLeaveServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -762,7 +795,9 @@ class GroupSelfJoinServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -789,7 +824,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -816,7 +853,9 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@_validate_group_id
|
||||
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, group_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
@@ -839,7 +878,9 @@ class PublicisedGroupsForUserServlet(RestServlet):
|
||||
self.store = hs.get_datastore()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
|
||||
@@ -859,7 +900,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
|
||||
self.store = hs.get_datastore()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
@@ -881,7 +922,7 @@ class GroupsForUserServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.site import SynapseRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
@@ -35,7 +36,7 @@ class MediaConfigResource(DirectServeJsonResource):
|
||||
self.auth = hs.get_auth()
|
||||
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
||||
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||
await self.auth.get_user_by_req(request)
|
||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ from synapse.http.server import (
|
||||
respond_with_json_bytes,
|
||||
)
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.rest.media.v1._base import get_filename_from_headers
|
||||
@@ -185,7 +186,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
request.setHeader(b"Allow", b"OPTIONS, GET")
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
async def _async_render_GET(self, request: SynapseRequest) -> None:
|
||||
|
||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
@@ -22,6 +22,7 @@ from twisted.web.server import Request
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.media.v1.media_storage import SpamMediaException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -49,7 +50,7 @@ class UploadResource(DirectServeJsonResource):
|
||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
async def _async_render_POST(self, request: Request) -> None:
|
||||
async def _async_render_POST(self, request: SynapseRequest) -> None:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
|
||||
@@ -351,11 +351,9 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
|
||||
@cache_in_self
|
||||
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
|
||||
return (
|
||||
InsecureInterceptableContextFactory()
|
||||
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
|
||||
else RegularPolicyForHTTPS()
|
||||
)
|
||||
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
|
||||
return InsecureInterceptableContextFactory()
|
||||
return RegularPolicyForHTTPS()
|
||||
|
||||
@cache_in_self
|
||||
def get_simple_http_client(self) -> SimpleHttpClient:
|
||||
|
||||
@@ -16,12 +16,23 @@ from io import BytesIO
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from netaddr import IPSet
|
||||
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.client import ResponseDone
|
||||
from twisted.test.proto_helpers import AccumulatingProtocol
|
||||
from twisted.web.client import Agent, ResponseDone
|
||||
from twisted.web.iweb import UNKNOWN_LENGTH
|
||||
|
||||
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.client import (
|
||||
BlacklistingAgentWrapper,
|
||||
BlacklistingReactorWrapper,
|
||||
BodyExceededMaxSize,
|
||||
read_body_with_max_size,
|
||||
)
|
||||
|
||||
from tests.server import FakeTransport, get_clock
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
@@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase):
|
||||
|
||||
# The data is never consumed.
|
||||
self.assertEqual(result.getvalue(), b"")
|
||||
|
||||
|
||||
class BlacklistingAgentTest(TestCase):
|
||||
def setUp(self):
|
||||
self.reactor, self.clock = get_clock()
|
||||
|
||||
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
|
||||
self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
|
||||
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
|
||||
|
||||
# Configure the reactor's DNS resolver.
|
||||
for (domain, ip) in (
|
||||
(self.safe_domain, self.safe_ip),
|
||||
(self.unsafe_domain, self.unsafe_ip),
|
||||
(self.allowed_domain, self.allowed_ip),
|
||||
):
|
||||
self.reactor.lookups[domain.decode()] = ip.decode()
|
||||
self.reactor.lookups[ip.decode()] = ip.decode()
|
||||
|
||||
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
|
||||
self.ip_blacklist = IPSet(["5.0.0.0/8"])
|
||||
|
||||
def test_reactor(self):
|
||||
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
|
||||
agent = Agent(
|
||||
BlacklistingReactorWrapper(
|
||||
self.reactor,
|
||||
ip_whitelist=self.ip_whitelist,
|
||||
ip_blacklist=self.ip_blacklist,
|
||||
),
|
||||
)
|
||||
|
||||
# The unsafe domains and IPs should be rejected.
|
||||
for domain in (self.unsafe_domain, self.unsafe_ip):
|
||||
self.failureResultOf(
|
||||
agent.request(b"GET", b"http://" + domain), DNSLookupError
|
||||
)
|
||||
|
||||
# The safe domains IPs should be accepted.
|
||||
for domain in (
|
||||
self.safe_domain,
|
||||
self.allowed_domain,
|
||||
self.safe_ip,
|
||||
self.allowed_ip,
|
||||
):
|
||||
d = agent.request(b"GET", b"http://" + domain)
|
||||
|
||||
# Grab the latest TCP connection.
|
||||
(
|
||||
host,
|
||||
port,
|
||||
client_factory,
|
||||
_timeout,
|
||||
_bindAddress,
|
||||
) = self.reactor.tcpClients[-1]
|
||||
|
||||
# Make the connection and pump data through it.
|
||||
client = client_factory.buildProtocol(None)
|
||||
server = AccumulatingProtocol()
|
||||
server.makeConnection(FakeTransport(client, self.reactor))
|
||||
client.makeConnection(FakeTransport(server, self.reactor))
|
||||
client.dataReceived(
|
||||
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
|
||||
)
|
||||
|
||||
response = self.successResultOf(d)
|
||||
self.assertEqual(response.code, 200)
|
||||
|
||||
def test_agent(self):
|
||||
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
|
||||
agent = BlacklistingAgentWrapper(
|
||||
Agent(self.reactor),
|
||||
ip_whitelist=self.ip_whitelist,
|
||||
ip_blacklist=self.ip_blacklist,
|
||||
)
|
||||
|
||||
# The unsafe IPs should be rejected.
|
||||
self.failureResultOf(
|
||||
agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
|
||||
)
|
||||
|
||||
# The safe and unsafe domains and safe IPs should be accepted.
|
||||
for domain in (
|
||||
self.safe_domain,
|
||||
self.unsafe_domain,
|
||||
self.allowed_domain,
|
||||
self.safe_ip,
|
||||
self.allowed_ip,
|
||||
):
|
||||
d = agent.request(b"GET", b"http://" + domain)
|
||||
|
||||
# Grab the latest TCP connection.
|
||||
(
|
||||
host,
|
||||
port,
|
||||
client_factory,
|
||||
_timeout,
|
||||
_bindAddress,
|
||||
) = self.reactor.tcpClients[-1]
|
||||
|
||||
# Make the connection and pump data through it.
|
||||
client = client_factory.buildProtocol(None)
|
||||
server = AccumulatingProtocol()
|
||||
server.makeConnection(FakeTransport(client, self.reactor))
|
||||
client.makeConnection(FakeTransport(server, self.reactor))
|
||||
client.dataReceived(
|
||||
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
|
||||
)
|
||||
|
||||
response = self.successResultOf(d)
|
||||
self.assertEqual(response.code, 200)
|
||||
|
||||
@@ -17,7 +17,7 @@ import mock
|
||||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.replication.tcp.commands import FederationAckCommand
|
||||
from synapse.replication.tcp.protocol import AbstractConnection
|
||||
from synapse.replication.tcp.protocol import IReplicationConnection
|
||||
from synapse.replication.tcp.streams.federation import FederationStream
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
@@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
|
||||
"""
|
||||
rch = self.hs.get_tcp_replication()
|
||||
|
||||
# wire up the ReplicationCommandHandler to a mock connection
|
||||
mock_connection = mock.Mock(spec=AbstractConnection)
|
||||
# wire up the ReplicationCommandHandler to a mock connection, which needs
|
||||
# to implement IReplicationConnection. (Note that Mock doesn't understand
|
||||
# interfaces, but casing an interface to a list gives the attributes.)
|
||||
mock_connection = mock.Mock(spec=list(IReplicationConnection))
|
||||
rch.new_connection(mock_connection)
|
||||
|
||||
# tell it it received an RDATA row
|
||||
|
||||
Reference in New Issue
Block a user