1
0

Merge commit '55da8df07' into anoa/dinsic_release_1_31_0

This commit is contained in:
Andrew Morgan
2021-04-23 17:28:38 +01:00
27 changed files with 464 additions and 243 deletions

1
changelog.d/9563.misc Normal file
View File

@@ -0,0 +1 @@
Fix type hints and tests for BlacklistingAgentWrapper and BlacklistingReactorWrapper.

1
changelog.d/9587.bugfix Normal file
View File

@@ -0,0 +1 @@
Re-Activating account with admin API when local passwords are disabled.

1
changelog.d/9590.misc Normal file
View File

@@ -0,0 +1 @@
Add logging for redis connection setup.

1
changelog.d/9591.misc Normal file
View File

@@ -0,0 +1 @@
Fix incorrect type hints.

1
changelog.d/9596.misc Normal file
View File

@@ -0,0 +1 @@
Improve logging when processing incoming transactions.

1
changelog.d/9597.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug introduced in Synapse 1.20 which caused incoming federation transactions to stack up, causing slow recovery from outages.

View File

@@ -17,6 +17,8 @@
"""
from typing import Any, List, Optional, Type, Union
from twisted.internet import protocol
class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
@@ -52,7 +54,7 @@ def lazyConnection(
class ConnectionHandler: ...
class RedisFactory:
class RedisFactory(protocol.ReconnectingClientFactory):
continueTrying: bool
handler: RedisProtocol
pool: List[RedisProtocol]

View File

@@ -164,7 +164,7 @@ class Auth:
async def get_user_by_req(
self,
request: Request,
request: SynapseRequest,
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,

View File

@@ -113,10 +113,11 @@ class FederationServer(FederationBase):
# with FederationHandlerRegistry.
hs.get_directory_handler()
self._federation_ratelimiter = hs.get_federation_ratelimiter()
self._server_linearizer = Linearizer("fed_server")
self._transaction_linearizer = Linearizer("fed_txn_handler")
# origins that we are currently processing a transaction from.
# a dict from origin to txn id.
self._active_transactions = {} # type: Dict[str, str]
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
@@ -170,6 +171,33 @@ class FederationServer(FederationBase):
logger.debug("[%s] Got transaction", transaction_id)
# Reject malformed transactions early: reject if too many PDUs/EDUs
if len(transaction.pdus) > 50 or ( # type: ignore
hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
):
logger.info("Transaction PDU or EDU count too large. Returning 400")
return 400, {}
# we only process one transaction from each origin at a time. We need to do
# this check here, rather than in _on_incoming_transaction_inner so that we
# don't cache the rejection in _transaction_resp_cache (so that if the txn
# arrives again later, we can process it).
current_transaction = self._active_transactions.get(origin)
if current_transaction and current_transaction != transaction_id:
logger.warning(
"Received another txn %s from %s while still processing %s",
transaction_id,
origin,
current_transaction,
)
return 429, {
"errcode": Codes.UNKNOWN,
"error": "Too many concurrent transactions",
}
# CRITICAL SECTION: we must now not await until we populate _active_transactions
# in _on_incoming_transaction_inner.
# We wrap in a ResponseCache so that we de-duplicate retried
# transactions.
return await self._transaction_resp_cache.wrap(
@@ -183,26 +211,18 @@ class FederationServer(FederationBase):
async def _on_incoming_transaction_inner(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
# Use a linearizer to ensure that transactions from a remote are
# processed in order.
with await self._transaction_linearizer.queue(origin):
# We rate limit here *after* we've queued up the incoming requests,
# so that we don't fill up the ratelimiter with blocked requests.
#
# This is important as the ratelimiter allows N concurrent requests
# at a time, and only starts ratelimiting if there are more requests
# than that being processed at a time. If we queued up requests in
# the linearizer/response cache *after* the ratelimiting then those
# queued up requests would count as part of the allowed limit of N
# concurrent requests.
with self._federation_ratelimiter.ratelimit(origin) as d:
await d
# CRITICAL SECTION: the first thing we must do (before awaiting) is
# add an entry to _active_transactions.
assert origin not in self._active_transactions
self._active_transactions[origin] = transaction.transaction_id # type: ignore
result = await self._handle_incoming_transaction(
origin, transaction, request_time
)
return result
try:
result = await self._handle_incoming_transaction(
origin, transaction, request_time
)
return result
finally:
del self._active_transactions[origin]
async def _handle_incoming_transaction(
self, origin: str, transaction: Transaction, request_time: int
@@ -228,19 +248,6 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id) # type: ignore
# Reject if PDU count > 50 or EDU count > 100
if len(transaction.pdus) > 50 or ( # type: ignore
hasattr(transaction, "edus") and len(transaction.edus) > 100 # type: ignore
):
logger.info("Transaction PDU or EDU count too large. Returning 400")
response = {}
await self.transaction_actions.set_response(
origin, transaction, 400, response
)
return 400, response
# We process PDUs and EDUs in parallel. This is important as we don't
# want to block things like to device messages from reaching clients
# behind the potentially expensive handling of PDUs.
@@ -336,34 +343,41 @@ class FederationServer(FederationBase):
# impose a limit to avoid going too crazy with ram/cpu.
async def process_pdus_for_room(room_id: str):
logger.debug("Processing PDUs for %s", room_id)
try:
await self.check_server_matches_acl(origin_host, room_id)
except AuthError as e:
logger.warning("Ignoring PDUs for room %s from banned server", room_id)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
pdu_results[event_id] = e.error_dict()
return
with nested_logging_context(room_id):
logger.debug("Processing PDUs for %s", room_id)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
with pdu_process_time.time():
with nested_logging_context(event_id):
try:
await self._handle_received_pdu(origin, pdu)
pdu_results[event_id] = {}
except FederationError as e:
logger.warning("Error handling PDU %s: %s", event_id, e)
pdu_results[event_id] = {"error": str(e)}
except Exception as e:
f = failure.Failure()
pdu_results[event_id] = {"error": str(e)}
logger.error(
"Failed to handle PDU %s",
event_id,
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
try:
await self.check_server_matches_acl(origin_host, room_id)
except AuthError as e:
logger.warning(
"Ignoring PDUs for room %s from banned server", room_id
)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
pdu_results[event_id] = e.error_dict()
return
for pdu in pdus_by_room[room_id]:
pdu_results[pdu.event_id] = await process_pdu(pdu)
async def process_pdu(pdu: EventBase) -> JsonDict:
event_id = pdu.event_id
with pdu_process_time.time():
with nested_logging_context(event_id):
try:
await self._handle_received_pdu(origin, pdu)
return {}
except FederationError as e:
logger.warning("Error handling PDU %s: %s", event_id, e)
return {"error": str(e)}
except Exception as e:
f = failure.Failure()
logger.error(
"Failed to handle PDU %s",
event_id,
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
return {"error": str(e)}
await concurrently_execute(
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
@@ -942,7 +956,9 @@ class FederationHandlerRegistry:
self.edu_handlers = (
{}
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
self.query_handlers = (
{}
) # type: Dict[str, Callable[[dict], Awaitable[JsonDict]]]
# Map from type to instance names that we should route EDU handling to.
# We randomly choose one instance from the list to route to for each new
@@ -976,7 +992,7 @@ class FederationHandlerRegistry:
self.edu_handlers[edu_type] = handler
def register_query_handler(
self, query_type: str, handler: Callable[[dict], defer.Deferred]
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
):
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
@@ -1049,7 +1065,7 @@ class FederationHandlerRegistry:
# Oh well, let's just log and move on.
logger.warning("No handler registered for EDU type %s", edu_type)
async def on_query(self, query_type: str, args: dict):
async def on_query(self, query_type: str, args: dict) -> JsonDict:
handler = self.query_handlers.get(query_type)
if handler:
return await handler(args)

View File

@@ -202,7 +202,7 @@ class FederationHandler(BaseHandler):
or pdu.internal_metadata.is_outlier()
)
if already_seen:
logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
logger.debug("Already seen pdu")
return
# do some initial sanity-checking of the event. In particular, make
@@ -211,18 +211,14 @@ class FederationHandler(BaseHandler):
try:
self._sanity_check_event(pdu)
except SynapseError as err:
logger.warning(
"[%s %s] Received event failed sanity checks", room_id, event_id
)
logger.warning("Received event failed sanity checks")
raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
if room_id in self.room_queues:
logger.info(
"[%s %s] Queuing PDU from %s for now: join in progress",
room_id,
event_id,
"Queuing PDU from %s for now: join in progress",
origin,
)
self.room_queues[room_id].append((pdu, origin))
@@ -237,9 +233,7 @@ class FederationHandler(BaseHandler):
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
room_id,
event_id,
"Ignoring PDU from %s as we're not in the room",
origin,
)
return None
@@ -251,7 +245,7 @@ class FederationHandler(BaseHandler):
# We only backfill backwards to the min depth.
min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
logger.debug("min_depth: %d", min_depth)
prevs = set(pdu.prev_event_ids())
seen = await self.store.have_events_in_timeline(prevs)
@@ -268,17 +262,13 @@ class FederationHandler(BaseHandler):
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
"[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
room_id,
event_id,
"Acquiring room lock to fetch %d missing prev_events: %s",
len(missing_prevs),
shortstr(missing_prevs),
)
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id,
event_id,
"Acquired room lock to fetch %d missing prev_events",
len(missing_prevs),
)
@@ -298,9 +288,7 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
logger.info(
"[%s %s] Found all missing prev_events",
room_id,
event_id,
"Found all missing prev_events",
)
elif missing_prevs:
logger.info(
@@ -338,9 +326,7 @@ class FederationHandler(BaseHandler):
if sent_to_us_directly:
logger.warning(
"[%s %s] Rejecting: failed to fetch %d prev events: %s",
room_id,
event_id,
"Rejecting: failed to fetch %d prev events: %s",
len(prevs - seen),
shortstr(prevs - seen),
)
@@ -416,10 +402,7 @@ class FederationHandler(BaseHandler):
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"[%s %s] Error attempting to resolve state at missing "
"prev_events",
room_id,
event_id,
"Error attempting to resolve state at missing " "prev_events",
exc_info=True,
)
raise FederationError(
@@ -456,9 +439,7 @@ class FederationHandler(BaseHandler):
latest |= seen
logger.info(
"[%s %s]: Requesting missing events between %s and %s",
room_id,
event_id,
"Requesting missing events between %s and %s",
shortstr(latest),
event_id,
)
@@ -525,15 +506,11 @@ class FederationHandler(BaseHandler):
# We failed to get the missing events, but since we need to handle
# the case of `get_missing_events` not returning the necessary
# events anyway, it is safe to simply log the error and continue.
logger.warning(
"[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
)
logger.warning("Failed to get prev_events: %s", e)
return
logger.info(
"[%s %s]: Got %d prev_events: %s",
room_id,
event_id,
"Got %d prev_events: %s",
len(missing_events),
shortstr(missing_events),
)
@@ -544,9 +521,7 @@ class FederationHandler(BaseHandler):
for ev in missing_events:
logger.info(
"[%s %s] Handling received prev_event %s",
room_id,
event_id,
"Handling received prev_event %s",
ev.event_id,
)
with nested_logging_context(ev.event_id):
@@ -555,9 +530,7 @@ class FederationHandler(BaseHandler):
except FederationError as e:
if e.code == 403:
logger.warning(
"[%s %s] Received prev_event %s failed history check.",
room_id,
event_id,
"Received prev_event %s failed history check.",
ev.event_id,
)
else:
@@ -709,10 +682,7 @@ class FederationHandler(BaseHandler):
(ie, we are missing one or more prev_events), the resolved state at the
event
"""
room_id = event.room_id
event_id = event.event_id
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
logger.debug("Processing event: %s", event)
try:
await self._handle_new_event(origin, event, state=state)

