Compare commits

...

1 Commits

Author SHA1 Message Date
Olivier Wilkinson (reivilibre)
d7a1948a47 Initial crack at defining a worker agent endpoint factory
TODO it's not being configured in the HTTP Client yet
2023-04-05 12:37:38 +01:00
4 changed files with 93 additions and 29 deletions

View File

@@ -18,6 +18,7 @@ import logging
from typing import Any, Dict, List, Union
import attr
from pydantic import BaseModel, StrictBool, StrictInt, StrictStr, parse_obj_as
from synapse.config._base import (
Config,
@@ -50,13 +51,23 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
return obj
@attr.s(auto_attribs=True)
class InstanceLocationConfig:
class UnixSocketInstanceLocationConfig(BaseModel):
"""The path to talk to an instance via HTTP replication over Unix socket."""
socket_path: StrictStr
class TcpInstanceLocationConfig(BaseModel):
"""The host and port to talk to an instance via HTTP replication."""
host: str
port: int
tls: bool = False
host: StrictStr
port: StrictInt
tls: StrictBool = False
InstanceLocationConfig = Union[
UnixSocketInstanceLocationConfig, TcpInstanceLocationConfig
]
@attr.s
@@ -182,11 +193,10 @@ class WorkerConfig(Config):
federation_sender_instances
)
# A map from instance name to host/port of their HTTP replication endpoint.
instance_map = config.get("instance_map") or {}
self.instance_map = {
name: InstanceLocationConfig(**c) for name, c in instance_map.items()
}
# A map from instance name to connection details for their HTTP replication endpoint.
self.instance_map = parse_obj_as(
Dict[str, InstanceLocationConfig], config.get("instance_map") or {}
)
# Map from type of streams to source, c.f. WriterLocations.
writers = config.get("stream_writers") or {}

View File

@@ -198,9 +198,6 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
local_instance_name = hs.get_instance_name()
# The value of these option should match the replication listener settings
master_host = hs.config.worker.worker_replication_host
master_port = hs.config.worker.worker_replication_http_port
master_tls = hs.config.worker.worker_replication_http_tls
instance_map = hs.config.worker.instance_map
@@ -221,15 +218,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
raise Exception("Trying to send HTTP request to self")
if instance_name == "master":
host = master_host
port = master_port
tls = master_tls
elif instance_name in instance_map:
host = instance_map[instance_name].host
port = instance_map[instance_name].port
tls = instance_map[instance_name].tls
else:
if instance_name not in instance_map:
raise Exception(
"Instance %r not in 'instance_map' config" % (instance_name,)
)
@@ -279,14 +268,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# Here the protocol is hard coded to be http by default or https in case the replication
# port is set to have tls true.
tls = False # TODO
scheme = "https" if tls else "http"
uri = "%s://%s:%s/_synapse/replication/%s/%s" % (
scheme,
host,
port,
cls.NAME,
"/".join(url_args),
)
joined_args = "/".join(url_args)
uri = f"{scheme}://{instance_name}/_synapse/replication/{cls.NAME}/{joined_args}"
headers: Dict[bytes, List[bytes]] = {}
# Add an authorization header, if configured.

View File

@@ -0,0 +1,67 @@
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict
from zope.interface import implementer
from twisted.internet.endpoints import UNIXClientEndpoint
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.web.client import URI
from twisted.web.iweb import IAgentEndpointFactory
from synapse.config.workers import (
InstanceLocationConfig,
TcpInstanceLocationConfig,
UnixSocketInstanceLocationConfig,
)
from synapse.types import ISynapseReactor
@implementer(IAgentEndpointFactory)
class WorkerEndpointFactory:
def __init__(
self,
reactor: ISynapseReactor,
configs: Dict[str, InstanceLocationConfig],
tcp_endpoint_factory: IAgentEndpointFactory,
):
self.reactor = reactor
self.configs = configs
self.tcp_agent_factory = tcp_endpoint_factory
def endpointForURI(self, uri: URI) -> IStreamClientEndpoint:
worker_config = self.configs.get(uri.host)
if not worker_config:
raise ValueError(f"Don't know how to connect to worker: {uri.host}")
if isinstance(worker_config, TcpInstanceLocationConfig):
# TODO TLS support
rewritten_uri = URI(
scheme=uri.scheme,
# TODO I'd probably cache the encoded netloc and host in the TCP Config?
netloc=f"{worker_config.host}:{worker_config.port}".encode("utf-8"),
host=worker_config.host.encode("utf-8"),
port=worker_config.port,
path=uri.path,
params=uri.params,
query=uri.query,
fragment=uri.fragment,
)
return self.tcp_agent_factory.endpointForURI(rewritten_uri)
elif isinstance(worker_config, UnixSocketInstanceLocationConfig):
return UNIXClientEndpoint(self.reactor, worker_config.socket_path)
else:
raise ValueError(
f"Unknown worker connection config {worker_config} for {uri.host}"
)

View File

@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer
from synapse.config.workers import TcpInstanceLocationConfig
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
@@ -339,6 +340,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# `_handle_http_replication_attempt` like we do with the master HS.
instance_name = worker_hs.get_instance_name()
instance_loc = worker_hs.config.worker.instance_map.get(instance_name)
assert isinstance(instance_loc, TcpInstanceLocationConfig)
if instance_loc:
# Ensure the host is one that has a fake DNS entry.
if instance_loc.host not in self.reactor.lookups: