Compare commits

...

9 Commits

Author SHA1 Message Date
Patrick Cloke
e0b60a9b4e temp 2022-03-11 11:26:03 -05:00
Patrick Cloke
b32bb82bee temp 2022-03-11 10:35:23 -05:00
Patrick Cloke
829139c3d5 Attempt to re-connect better. 2022-03-11 10:35:22 -05:00
Patrick Cloke
7375bd4828 More robust-ness against dying connections. 2022-03-11 10:34:59 -05:00
Erik Johnston
fd491969a6 Attempt some progress 2022-03-11 10:34:58 -05:00
Patrick Cloke
0f3798dac7 Rip out TCP replication bits for tests and hook up Redis replication. 2022-03-11 10:33:58 -05:00
Patrick Cloke
9e1dfc68fd Use redis for all replication tests. 2022-03-11 10:33:58 -05:00
Patrick Cloke
b7d7a1b0a8 Respond to Redis PING messages. 2022-03-11 10:33:58 -05:00
Patrick Cloke
e545948eef Use the reactor from the HomeServer. 2022-03-11 10:33:58 -05:00
7 changed files with 192 additions and 83 deletions

View File

@@ -118,13 +118,21 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
try:
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
except txredisapi.ConnectionError:
# The connection died, the factory will attempt to reconnect.
return
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
# If the connection has been severed for some reason, bail.
if not self.connected:
return
self.synapse_handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
# We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the
@@ -255,7 +263,15 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
replyTimeout=replyTimeout,
convertNumbers=convertNumbers,
)
self.hs = hs
# Set the homeserver reactor as the clock, if this is not done than
# twisted.internet.protocol.ReconnectingClientFactory.retry will default
# to the reactor.
self.clock = hs.get_reactor()
# Send pings every 30 seconds (not that get_clock() returns a Clock, not
# a reactor).
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
@wrap_as_background_process("redis_ping")
@@ -353,6 +369,7 @@ def lazyConnection(
reconnect: bool = True,
password: Optional[str] = None,
replyTimeout: int = 30,
handler: Optional[txredisapi.ConnectionHandler] = None,
) -> txredisapi.ConnectionHandler:
"""Creates a connection to Redis that is lazily set up and reconnects if the
connections is lost.

View File

@@ -549,7 +549,6 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = {
"presence_writer": {"host": "testserv", "port": 1001},

View File

@@ -14,20 +14,14 @@
import logging
from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
from twisted.python.failure import Failure
from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import (
ReplicationStreamProtocolFactory,
ServerReplicationStreamProtocol,
)
from synapse.server import HomeServer
from tests import unittest
@@ -41,6 +35,55 @@ except ImportError:
logger = logging.getLogger(__name__)
class FakeOutboundConnector:
"""
A fake connector class, reconnects.
"""
def __init__(self, hs: HomeServer):
self._hs = hs
def stopConnecting(self):
pass
def connect(self):
# Restart replication.
from synapse.replication.tcp.redis import lazyConnection
handler = self._hs.get_outbound_redis_connection()
reactor = self._hs.get_reactor()
reactor.connectTCP(
self._hs.config.redis.redis_host,
self._hs.config.redis.redis_port,
handler._factory,
timeout=30,
bindAddress=None,
)
def getDestination(self):
return "blah"
class FakeReplicationHandlerConnector:
"""
A fake connector class, reconnects.
"""
def __init__(self, hs: HomeServer):
self._hs = hs
def stopConnecting(self):
pass
def connect(self):
# Restart replication.
self._hs.get_replication_command_handler().start_replication(self._hs)
def getDestination(self):
return "blah"
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
@@ -49,16 +92,33 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
def default_config(self):
config = super().default_config()
config["redis"] = {"enabled": True}
return config
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
IPv4Address("TCP", "127.0.0.1", 0)
)
# Fake in memory Redis server that servers can connect to.
self._redis_transports = []
self._redis_server = FakeRedisPubSubServer()
# We may have an attempt to connect to redis for the external cache already.
self.connect_any_redis_attempts()
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["localhost"] = "127.0.0.1"
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost",
6379,
self.connect_any_redis_attempts,
)
self.worker_hs = self.setup_test_homeserver(
federation_http_client=None,
homeserver_to_use=GenericWorkerServer,
@@ -81,18 +141,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.test_handler = self._build_replication_data_handler()
self.worker_hs._replication_data_handler = self.test_handler # type: ignore[attr-defined]
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
self.worker_hs,
"client",
"test",
clock,
repl_handler,
self.hs.get_replication_command_handler().start_replication(self.hs)
self.worker_hs.get_replication_command_handler().start_replication(
self.worker_hs
)
self._client_transport = None
self._server_transport = None
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_synapse/replication"] = ReplicationRestResource(self.hs)
@@ -109,26 +162,46 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
return TestReplicationDataHandler(self.worker_hs)
def reconnect(self):
if self._client_transport:
self.client.close()
self.disconnect()
print("RECONNECTING")
if self._server_transport:
self.server.close()
# Make a `FakeConnector` to emulate the behavior of `connectTCP. That
# creates an `IConnector`, which is responsible for calling the factory
# `clientConnectionLost`. The reconnecting factory then calls
# `IConnector.connect` to attempt a reconnection. The transport is meant
# to call `connectionLost` on the `IConnector`.
#
# Most of that is bypassed by directly calling `retry` on the factory,
# which schedules a `connect()` call on the connector.
timeouts = []
for hs in (self.hs, self.worker_hs):
hs_factory_outbound = hs.get_outbound_redis_connection()._factory
hs_factory_outbound.clientConnectionLost(
FakeOutboundConnector(hs), Failure(RuntimeError(""))
)
timeouts.append(hs_factory_outbound.delay)
self._client_transport = FakeTransport(self.server, self.reactor)
self.client.makeConnection(self._client_transport)
hs_factory = hs.get_replication_command_handler()._factory
hs_factory.clientConnectionLost(
FakeReplicationHandlerConnector(hs),
Failure(RuntimeError("")),
)
timeouts.append(hs_factory.delay)
self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport)
# Wait for the reconnects to happen.
self.pump(max(timeouts) + 1)
self.connect_any_redis_attempts()
def disconnect(self):
if self._client_transport:
self._client_transport = None
self.client.close()
if self._server_transport:
self._server_transport = None
self.server.close()
print("DISCONNECTING")
for (
client_to_server_transport,
server_to_client_transport,
) in self._redis_transports:
client_to_server_transport.abortConnection()
server_to_client_transport.abortConnection()
self._redis_transports = []
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -212,6 +285,40 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET")
def connect_any_redis_attempts(self):
"""If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a
fake one.
"""
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
if client_protocol.__class__.__name__ == "RedisSubscriber":
print(client_protocol, client_protocol.synapse_handler._presence_handler.hs, client_protocol.synapse_outbound_redis_connection)
else:
print(client_protocol, client_protocol.factory.hs)
print()
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, server_protocol
)
server_protocol.makeConnection(server_to_client_transport)
# Store for potentially disconnecting.
self._redis_transports.append(
(client_to_server_transport, server_to_client_transport)
)
class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests running multiple workers.
@@ -220,11 +327,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
unlike `BaseStreamTestCase`.
"""
def default_config(self):
config = super().default_config()
config["redis"] = {"enabled": True}
return config
def setUp(self):
super().setUp()
# build a replication server
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
# Fake in memory Redis server that servers can connect to.
@@ -243,15 +353,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# handling inbound HTTP requests to that instance.
self._hs_to_site = {self.hs: self.site}
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost",
6379,
self.connect_any_redis_attempts,
)
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost",
6379,
self.connect_any_redis_attempts,
)
self.hs.get_replication_command_handler().start_replication(self.hs)
self.hs.get_replication_command_handler().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
@@ -335,27 +444,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
store = worker_hs.get_datastores().main
store.db_pool._db_pool = self.database_pool._db_pool
# Set up TCP replication between master and the new worker if we don't
# have Redis support enabled.
if not worker_hs.config.redis.redis_enabled:
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs,
"client",
"test",
self.clock,
repl_handler,
)
server = self.server_factory.buildProtocol(
IPv4Address("TCP", "127.0.0.1", 0)
)
client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)
server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)
# Set up a resource for the worker
resource = ReplicationRestResource(worker_hs)
@@ -374,8 +462,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
reactor=self.reactor,
)
if worker_hs.config.redis.redis_enabled:
worker_hs.get_replication_command_handler().start_replication(worker_hs)
worker_hs.get_replication_command_handler().start_replication(worker_hs)
return worker_hs
@@ -424,7 +511,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Note: at this point we've wired everything up, but we need to return
# before the data starts flowing over the connections as this is called
# inside `connecTCP` before the connection has been passed back to the
# inside `connectTCP` before the connection has been passed back to the
# code that requested the TCP connection.
def connect_any_redis_attempts(self):
@@ -536,8 +623,13 @@ class FakeRedisPubSubProtocol(Protocol):
self.send("OK")
elif command == b"GET":
self.send(None)
# Connection keep-alives.
elif command == b"PING":
self.send("PONG")
else:
raise Exception("Unknown command")
raise Exception(f"Unknown command: {command}")
def send(self, msg):
"""Send a message back to the client."""

View File

@@ -250,10 +250,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.replicate()
self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
# limit the replication rate
repl_transport = self._server_transport
assert isinstance(repl_transport, FakeTransport)
repl_transport.autoflush = False
# limit the replication rate from server -> client.
print(len(self._redis_transports))
for x in self._redis_transports:
print(f"\t{x}")
assert len(self._redis_transports) == 1
for _, repl_transport in self._redis_transports:
assert isinstance(repl_transport, FakeTransport)
repl_transport.autoflush = False
# build the join and message events and persist them in the same batch.
logger.info("----- build test events ------")

View File

@@ -28,7 +28,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
return Mock(wraps=super()._build_replication_data_handler())
def test_receipt(self):
self.reconnect()
# self.reconnect()
# tell the master to send a new receipt
self.get_success(

View File

@@ -27,10 +27,8 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
servlets = [register.register_servlets]
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config = super()._get_worker_hs_config()
config["worker_app"] = "synapse.app.client_reader"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config
def test_register_single_worker(self):

View File

@@ -51,7 +51,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = {
"worker1": {"host": "testserv", "port": 1001},