diff --git a/changelog.d/9561.misc b/changelog.d/9561.misc new file mode 100644 index 0000000000..6c529a82ee --- /dev/null +++ b/changelog.d/9561.misc @@ -0,0 +1 @@ +Increase the threshold before which outbound federation to a server goes into "catch up" mode, which is expensive for the remote server to handle. diff --git a/changelog.d/9601.feature b/changelog.d/9601.feature new file mode 100644 index 0000000000..5078d63ffa --- /dev/null +++ b/changelog.d/9601.feature @@ -0,0 +1 @@ +Optimise handling of incomplete room history for incoming federation. diff --git a/changelog.d/9608.misc b/changelog.d/9608.misc new file mode 100644 index 0000000000..14c7b78dd9 --- /dev/null +++ b/changelog.d/9608.misc @@ -0,0 +1 @@ +Fix incorrect type hints. diff --git a/changelog.d/9618.misc b/changelog.d/9618.misc new file mode 100644 index 0000000000..14c7b78dd9 --- /dev/null +++ b/changelog.d/9618.misc @@ -0,0 +1 @@ +Fix incorrect type hints. diff --git a/changelog.d/9623.bugfix b/changelog.d/9623.bugfix new file mode 100644 index 0000000000..ecccb46105 --- /dev/null +++ b/changelog.d/9623.bugfix @@ -0,0 +1 @@ +Fix Internal Server Error on `GET /_synapse/client/saml2/authn_response` request. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index 34787e0b1e..080ca40287 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -19,7 +19,7 @@ from typing import Any, List, Optional, Type, Union from twisted.internet import protocol -class RedisProtocol: +class RedisProtocol(protocol.Protocol): def publish(self, channel: str, message: bytes): ... async def ping(self) -> None: ... async def set( diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index deb519f3ef..cc0d765e5f 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -17,6 +17,7 @@ import datetime import logging from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast +import attr from prometheus_client import Counter from synapse.api.errors import ( @@ -93,6 +94,10 @@ class PerDestinationQueue: self._destination = destination self.transmission_loop_running = False + # Flag to signal to any running transmission loop that there is new data + # queued up to be sent. + self._new_data_to_send = False + # True whilst we are sending events that the remote homeserver missed # because it was unreachable. We start in this state so we can perform # catch-up at startup. @@ -108,7 +113,7 @@ class PerDestinationQueue: # destination (we are the only updater so this is safe) self._last_successful_stream_ordering = None # type: Optional[int] - # a list of pending PDUs + # a queue of pending PDUs self._pending_pdus = [] # type: List[EventBase] # XXX this is never actually used: see @@ -208,6 +213,10 @@ class PerDestinationQueue: transaction in the background. """ + # Mark that we (may) have new things to send, so that any running + # transmission loop will recheck whether there is stuff to send. + self._new_data_to_send = True + if self.transmission_loop_running: # XXX: this can get stuck on by a never-ending # request at which point pending_pdus just keeps growing. @@ -250,125 +259,41 @@ class PerDestinationQueue: pending_pdus = [] while True: - # We have to keep 2 free slots for presence and rr_edus - limit = MAX_EDUS_PER_TRANSACTION - 2 + self._new_data_to_send = False - device_update_edus, dev_list_id = await self._get_device_update_edus( - limit - ) - - limit -= len(device_update_edus) - - ( - to_device_edus, - device_stream_id, - ) = await self._get_to_device_message_edus(limit) - - pending_edus = device_update_edus + to_device_edus - - # BEGIN CRITICAL SECTION - # - # In order to avoid a race condition, we need to make sure that - # the following code (from popping the queues up to the point - # where we decide if we actually have any pending messages) is - # atomic - otherwise new PDUs or EDUs might arrive in the - # meantime, but not get sent because we hold the - # transmission_loop_running flag. - - pending_pdus = self._pending_pdus - - # We can only include at most 50 PDUs per transactions - pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:] - - pending_edus.extend(self._get_rr_edus(force_flush=False)) - pending_presence = self._pending_presence - self._pending_presence = {} - if pending_presence: - pending_edus.append( - Edu( - origin=self._server_name, - destination=self._destination, - edu_type="m.presence", - content={ - "push": [ - format_user_presence_state( - presence, self._clock.time_msec() - ) - for presence in pending_presence.values() - ] - }, - ) - ) - - pending_edus.extend( - self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) - ) - while ( - len(pending_edus) < MAX_EDUS_PER_TRANSACTION - and self._pending_edus_keyed + async with _TransactionQueueManager(self) as ( + pending_pdus, + pending_edus, ): - _, val = self._pending_edus_keyed.popitem() - pending_edus.append(val) + if not pending_pdus and not pending_edus: + logger.debug("TX [%s] Nothing to send", self._destination) - if pending_pdus: - logger.debug( - "TX [%s] len(pending_pdus_by_dest[dest]) = %d", - self._destination, - len(pending_pdus), + # If we've gotten told about new things to send during + # checking for things to send, we try looking again. + # Otherwise new PDUs or EDUs might arrive in the meantime, + # but not get sent because we hold the + # `transmission_loop_running` flag. + if self._new_data_to_send: + continue + else: + return + + if pending_pdus: + logger.debug( + "TX [%s] len(pending_pdus_by_dest[dest]) = %d", + self._destination, + len(pending_pdus), + ) + + await self._transaction_manager.send_new_transaction( + self._destination, pending_pdus, pending_edus ) - if not pending_pdus and not pending_edus: - logger.debug("TX [%s] Nothing to send", self._destination) - self._last_device_stream_id = device_stream_id - return - - # if we've decided to send a transaction anyway, and we have room, we - # may as well send any pending RRs - if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: - pending_edus.extend(self._get_rr_edus(force_flush=True)) - - # END CRITICAL SECTION - - success = await self._transaction_manager.send_new_transaction( - self._destination, pending_pdus, pending_edus - ) - if success: sent_transactions_counter.inc() sent_edus_counter.inc(len(pending_edus)) for edu in pending_edus: sent_edus_by_type.labels(edu.edu_type).inc() - # Remove the acknowledged device messages from the database - # Only bother if we actually sent some device messages - if to_device_edus: - await self._store.delete_device_msgs_for_remote( - self._destination, device_stream_id - ) - # also mark the device updates as sent - if device_update_edus: - logger.info( - "Marking as sent %r %r", self._destination, dev_list_id - ) - await self._store.mark_as_sent_devices_by_remote( - self._destination, dev_list_id - ) - - self._last_device_stream_id = device_stream_id - self._last_device_list_stream_id = dev_list_id - - if pending_pdus: - # we sent some PDUs and it was successful, so update our - # last_successful_stream_ordering in the destinations table. - final_pdu = pending_pdus[-1] - last_successful_stream_ordering = ( - final_pdu.internal_metadata.stream_ordering - ) - assert last_successful_stream_ordering - await self._store.set_destination_last_successful_stream_ordering( - self._destination, last_successful_stream_ordering - ) - else: - break except NotRetryingDestination as e: logger.debug( "TX [%s] not ready for retry yet (next retry at %s) - " @@ -401,7 +326,7 @@ class PerDestinationQueue: self._pending_presence = {} self._pending_rrs = {} - self._start_catching_up() + self._start_catching_up() except FederationDeniedError as e: logger.info(e) except HttpResponseException as e: @@ -412,7 +337,6 @@ class PerDestinationQueue: e, ) - self._start_catching_up() except RequestSendFailed as e: logger.warning( "TX [%s] Failed to send transaction: %s", self._destination, e @@ -422,16 +346,12 @@ class PerDestinationQueue: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) - - self._start_catching_up() except Exception: logger.exception("TX [%s] Failed to send transaction", self._destination) for p in pending_pdus: logger.info( "Failed to send event %s to %s", p.event_id, self._destination ) - - self._start_catching_up() finally: # We want to be *very* sure we clear this after we stop processing self.transmission_loop_running = False @@ -499,13 +419,10 @@ class PerDestinationQueue: rooms = [p.room_id for p in catchup_pdus] logger.info("Catching up rooms to %s: %r", self._destination, rooms) - success = await self._transaction_manager.send_new_transaction( + await self._transaction_manager.send_new_transaction( self._destination, catchup_pdus, [] ) - if not success: - return - sent_transactions_counter.inc() final_pdu = catchup_pdus[-1] self._last_successful_stream_ordering = cast( @@ -584,3 +501,135 @@ class PerDestinationQueue: """ self._catching_up = True self._pending_pdus = [] + + +@attr.s(slots=True) +class _TransactionQueueManager: + """A helper async context manager for pulling stuff off the queues and + tracking what was last successfully sent, etc. + """ + + queue = attr.ib(type=PerDestinationQueue) + + _device_stream_id = attr.ib(type=Optional[int], default=None) + _device_list_id = attr.ib(type=Optional[int], default=None) + _last_stream_ordering = attr.ib(type=Optional[int], default=None) + _pdus = attr.ib(type=List[EventBase], factory=list) + + async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]: + # First we calculate the EDUs we want to send, if any. + + # We start by fetching device related EDUs, i.e device updates and to + # device messages. We have to keep 2 free slots for presence and rr_edus. + limit = MAX_EDUS_PER_TRANSACTION - 2 + + device_update_edus, dev_list_id = await self.queue._get_device_update_edus( + limit + ) + + if device_update_edus: + self._device_list_id = dev_list_id + else: + self.queue._last_device_list_stream_id = dev_list_id + + limit -= len(device_update_edus) + + ( + to_device_edus, + device_stream_id, + ) = await self.queue._get_to_device_message_edus(limit) + + if to_device_edus: + self._device_stream_id = device_stream_id + else: + self.queue._last_device_stream_id = device_stream_id + + pending_edus = device_update_edus + to_device_edus + + # Now add the read receipt EDU. + pending_edus.extend(self.queue._get_rr_edus(force_flush=False)) + + # And presence EDU. + if self.queue._pending_presence: + pending_edus.append( + Edu( + origin=self.queue._server_name, + destination=self.queue._destination, + edu_type="m.presence", + content={ + "push": [ + format_user_presence_state( + presence, self.queue._clock.time_msec() + ) + for presence in self.queue._pending_presence.values() + ] + }, + ) + ) + self.queue._pending_presence = {} + + # Finally add any other types of EDUs if there is room. + pending_edus.extend( + self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) + ) + while ( + len(pending_edus) < MAX_EDUS_PER_TRANSACTION + and self.queue._pending_edus_keyed + ): + _, val = self.queue._pending_edus_keyed.popitem() + pending_edus.append(val) + + # Now we look for any PDUs to send, by getting up to 50 PDUs from the + # queue + self._pdus = self.queue._pending_pdus[:50] + + if not self._pdus and not pending_edus: + return [], [] + + # if we've decided to send a transaction anyway, and we have room, we + # may as well send any pending RRs + if len(pending_edus) < MAX_EDUS_PER_TRANSACTION: + pending_edus.extend(self.queue._get_rr_edus(force_flush=True)) + + if self._pdus: + self._last_stream_ordering = self._pdus[ + -1 + ].internal_metadata.stream_ordering + assert self._last_stream_ordering + + return self._pdus, pending_edus + + async def __aexit__(self, exc_type, exc, tb): + if exc_type is not None: + # Failed to send transaction, so we bail out. + return + + # Successfully sent transactions, so we remove pending PDUs from the queue + if self._pdus: + self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :] + + # Succeeded to send the transaction so we record where we have sent up + # to in the various streams + + if self._device_stream_id: + await self.queue._store.delete_device_msgs_for_remote( + self.queue._destination, self._device_stream_id + ) + self.queue._last_device_stream_id = self._device_stream_id + + # also mark the device updates as sent + if self._device_list_id: + logger.info( + "Marking as sent %r %r", self.queue._destination, self._device_list_id + ) + await self.queue._store.mark_as_sent_devices_by_remote( + self.queue._destination, self._device_list_id + ) + self.queue._last_device_list_stream_id = self._device_list_id + + if self._last_stream_ordering: + # we sent some PDUs and it was successful, so update our + # last_successful_stream_ordering in the destinations table. + await self.queue._store.set_destination_last_successful_stream_ordering( + self.queue._destination, self._last_stream_ordering + ) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 2a9cd063c4..07b740c2f2 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -69,15 +69,12 @@ class TransactionManager: destination: str, pdus: List[EventBase], edus: List[Edu], - ) -> bool: + ) -> None: """ Args: destination: The destination to send to (e.g. 'example.org') pdus: In-order list of PDUs to send edus: List of EDUs to send - - Returns: - True iff the transaction was successful """ # Make a transaction-sending opentracing span. This span follows on from @@ -96,8 +93,6 @@ class TransactionManager: edu.strip_context() with start_active_span_follows_from("send_transaction", span_contexts): - success = True - logger.debug("TX [%s] _attempt_new_transaction", destination) txn_id = str(self._next_txn_id) @@ -152,44 +147,29 @@ class TransactionManager: response = await self._transport_layer.send_transaction( transaction, json_data_cb ) - code = 200 except HttpResponseException as e: code = e.code response = e.response - if e.code in (401, 404, 429) or 500 <= e.code: - logger.info( - "TX [%s] {%s} got %d response", destination, txn_id, code - ) - raise e + set_tag(tags.ERROR, True) - logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) + logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) + raise - if code == 200: - for e_id, r in response.get("pdus", {}).items(): - if "error" in r: - logger.warning( - "TX [%s] {%s} Remote returned error for %s: %s", - destination, - txn_id, - e_id, - r, - ) - else: - for p in pdus: + logger.info("TX [%s] {%s} got 200 response", destination, txn_id) + + for e_id, r in response.get("pdus", {}).items(): + if "error" in r: logger.warning( - "TX [%s] {%s} Failed to send event %s", + "TX [%s] {%s} Remote returned error for %s: %s", destination, txn_id, - p.event_id, + e_id, + r, ) - success = False - if success and pdus and destination in self._federation_metrics_domains: + if pdus and destination in self._federation_metrics_domains: last_pdu = pdus[-1] last_pdu_ts_metric.labels(server_name=destination).set( last_pdu.origin_server_ts / 1000 ) - - set_tag(tags.ERROR, not success) - return success diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 19a55f0971..ec2ce679c2 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -356,17 +356,16 @@ class FederationHandler(BaseHandler): # Ask the remote server for the states we don't # know about for p in prevs - seen: - logger.info( - "Requesting state at missing prev_event %s", - event_id, - ) + logger.info("Requesting state after missing prev_event %s", p) with nested_logging_context(p): # note that if any of the missing prevs share missing state or # auth events, the requests to fetch those events are deduped # by the get_pdu_cache in federation_client. - (remote_state, _,) = await self._get_state_for_room( - origin, room_id, p, include_event_in_state=True + remote_state = ( + await self._get_state_after_missing_prev_event( + origin, room_id, p + ) ) remote_state_map = { @@ -541,7 +540,6 @@ class FederationHandler(BaseHandler): destination: str, room_id: str, event_id: str, - include_event_in_state: bool = False, ) -> Tuple[List[EventBase], List[EventBase]]: """Requests all of the room state at a given event from a remote homeserver. @@ -549,11 +547,9 @@ class FederationHandler(BaseHandler): destination: The remote homeserver to query for the state. room_id: The id of the room we're interested in. event_id: The id of the event we want the state at. - include_event_in_state: if true, the event itself will be included in the - returned state event list. Returns: - A list of events in the state, possibly including the event itself, and + A list of events in the state, not including the event itself, and a list of events in the auth chain for the given event. """ ( @@ -565,9 +561,6 @@ class FederationHandler(BaseHandler): desired_events = set(state_event_ids + auth_event_ids) - if include_event_in_state: - desired_events.add(event_id) - event_map = await self._get_events_from_store_or_dest( destination, room_id, desired_events ) @@ -584,13 +577,6 @@ class FederationHandler(BaseHandler): event_map[e_id] for e_id in state_event_ids if e_id in event_map ] - if include_event_in_state: - remote_event = event_map.get(event_id) - if not remote_event: - raise Exception("Unable to get missing prev_event %s" % (event_id,)) - if remote_event.is_state() and remote_event.rejected_reason is None: - remote_state.append(remote_event) - auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map] auth_chain.sort(key=lambda e: e.depth) @@ -664,6 +650,131 @@ class FederationHandler(BaseHandler): return fetched_events + async def _get_state_after_missing_prev_event( + self, + destination: str, + room_id: str, + event_id: str, + ) -> List[EventBase]: + """Requests all of the room state at a given event from a remote homeserver. + + Args: + destination: The remote homeserver to query for the state. + room_id: The id of the room we're interested in. + event_id: The id of the event we want the state at. + + Returns: + A list of events in the state, including the event itself + """ + # TODO: This function is basically the same as _get_state_for_room. Can + # we make backfill() use it, rather than having two code paths? I think the + # only difference is that backfill() persists the prev events separately. + + ( + state_event_ids, + auth_event_ids, + ) = await self.federation_client.get_room_state_ids( + destination, room_id, event_id=event_id + ) + + logger.debug( + "state_ids returned %i state events, %i auth events", + len(state_event_ids), + len(auth_event_ids), + ) + + # start by just trying to fetch the events from the store + desired_events = set(state_event_ids) + desired_events.add(event_id) + logger.debug("Fetching %i events from cache/store", len(desired_events)) + fetched_events = await self.store.get_events( + desired_events, allow_rejected=True + ) + + missing_desired_events = desired_events - fetched_events.keys() + logger.debug( + "We are missing %i events (got %i)", + len(missing_desired_events), + len(fetched_events), + ) + + # We probably won't need most of the auth events, so let's just check which + # we have for now, rather than thrashing the event cache with them all + # unnecessarily. + + # TODO: we probably won't actually need all of the auth events, since we + # already have a bunch of the state events. It would be nice if the + # federation api gave us a way of finding out which we actually need. + + missing_auth_events = set(auth_event_ids) - fetched_events.keys() + missing_auth_events.difference_update( + await self.store.have_seen_events(missing_auth_events) + ) + logger.debug("We are also missing %i auth events", len(missing_auth_events)) + + missing_events = missing_desired_events | missing_auth_events + logger.debug("Fetching %i events from remote", len(missing_events)) + await self._get_events_and_persist( + destination=destination, room_id=room_id, events=missing_events + ) + + # we need to make sure we re-load from the database to get the rejected + # state correct. + fetched_events.update( + (await self.store.get_events(missing_desired_events, allow_rejected=True)) + ) + + # check for events which were in the wrong room. + # + # this can happen if a remote server claims that the state or + # auth_events at an event in room A are actually events in room B + + bad_events = [ + (event_id, event.room_id) + for event_id, event in fetched_events.items() + if event.room_id != room_id + ] + + for bad_event_id, bad_room_id in bad_events: + # This is a bogus situation, but since we may only discover it a long time + # after it happened, we try our best to carry on, by just omitting the + # bad events from the returned state set. + logger.warning( + "Remote server %s claims event %s in room %s is an auth/state " + "event in room %s", + destination, + bad_event_id, + bad_room_id, + room_id, + ) + + del fetched_events[bad_event_id] + + # if we couldn't get the prev event in question, that's a problem. + remote_event = fetched_events.get(event_id) + if not remote_event: + raise Exception("Unable to get missing prev_event %s" % (event_id,)) + + # missing state at that event is a warning, not a blocker + # XXX: this doesn't sound right? it means that we'll end up with incomplete + # state. + failed_to_fetch = desired_events - fetched_events.keys() + if failed_to_fetch: + logger.warning( + "Failed to fetch missing state events for %s %s", + event_id, + failed_to_fetch, + ) + + remote_state = [ + fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events + ] + + if remote_event.is_state() and remote_event.rejected_reason is None: + remote_state.append(remote_event) + + return remote_state + async def _process_received_pdu( self, origin: str, @@ -843,7 +954,6 @@ class FederationHandler(BaseHandler): destination=dest, room_id=room_id, event_id=e_id, - include_event_in_state=False, ) auth_events.update({a.event_id: a for a in auth}) auth_events.update({s.event_id: s for s in state}) diff --git a/synapse/http/client.py b/synapse/http/client.py index d4ab3a2732..1e01e0a9f2 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -45,7 +45,9 @@ from twisted.internet.interfaces import ( IHostResolution, IReactorPluggableNameResolver, IResolutionReceiver, + ITCPTransport, ) +from twisted.internet.protocol import connectionDone from twisted.internet.task import Cooperator from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone @@ -760,6 +762,8 @@ class BodyExceededMaxSize(Exception): class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which immediately errors upon receiving data.""" + transport = None # type: Optional[ITCPTransport] + def __init__(self, deferred: defer.Deferred): self.deferred = deferred @@ -771,18 +775,21 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol): self.deferred.errback(BodyExceededMaxSize()) # Close the connection (forcefully) since all the data will get # discarded anyway. + assert self.transport is not None self.transport.abortConnection() def dataReceived(self, data: bytes) -> None: self._maybe_fail() - def connectionLost(self, reason: Failure) -> None: + def connectionLost(self, reason: Failure = connectionDone) -> None: self._maybe_fail() class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): """A protocol which reads body to a stream, erroring if the body exceeds a maximum size.""" + transport = None # type: Optional[ITCPTransport] + def __init__( self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] ): @@ -805,9 +812,10 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): self.deferred.errback(BodyExceededMaxSize()) # Close the connection (forcefully) since all the data will get # discarded anyway. + assert self.transport is not None self.transport.abortConnection() - def connectionLost(self, reason: Failure) -> None: + def connectionLost(self, reason: Failure = connectionDone) -> None: # If the maximum size was already exceeded, there's nothing to do. if self.deferred.called: return diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index ee909f3fc5..a8894beadf 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -302,7 +302,7 @@ class ReplicationCommandHandler: hs, outbound_redis_connection ) hs.get_reactor().connectTCP( - hs.config.redis.redis_host, + hs.config.redis.redis_host.encode(), hs.config.redis.redis_port, self._factory, ) @@ -311,7 +311,7 @@ class ReplicationCommandHandler: self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) host = hs.config.worker_replication_host port = hs.config.worker_replication_port - hs.get_reactor().connectTCP(host, port, self._factory) + hs.get_reactor().connectTCP(host.encode(), port, self._factory) def get_streams(self) -> Dict[str, Stream]: """Get a map from stream name to all streams.""" diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 8e4734b59c..825900f64c 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -56,6 +56,7 @@ from prometheus_client import Counter from zope.interface import Interface, implementer from twisted.internet import task +from twisted.internet.tcp import Connection from twisted.protocols.basic import LineOnlyReceiver from twisted.python.failure import Failure @@ -145,6 +146,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): (if they send a `PING` command) """ + # The transport is going to be an ITCPTransport, but that doesn't have the + # (un)registerProducer methods, those are only on the implementation. + transport = None # type: Connection + delimiter = b"\n" # Valid commands we expect to receive @@ -189,6 +194,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): connected_connections.append(self) # Register connection for metrics + assert self.transport is not None self.transport.registerProducer(self, True) # For the *Producing callbacks self._send_pending_commands() @@ -213,6 +219,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): logger.info( "[%s] Failed to close connection gracefully, aborting", self.id() ) + assert self.transport is not None self.transport.abortConnection() else: if now - self.last_sent_command >= PING_TIME: @@ -302,6 +309,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def close(self): logger.warning("[%s] Closing connection", self.id()) self.time_we_closed = self.clock.time_msec() + assert self.transport is not None self.transport.loseConnection() self.on_connection_closed() @@ -399,6 +407,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): def connectionLost(self, reason): logger.info("[%s] Replication connection closed: %r", self.id(), reason) if isinstance(reason, Failure): + assert reason.type is not None connection_close_counter.labels(reason.type.__name__).inc() else: connection_close_counter.labels(reason.__class__.__name__).inc() diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 7cccde097d..2f4d407f94 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -365,6 +365,6 @@ def lazyConnection( factory.continueTrying = reconnect reactor = hs.get_reactor() - reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None) + reactor.connectTCP(host.encode(), port, factory, timeout=30, bindAddress=None) return factory.handler diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py index f6668fb5e3..4dfadf1bfb 100644 --- a/synapse/rest/synapse/client/saml2/response_resource.py +++ b/synapse/rest/synapse/client/saml2/response_resource.py @@ -14,24 +14,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING + from synapse.http.server import DirectServeHtmlResource +if TYPE_CHECKING: + from synapse.server import HomeServer + class SAML2ResponseResource(DirectServeHtmlResource): """A Twisted web resource which handles the SAML response""" isLeaf = 1 - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self._saml_handler = hs.get_saml_handler() + self._sso_handler = hs.get_sso_handler() async def _async_render_GET(self, request): # We're not expecting any GET request on that resource if everything goes right, # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. # In this case, just tell the user that something went wrong and they should # try to authenticate again. - self._saml_handler._render_error( + self._sso_handler.render_error( request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 34e8ddc62f..d47c13d03f 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import itertools + import logging import threading from collections import namedtuple @@ -1044,7 +1044,8 @@ class EventsWorkerStore(SQLBaseStore): Returns: set[str]: The events we have already seen. """ - results = set() + # if the event cache contains the event, obviously we've seen it. + results = {x for x in event_ids if self._get_event_cache.contains(x)} def have_seen_events_txn(txn, chunk): sql = "SELECT event_id FROM events as e WHERE " @@ -1052,12 +1053,9 @@ class EventsWorkerStore(SQLBaseStore): txn.database_engine, "e.event_id", chunk ) txn.execute(sql + clause, args) - for (event_id,) in txn: - results.add(event_id) + results.update(row[0] for row in txn) - # break the input up into chunks of 100 - input_iterator = iter(event_ids) - for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []): + for chunk in batch_iter((x for x in event_ids if x not in results), 100): await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index b921d63d30..0309661841 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -350,11 +350,11 @@ class TransactionStore(TransactionWorkerStore): self.db_pool.simple_upsert_many_txn( txn, - "destination_rooms", - ["destination", "room_id"], - rows, - ["stream_ordering"], - [(stream_ordering,)] * len(rows), + table="destination_rooms", + key_names=("destination", "room_id"), + key_values=rows, + value_names=["stream_ordering"], + value_values=[(stream_ordering,)] * len(rows), ) async def get_destination_last_successful_stream_ordering( diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 1a3ccb263d..6f96cd7940 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -7,6 +7,7 @@ from synapse.federation.sender import PerDestinationQueue, TransactionManager from synapse.federation.units import Edu from synapse.rest import admin from synapse.rest.client.v1 import login, room +from synapse.util.retryutils import NotRetryingDestination from tests.test_utils import event_injection, make_awaitable from tests.unittest import FederatingHomeserverTestCase, override_config @@ -49,7 +50,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): else: data = json_cb() self.failed_pdus.extend(data["pdus"]) - raise IOError("Failed to connect because this is a test!") + raise NotRetryingDestination(0, 24 * 60 * 60 * 1000, txn.destination) def get_destination_room(self, room: str, destination: str = "host2") -> dict: """ diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 20940c8107..67b7913666 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Dict, List, Optional, Tuple - -import attr +from typing import Any, Callable, Dict, List, Optional, Tuple, Type from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.protocol import Protocol @@ -158,10 +156,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Set up client side protocol client_protocol = client_factory.buildProtocol(None) - request_factory = OneShotRequestFactory() - # Set up the server side protocol - channel = _PushHTTPChannel(self.reactor, request_factory, self.site) + channel = _PushHTTPChannel(self.reactor, SynapseRequest, self.site) # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -183,7 +179,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): server_to_client_transport.loseConnection() client_to_server_transport.loseConnection() - return request_factory.request + return channel.request def assert_request_is_get_repl_stream_updates( self, request: SynapseRequest, stream_name: str @@ -237,7 +233,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): if self.hs.config.redis.redis_enabled: # Handle attempts to connect to fake redis server. self.reactor.add_tcp_client_callback( - "localhost", + b"localhost", 6379, self.connect_any_redis_attempts, ) @@ -392,10 +388,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Set up client side protocol client_protocol = client_factory.buildProtocol(None) - request_factory = OneShotRequestFactory() - # Set up the server side protocol - channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs]) + channel = _PushHTTPChannel(self.reactor, SynapseRequest, self._hs_to_site[hs]) # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -421,7 +415,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): clients = self.reactor.tcpClients while clients: (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, "localhost") + self.assertEqual(host, b"localhost") self.assertEqual(port, 6379) client_protocol = client_factory.buildProtocol(None) @@ -453,21 +447,6 @@ class TestReplicationDataHandler(GenericWorkerReplicationHandler): self.received_rdata_rows.append((stream_name, token, r)) -@attr.s() -class OneShotRequestFactory: - """A simple request factory that generates a single `SynapseRequest` and - stores it for future use. Can only be used once. - """ - - request = attr.ib(default=None) - - def __call__(self, *args, **kwargs): - assert self.request is None - - self.request = SynapseRequest(*args, **kwargs) - return self.request - - class _PushHTTPChannel(HTTPChannel): """A HTTPChannel that wraps pull producers to push producers. @@ -479,7 +458,7 @@ class _PushHTTPChannel(HTTPChannel): """ def __init__( - self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site + self, reactor: IReactorTime, request_factory: Type[Request], site: Site ): super().__init__() self.reactor = reactor @@ -510,6 +489,11 @@ class _PushHTTPChannel(HTTPChannel): request.responseHeaders.setRawHeaders(b"connection", [b"close"]) return False + def requestDone(self, request): + # Store the request for inspection. + self.request = request + super().requestDone(request) + class _PullToPushProducer: """A push producer that wraps a pull producer.""" @@ -597,6 +581,8 @@ class FakeRedisPubSubServer: class FakeRedisPubSubProtocol(Protocol): """A connection from a client talking to the fake Redis server.""" + transport = None # type: Optional[FakeTransport] + def __init__(self, server: FakeRedisPubSubServer): self._server = server self._reader = hiredis.Reader() @@ -641,6 +627,8 @@ class FakeRedisPubSubProtocol(Protocol): def send(self, msg): """Send a message back to the client.""" + assert self.transport is not None + raw = self.encode(msg).encode("utf-8") self.transport.write(raw) diff --git a/tests/server.py b/tests/server.py index 863f6da738..2287d20076 100644 --- a/tests/server.py +++ b/tests/server.py @@ -16,6 +16,7 @@ from twisted.internet.interfaces import ( IReactorPluggableNameResolver, IReactorTCP, IResolverSimple, + ITransport, ) from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock @@ -467,6 +468,7 @@ def get_clock(): return clock, hs_clock +@implementer(ITransport) @attr.s(cmp=False) class FakeTransport: """ diff --git a/tox.ini b/tox.ini index 52168cebe6..bd4739f30f 100644 --- a/tox.ini +++ b/tox.ini @@ -189,7 +189,5 @@ commands= [testenv:mypy] deps = {[base]deps} - # Type hints are broken with Twisted > 20.3.0, see https://github.com/matrix-org/synapse/issues/9513 - twisted==20.3.0 extras = all,mypy commands = mypy