View File

@@ -34,6 +34,7 @@ from pymacaroons.exceptions import (
from typing_extensions import TypedDict
from twisted.web.client import readBody
from twisted.web.http_headers import Headers
from synapse.config import ConfigError
from synapse.config.oidc_config import (
@@ -538,7 +539,7 @@ class OidcProvider:
"""
metadata = await self.load_metadata()
token_endpoint = metadata.get("token_endpoint")
headers = {
raw_headers = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": self._http_client.user_agent,
"Accept": "application/json",
@@ -552,10 +553,10 @@ class OidcProvider:
body = urlencode(args, True)
# Fill the body/headers with credentials
uri, headers, body = self._client_auth.prepare(
method="POST", uri=token_endpoint, headers=headers, body=body
uri, raw_headers, body = self._client_auth.prepare(
method="POST", uri=token_endpoint, headers=raw_headers, body=body
)
headers = {k: [v] for (k, v) in headers.items()}
headers = Headers({k: [v] for (k, v) in raw_headers.items()})
# Do the actual request
# We're not using the SimpleHttpClient util methods as we don't want to

View File

@@ -39,6 +39,7 @@ from zope.interface import implementer, provider
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
@@ -56,7 +57,13 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
from twisted.web.iweb import (
UNKNOWN_LENGTH,
IAgent,
IBodyProducer,
IPolicyForHTTPS,
IResponse,
)
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -151,16 +158,17 @@ class _IPBlacklistingResolver:
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:
r = recv()
addresses = [] # type: List[IAddress]
def _callback() -> None:
r.resolutionBegan(None)
has_bad_ip = False
for i in addresses:
ip_address = IPAddress(i.host)
for address in addresses:
# We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
# should go through this path.
if not isinstance(address, (IPv4Address, IPv6Address)):
continue
ip_address = IPAddress(address.host)
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
@@ -175,15 +183,15 @@ class _IPBlacklistingResolver:
# request, but all we can really do from here is claim that there were no
# valid results.
if not has_bad_ip:
for i in addresses:
r.addressResolved(i)
r.resolutionComplete()
for address in addresses:
recv.addressResolved(address)
recv.resolutionComplete()
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
recv.resolutionBegan(resolutionInProgress)
@staticmethod
def addressResolved(address: IAddress) -> None:
@@ -197,7 +205,7 @@ class _IPBlacklistingResolver:
EndpointReceiver, hostname, portNumber=portNumber
)
return r
return recv
@implementer(ISynapseReactor)
@@ -346,7 +354,7 @@ class SimpleHttpClient:
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
use_proxy=use_proxy,
)
) # type: IAgent
if self._ip_blacklist:
# If we have an IP blacklist, we then install the blacklisting Agent
@@ -868,6 +876,7 @@ def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> by
return query_str.encode("utf8")
@implementer(IPolicyForHTTPS)
class InsecureInterceptableContextFactory(ssl.ContextFactory):
"""
Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.

View File

@@ -32,8 +32,9 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint
from twisted.internet.protocol import Factory, Protocol
from twisted.internet.tcp import Connection
from twisted.python.failure import Failure
logger = logging.getLogger(__name__)
@@ -52,7 +53,9 @@ class LogProducer:
format: A callable to format the log record to a string.
"""
transport = attr.ib(type=ITransport)
# This is essentially ITCPTransport, but that is missing certain fields
# (connected and registerProducer) which are part of the implementation.
transport = attr.ib(type=Connection)
_format = attr.ib(type=Callable[[logging.LogRecord], str])
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
@@ -149,8 +152,6 @@ class RemoteHandler(logging.Handler):
if self._connection_waiter:
return
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
def fail(failure: Failure) -> None:
# If the Deferred was cancelled (e.g. during shutdown) do not try to
# reconnect (this will cause an infinite loop of errors).
@@ -163,9 +164,13 @@ class RemoteHandler(logging.Handler):
self._connect()
def writer(result: Protocol) -> None:
# Force recognising transport as a Connection and not the more
# generic ITransport.
transport = result.transport # type: Connection # type: ignore
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
if self._producer and result.transport is self._producer.transport:
if self._producer and transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
@@ -177,14 +182,16 @@ class RemoteHandler(logging.Handler):
# Make a new producer and start it.
self._producer = LogProducer(
buffer=self._buffer,
transport=result.transport,
transport=transport,
format=self.format,
)
result.transport.registerProducer(self._producer, True)
transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
self._connection_waiter.addCallbacks(writer, fail)
deferred = self._service.whenConnected(failAfterFailures=1) # type: Deferred
deferred.addCallbacks(writer, fail)
self._connection_waiter = deferred
def _handle_pressure(self) -> None:
"""

View File

@@ -16,8 +16,8 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Optional
from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from twisted.internet.interfaces import IDelayedCall
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfig, ThrottleParams
@@ -66,7 +66,7 @@ class EmailPusher(Pusher):
self.store = self.hs.get_datastore()
self.email = pusher_config.pushkey
self.timed_call = None # type: Optional[DelayedCall]
self.timed_call = None # type: Optional[IDelayedCall]
self.throttle_params = {} # type: Dict[str, ThrottleParams]
self._inited = False

View File

@@ -48,7 +48,7 @@ from synapse.replication.tcp.commands import (
UserIpCommand,
UserSyncCommand,
)
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams import (
STREAMS_MAP,
AccountDataStream,
@@ -82,7 +82,7 @@ user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache"
# the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[
Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
Tuple[Union[RdataCommand, PositionCommand], IReplicationConnection]
]
@@ -174,7 +174,7 @@ class ReplicationCommandHandler:
# The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
self._connections = [] # type: List[IReplicationConnection]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
@@ -197,7 +197,7 @@ class ReplicationCommandHandler:
# For each connection, the incoming stream names that have received a POSITION
# from that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
self._streams_by_connection = {} # type: Dict[IReplicationConnection, Set[str]]
LaterGauge(
"synapse_replication_tcp_command_queue",
@@ -220,7 +220,7 @@ class ReplicationCommandHandler:
self._server_notices_sender = hs.get_server_notices_sender()
def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
@@ -267,7 +267,7 @@ class ReplicationCommandHandler:
async def _process_command(
self,
cmd: Union[PositionCommand, RdataCommand],
conn: AbstractConnection,
conn: IReplicationConnection,
stream_name: str,
) -> None:
if isinstance(cmd, PositionCommand):
@@ -321,10 +321,10 @@ class ReplicationCommandHandler:
"""Get a list of streams that this instances replicates."""
return self._streams_to_replicate
def on_REPLICATE(self, conn: AbstractConnection, cmd: ReplicateCommand):
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand):
self.send_positions_to_connection(conn)
def send_positions_to_connection(self, conn: AbstractConnection):
def send_positions_to_connection(self, conn: IReplicationConnection):
"""Send current position of all streams this process is source of to
the connection.
"""
@@ -347,7 +347,7 @@ class ReplicationCommandHandler:
)
def on_USER_SYNC(
self, conn: AbstractConnection, cmd: UserSyncCommand
self, conn: IReplicationConnection, cmd: UserSyncCommand
) -> Optional[Awaitable[None]]:
user_sync_counter.inc()
@@ -359,21 +359,23 @@ class ReplicationCommandHandler:
return None
def on_CLEAR_USER_SYNC(
self, conn: AbstractConnection, cmd: ClearUserSyncsCommand
self, conn: IReplicationConnection, cmd: ClearUserSyncsCommand
) -> Optional[Awaitable[None]]:
if self._is_master:
return self._presence_handler.update_external_syncs_clear(cmd.instance_id)
else:
return None
def on_FEDERATION_ACK(self, conn: AbstractConnection, cmd: FederationAckCommand):
def on_FEDERATION_ACK(
self, conn: IReplicationConnection, cmd: FederationAckCommand
):
federation_ack_counter.inc()
if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
def on_USER_IP(
self, conn: AbstractConnection, cmd: UserIpCommand
self, conn: IReplicationConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
user_ip_cache_counter.inc()
@@ -395,7 +397,7 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id)
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes
return
@@ -412,7 +414,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
self, stream_name: str, conn: IReplicationConnection, cmd: RdataCommand
) -> None:
"""Process an RDATA command
@@ -486,7 +488,7 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows
)
def on_POSITION(self, conn: AbstractConnection, cmd: PositionCommand):
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand):
if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes
return
@@ -496,7 +498,7 @@ class ReplicationCommandHandler:
self._add_command_to_stream_queue(conn, cmd)
async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
self, stream_name: str, conn: IReplicationConnection, cmd: PositionCommand
) -> None:
"""Process a POSITION command
@@ -553,7 +555,9 @@ class ReplicationCommandHandler:
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
def on_REMOTE_SERVER_UP(self, conn: AbstractConnection, cmd: RemoteServerUpCommand):
def on_REMOTE_SERVER_UP(
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
):
""""Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data)
@@ -576,7 +580,7 @@ class ReplicationCommandHandler:
# between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn)
def new_connection(self, connection: AbstractConnection):
def new_connection(self, connection: IReplicationConnection):
"""Called when we have a new connection."""
self._connections.append(connection)
@@ -603,7 +607,7 @@ class ReplicationCommandHandler:
UserSyncCommand(self._instance_id, user_id, True, now)
)
def lost_connection(self, connection: AbstractConnection):
def lost_connection(self, connection: IReplicationConnection):
"""Called when a connection is closed/lost."""
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
@@ -624,7 +628,7 @@ class ReplicationCommandHandler:
return bool(self._connections)
def send_command(
self, cmd: Command, ignore_conn: Optional[AbstractConnection] = None
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
):
"""Send a command to all connected connections.

View File

@@ -46,7 +46,6 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
import abc
import fcntl
import logging
import struct
@@ -54,6 +53,7 @@ from inspect import isawaitable
from typing import TYPE_CHECKING, List, Optional
from prometheus_client import Counter
from zope.interface import Interface, implementer
from twisted.internet import task
from twisted.protocols.basic import LineOnlyReceiver
@@ -121,6 +121,14 @@ class ConnectionStates:
CLOSED = "closed"
class IReplicationConnection(Interface):
"""An interface for replication connections."""
def send_command(cmd: Command):
"""Send the command down the connection"""
@implementer(IReplicationConnection)
class BaseReplicationStreamProtocol(LineOnlyReceiver):
"""Base replication protocol shared between client and server.
@@ -495,20 +503,6 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.send_command(ReplicateCommand())
class AbstractConnection(abc.ABC):
"""An interface for replication connections."""
@abc.abstractmethod
def send_command(self, cmd: Command):
"""Send the command down the connection"""
pass
# This tells python that `BaseReplicationStreamProtocol` implements the
# interface.
AbstractConnection.register(BaseReplicationStreamProtocol)
# The following simply registers metrics for the replication connections
pending_commands = LaterGauge(

View File

@@ -19,6 +19,11 @@ from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast
import attr
import txredisapi
from zope.interface import implementer
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import IAddress, IConnector
from twisted.python.failure import Failure
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import (
@@ -32,7 +37,7 @@ from synapse.replication.tcp.commands import (
parse_command_from_line,
)
from synapse.replication.tcp.protocol import (
AbstractConnection,
IReplicationConnection,
tcp_inbound_commands_counter,
tcp_outbound_commands_counter,
)
@@ -62,7 +67,8 @@ class ConstantProperty(Generic[T, V]):
pass
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
@implementer(IReplicationConnection)
class RedisSubscriber(txredisapi.SubscriberProtocol):
"""Connection to redis subscribed to replication stream.
This class fulfils two functions:
@@ -71,7 +77,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
connection, parsing *incoming* messages into replication commands, and passing them
to `ReplicationCommandHandler`
(b) it implements the AbstractConnection API, where it sends *outgoing* commands
(b) it implements the IReplicationConnection API, where it sends *outgoing* commands
onto outbound_redis_connection.
Due to the vagaries of `txredisapi` we don't want to have a custom
@@ -253,6 +259,37 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
except Exception:
logger.warning("Failed to send ping to a redis connection")
# ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
# it's rubbish. We add our own here.
def startedConnecting(self, connector: IConnector):
logger.info(
"Connecting to redis server %s", format_address(connector.getDestination())
)
super().startedConnecting(connector)
def clientConnectionFailed(self, connector: IConnector, reason: Failure):
logger.info(
"Connection to redis server %s failed: %s",
format_address(connector.getDestination()),
reason.value,
)
super().clientConnectionFailed(connector, reason)
def clientConnectionLost(self, connector: IConnector, reason: Failure):
logger.info(
"Connection to redis server %s lost: %s",
format_address(connector.getDestination()),
reason.value,
)
super().clientConnectionLost(connector, reason)
def format_address(address: IAddress) -> str:
if isinstance(address, (IPv4Address, IPv6Address)):
return "%s:%i" % (address.host, address.port)
return str(address)
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately

View File

@@ -15,10 +15,9 @@
import re
import twisted.web.server
import synapse.api.auth
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
from synapse.http.site import SynapseRequest
from synapse.types import UserID
@@ -37,13 +36,11 @@ def admin_patterns(path_regex: str, version: str = "v1"):
return patterns
async def assert_requester_is_admin(
auth: synapse.api.auth.Auth, request: twisted.web.server.Request
) -> None:
async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None:
"""Verify that the requester is an admin user
Args:
auth: api.auth.Auth singleton
auth: Auth singleton
request: incoming request
Raises:
@@ -53,11 +50,11 @@ async def assert_requester_is_admin(
await assert_user_is_admin(auth, requester.user)
async def assert_user_is_admin(auth: synapse.api.auth.Auth, user_id: UserID) -> None:
async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
"""Verify that the given user is an admin user
Args:
auth: api.auth.Auth singleton
auth: Auth singleton
user_id: user to check
Raises:

View File

@@ -17,10 +17,9 @@
import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
@@ -50,7 +49,9 @@ class QuarantineMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -75,7 +76,9 @@ class QuarantineMediaByUser(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -103,7 +106,7 @@ class QuarantineMediaByID(RestServlet):
self.auth = hs.get_auth()
async def on_POST(
self, request: Request, server_name: str, media_id: str
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -127,7 +130,9 @@ class ProtectMediaByID(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -148,7 +153,9 @@ class ListMediaInRoom(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@@ -166,7 +173,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -189,7 +196,7 @@ class DeleteMediaByID(RestServlet):
self.media_repository = hs.get_media_repository()
async def on_DELETE(
self, request: Request, server_name: str, media_id: str
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
@@ -218,7 +225,9 @@ class DeleteMediaByDateSize(RestServlet):
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, server_name: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)

View File

@@ -269,7 +269,10 @@ class UserRestServletV2(RestServlet):
target_user.to_string(), False, requester, by_admin=True
)
elif not deactivate and user["deactivated"]:
if "password" not in body:
if (
"password" not in body
and self.hs.config.password_localdb_enabled
):
raise SynapseError(
400, "Must provide a password to re-activate an account."
)

View File

@@ -32,6 +32,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.types import GroupID, JsonDict
from ._base import client_patterns
@@ -70,7 +71,9 @@ class GroupServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -81,7 +84,9 @@ class GroupServlet(RestServlet):
return 200, group_description
@_validate_group_id
async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -111,7 +116,9 @@ class GroupSummaryServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -144,7 +151,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: Request, group_id: str, category_id: Optional[str], room_id: str
self,
request: SynapseRequest,
group_id: str,
category_id: Optional[str],
room_id: str,
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -176,7 +187,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: Request, group_id: str, category_id: str, room_id: str
self, request: SynapseRequest, group_id: str, category_id: str, room_id: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -206,7 +217,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_GET(
self, request: Request, group_id: str, category_id: str
self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -219,7 +230,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: Request, group_id: str, category_id: str
self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -247,7 +258,7 @@ class GroupCategoryServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: Request, group_id: str, category_id: str
self, request: SynapseRequest, group_id: str, category_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -274,7 +285,9 @@ class GroupCategoriesServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -298,7 +311,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_GET(
self, request: Request, group_id: str, role_id: str
self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -311,7 +324,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: Request, group_id: str, role_id: str
self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -339,7 +352,7 @@ class GroupRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: Request, group_id: str, role_id: str
self, request: SynapseRequest, group_id: str, role_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -366,7 +379,9 @@ class GroupRolesServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -399,7 +414,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: Request, group_id: str, role_id: Optional[str], user_id: str
self,
request: SynapseRequest,
group_id: str,
role_id: Optional[str],
user_id: str,
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -431,7 +450,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: Request, group_id: str, role_id: str, user_id: str
self, request: SynapseRequest, group_id: str, role_id: str, user_id: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -458,7 +477,9 @@ class GroupRoomServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -481,7 +502,9 @@ class GroupUsersServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -504,7 +527,9 @@ class GroupInvitedUsersServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -526,7 +551,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -554,7 +581,7 @@ class GroupCreateServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -598,7 +625,7 @@ class GroupAdminRoomsServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: Request, group_id: str, room_id: str
self, request: SynapseRequest, group_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -615,7 +642,7 @@ class GroupAdminRoomsServlet(RestServlet):
@_validate_group_id
async def on_DELETE(
self, request: Request, group_id: str, room_id: str
self, request: SynapseRequest, group_id: str, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -646,7 +673,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
@_validate_group_id
async def on_PUT(
self, request: Request, group_id: str, room_id: str, config_key: str
self, request: SynapseRequest, group_id: str, room_id: str, config_key: str
):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -678,7 +705,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.is_mine_id = hs.is_mine_id
@_validate_group_id
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id, user_id
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -708,7 +737,9 @@ class GroupAdminUsersKickServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id, user_id
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -735,7 +766,9 @@ class GroupSelfLeaveServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -762,7 +795,9 @@ class GroupSelfJoinServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -789,7 +824,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -816,7 +853,9 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.store = hs.get_datastore()
@_validate_group_id
async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
async def on_PUT(
self, request: SynapseRequest, group_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -839,7 +878,9 @@ class PublicisedGroupsForUserServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -859,7 +900,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
@@ -881,7 +922,7 @@ class GroupsForUserServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()

View File

@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.site import SynapseRequest
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@@ -35,7 +36,7 @@ class MediaConfigResource(DirectServeJsonResource):
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
async def _async_render_GET(self, request: Request) -> None:
async def _async_render_GET(self, request: SynapseRequest) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)

View File

@@ -39,6 +39,7 @@ from synapse.http.server import (
respond_with_json_bytes,
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
@@ -185,7 +186,7 @@ class PreviewUrlResource(DirectServeJsonResource):
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_GET(self, request: Request) -> None:
async def _async_render_GET(self, request: SynapseRequest) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)

View File

@@ -22,6 +22,7 @@ from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.media.v1.media_storage import SpamMediaException
if TYPE_CHECKING:
@@ -49,7 +50,7 @@ class UploadResource(DirectServeJsonResource):
async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
async def _async_render_POST(self, request: Request) -> None:
async def _async_render_POST(self, request: SynapseRequest) -> None:
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point

View File

@@ -351,11 +351,9 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
return (
InsecureInterceptableContextFactory()
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
else RegularPolicyForHTTPS()
)
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use:
return InsecureInterceptableContextFactory()
return RegularPolicyForHTTPS()
@cache_in_self
def get_simple_http_client(self) -> SimpleHttpClient:

View File

@@ -16,12 +16,23 @@ from io import BytesIO
from mock import Mock
from netaddr import IPSet
from twisted.internet.error import DNSLookupError
from twisted.python.failure import Failure
from twisted.web.client import ResponseDone
from twisted.test.proto_helpers import AccumulatingProtocol
from twisted.web.client import Agent, ResponseDone
from twisted.web.iweb import UNKNOWN_LENGTH
from synapse.http.client import BodyExceededMaxSize, read_body_with_max_size
from synapse.api.errors import SynapseError
from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
BodyExceededMaxSize,
read_body_with_max_size,
)
from tests.server import FakeTransport, get_clock
from tests.unittest import TestCase
@@ -119,3 +130,114 @@ class ReadBodyWithMaxSizeTests(TestCase):
# The data is never consumed.
self.assertEqual(result.getvalue(), b"")
class BlacklistingAgentTest(TestCase):
def setUp(self):
self.reactor, self.clock = get_clock()
self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
# Configure the reactor's DNS resolver.
for (domain, ip) in (
(self.safe_domain, self.safe_ip),
(self.unsafe_domain, self.unsafe_ip),
(self.allowed_domain, self.allowed_ip),
):
self.reactor.lookups[domain.decode()] = ip.decode()
self.reactor.lookups[ip.decode()] = ip.decode()
self.ip_whitelist = IPSet([self.allowed_ip.decode()])
self.ip_blacklist = IPSet(["5.0.0.0/8"])
def test_reactor(self):
"""Apply the blacklisting reactor and ensure it properly blocks connections to particular domains and IPs."""
agent = Agent(
BlacklistingReactorWrapper(
self.reactor,
ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
),
)
# The unsafe domains and IPs should be rejected.
for domain in (self.unsafe_domain, self.unsafe_ip):
self.failureResultOf(
agent.request(b"GET", b"http://" + domain), DNSLookupError
)
# The safe domains IPs should be accepted.
for domain in (
self.safe_domain,
self.allowed_domain,
self.safe_ip,
self.allowed_ip,
):
d = agent.request(b"GET", b"http://" + domain)
# Grab the latest TCP connection.
(
host,
port,
client_factory,
_timeout,
_bindAddress,
) = self.reactor.tcpClients[-1]
# Make the connection and pump data through it.
client = client_factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
)
response = self.successResultOf(d)
self.assertEqual(response.code, 200)
def test_agent(self):
"""Apply the blacklisting agent and ensure it properly blocks connections to particular IPs."""
agent = BlacklistingAgentWrapper(
Agent(self.reactor),
ip_whitelist=self.ip_whitelist,
ip_blacklist=self.ip_blacklist,
)
# The unsafe IPs should be rejected.
self.failureResultOf(
agent.request(b"GET", b"http://" + self.unsafe_ip), SynapseError
)
# The safe and unsafe domains and safe IPs should be accepted.
for domain in (
self.safe_domain,
self.unsafe_domain,
self.allowed_domain,
self.safe_ip,
self.allowed_ip,
):
d = agent.request(b"GET", b"http://" + domain)
# Grab the latest TCP connection.
(
host,
port,
client_factory,
_timeout,
_bindAddress,
) = self.reactor.tcpClients[-1]
# Make the connection and pump data through it.
client = client_factory.buildProtocol(None)
server = AccumulatingProtocol()
server.makeConnection(FakeTransport(client, self.reactor))
client.makeConnection(FakeTransport(server, self.reactor))
client.dataReceived(
b"HTTP/1.0 200 OK\r\nContent-Length: 0\r\nContent-Type: text/html\r\n\r\n"
)
response = self.successResultOf(d)
self.assertEqual(response.code, 200)

View File

@@ -17,7 +17,7 @@ import mock
from synapse.app.generic_worker import GenericWorkerServer
from synapse.replication.tcp.commands import FederationAckCommand
from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.protocol import IReplicationConnection
from synapse.replication.tcp.streams.federation import FederationStream
from tests.unittest import HomeserverTestCase
@@ -51,8 +51,10 @@ class FederationAckTestCase(HomeserverTestCase):
"""
rch = self.hs.get_tcp_replication()
# wire up the ReplicationCommandHandler to a mock connection
mock_connection = mock.Mock(spec=AbstractConnection)
# wire up the ReplicationCommandHandler to a mock connection, which needs
# to implement IReplicationConnection. (Note that Mock doesn't understand
# interfaces, but casing an interface to a list gives the attributes.)
mock_connection = mock.Mock(spec=list(IReplicationConnection))
rch.new_connection(mock_connection)
# tell it it received an RDATA row