1
0

Compare commits

...

9 Commits

Author SHA1 Message Date
David Robertson
c929b8a073 Changelog 2022-08-21 23:30:39 +01:00
David Robertson
9b6764b2ef A bit more paramspec 2022-08-21 23:28:48 +01:00
David Robertson
dd70d11373 Remove unused config.ldap_enabled 2022-08-21 23:17:08 +01:00
David Robertson
5126d867b1 WIP: annotate setup_test_homeserver 2022-08-21 23:16:51 +01:00
David Robertson
9d4da69ffd Annotate FakeTransport 2022-08-21 23:02:58 +01:00
David Robertson
c9e80bc772 Annotate ThreadPool 2022-08-21 22:35:14 +01:00
David Robertson
48ae00e5bd Annotate ThreadedMemoryReactorClock 2022-08-21 22:27:04 +01:00
David Robertson
db1c5ffce9 annotate getResourceFor 2022-08-21 22:12:24 +01:00
David Robertson
895c09b6e4 annotate writeHeaders 2022-08-21 22:12:11 +01:00
2 changed files with 125 additions and 81 deletions

1
changelog.d/13578.misc Normal file
View File

@@ -0,0 +1 @@
Improve the type annotations in `tests.server`.

View File

@@ -22,20 +22,24 @@ import warnings
from collections import deque from collections import deque
from io import SEEK_END, BytesIO from io import SEEK_END, BytesIO
from typing import ( from typing import (
Any,
Callable, Callable,
Dict, Dict,
Iterable, Iterable,
List, List,
MutableMapping, MutableMapping,
Optional, Optional,
Sequence,
Tuple, Tuple,
Type, Type,
TypeVar,
Union, Union,
cast,
) )
from unittest.mock import Mock from unittest.mock import Mock
import attr import attr
from typing_extensions import Deque from typing_extensions import Deque, ParamSpec
from zope.interface import implementer from zope.interface import implementer
from twisted.internet import address, threads, udp from twisted.internet import address, threads, udp
@@ -44,23 +48,28 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import ( from twisted.internet.interfaces import (
IAddress, IAddress,
IConnector,
IConsumer, IConsumer,
IHostnameResolver, IHostnameResolver,
IProtocol, IProtocol,
IPullProducer, IPullProducer,
IPushProducer, IPushProducer,
IReactorFromThreads,
IReactorPluggableNameResolver, IReactorPluggableNameResolver,
IReactorTime, IReactorTime,
IResolverSimple, IResolverSimple,
ITransport, ITransport,
) )
from twisted.internet.protocol import ClientFactory, DatagramProtocol
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IRequest
from twisted.web.resource import IResource from twisted.web.resource import IResource
from twisted.web.server import Request, Site from twisted.web.server import Request, Site
from synapse.config.database import DatabaseConnectionConfig from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.events.presence_router import load_legacy_presence_router from synapse.events.presence_router import load_legacy_presence_router
from synapse.events.spamcheck import load_legacy_spam_checkers from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.events.third_party_rules import load_legacy_third_party_event_rules from synapse.events.third_party_rules import load_legacy_third_party_event_rules
@@ -70,7 +79,7 @@ from synapse.logging.context import ContextResourceUsage
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine from synapse.storage.engines import PostgresEngine, create_engine
from synapse.types import JsonDict from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests.utils import ( from tests.utils import (
@@ -90,6 +99,8 @@ logger = logging.getLogger(__name__)
# the type of thing that can be passed into `make_request` in the headers list # the type of thing that can be passed into `make_request` in the headers list
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]] CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
P = ParamSpec("P")
R = TypeVar("R")
class TimedOutException(Exception): class TimedOutException(Exception):
@@ -165,7 +176,9 @@ class FakeChannel:
h.addRawHeader(*i) h.addRawHeader(*i)
return h return h
def writeHeaders(self, version, code, reason, headers): def writeHeaders(
self, version: bytes, code: bytes, reason: bytes, headers: Headers
) -> None:
self.result["version"] = version self.result["version"] = version
self.result["code"] = code self.result["code"] = code
self.result["reason"] = reason self.result["reason"] = reason
@@ -275,7 +288,7 @@ class FakeSite:
self._resource = resource self._resource = resource
self.reactor = reactor self.reactor = reactor
def getResourceFor(self, request): def getResourceFor(self, request: IRequest) -> IResource:
return self._resource return self._resource
@@ -389,17 +402,17 @@ def make_request(
return channel return channel
@implementer(IReactorPluggableNameResolver) @implementer(IReactorPluggableNameResolver, IReactorFromThreads)
class ThreadedMemoryReactorClock(MemoryReactorClock): class ThreadedMemoryReactorClock(MemoryReactorClock):
""" """
A MemoryReactorClock that supports callFromThread. A MemoryReactorClock that supports callFromThread.
""" """
def __init__(self): def __init__(self) -> None:
self.threadpool = ThreadPool(self) self.threadpool = ThreadPool(self)
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {} self._tcp_callbacks: Dict[Tuple[str, int], Callable[[], None]] = {}
self._udp = [] self._udp: List[udp.Port] = []
self.lookups: Dict[str, str] = {} self.lookups: Dict[str, str] = {}
self._thread_callbacks: Deque[Callable[[], None]] = deque() self._thread_callbacks: Deque[Callable[[], None]] = deque()
@@ -407,7 +420,9 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
@implementer(IResolverSimple) @implementer(IResolverSimple)
class FakeResolver: class FakeResolver:
def getHostByName(self, name, timeout=None): def getHostByName(
self, name: str, timeout: Sequence[int] = ()
) -> "Deferred[str]":
if name not in lookups: if name not in lookups:
return fail(DNSLookupError("OH NO: unknown %s" % (name,))) return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
return succeed(lookups[name]) return succeed(lookups[name])
@@ -418,13 +433,22 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver: def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError() raise NotImplementedError()
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196): def listenUDP(
self,
port: int,
protocol: DatagramProtocol,
interface: str = "",
maxPacketSize: int = 8196,
) -> udp.Port:
p = udp.Port(port, protocol, interface, maxPacketSize, self) p = udp.Port(port, protocol, interface, maxPacketSize, self)
p.startListening() p.startListening()
self._udp.append(p) self._udp.append(p)
return p return p
def callFromThread(self, callback, *args, **kwargs): # Type-ignore: IReactorFromThreads doesn't use paramspec here.
def callFromThread( # type: ignore[override]
self, callback: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
) -> None:
""" """
Make the callback fire in the next reactor iteration. Make the callback fire in the next reactor iteration.
""" """
@@ -433,10 +457,12 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
# separate queue. # separate queue.
self._thread_callbacks.append(cb) self._thread_callbacks.append(cb)
def getThreadPool(self): def getThreadPool(self) -> "ThreadPool":
return self.threadpool return self.threadpool
def add_tcp_client_callback(self, host: str, port: int, callback: Callable): def add_tcp_client_callback(
self, host: str, port: int, callback: Callable[[], None]
) -> None:
"""Add a callback that will be invoked when we receive a connection """Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`. attempt to the given IP/port using `connectTCP`.
@@ -445,7 +471,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
""" """
self._tcp_callbacks[(host, port)] = callback self._tcp_callbacks[(host, port)] = callback
def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None): def connectTCP(
self,
host: str,
port: int,
factory: ClientFactory,
timeout: float = 30,
bindAddress: Optional[Tuple[str, int]] = None,
) -> IConnector:
"""Fake L{IReactorTCP.connectTCP}.""" """Fake L{IReactorTCP.connectTCP}."""
conn = super().connectTCP( conn = super().connectTCP(
@@ -458,7 +491,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
return conn return conn
def advance(self, amount): def advance(self, amount: float) -> None:
# first advance our reactor's time, and run any "callLater" callbacks that # first advance our reactor's time, and run any "callLater" callbacks that
# makes ready # makes ready
super().advance(amount) super().advance(amount)
@@ -485,26 +518,32 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
class ThreadPool: class ThreadPool:
""" """
Threadless thread pool. Threadless thread pool. A stand-in for twisted.python.threadpool.ThreadPool.
""" """
def __init__(self, reactor): def __init__(self, reactor: IReactorTime):
self._reactor = reactor self._reactor = reactor
def start(self): def start(self) -> None:
pass pass
def stop(self): def stop(self) -> None:
pass pass
def callInThreadWithCallback(self, onResult, function, *args, **kwargs): def callInThreadWithCallback(
def _(res): self,
onResult: Callable[[bool, object], Any],
function: Callable[P, Any],
*args: P.args,
**kwargs: P.kwargs,
) -> "Deferred[bool]":
def _(res: object) -> None:
if isinstance(res, Failure): if isinstance(res, Failure):
onResult(False, res) onResult(False, res)
else: else:
onResult(True, res) onResult(True, res)
d = Deferred() d: "Deferred[bool]" = Deferred()
d.addCallback(lambda x: function(*args, **kwargs)) d.addCallback(lambda x: function(*args, **kwargs))
d.addBoth(_) d.addBoth(_)
self._reactor.callLater(0, d.callback, True) self._reactor.callLater(0, d.callback, True)
@@ -521,7 +560,7 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
for database in server.get_datastores().databases: for database in server.get_datastores().databases:
pool = database._db_pool pool = database._db_pool
def runWithConnection(func, *args, **kwargs): def runWithConnection(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> "Deferred[R]":
return threads.deferToThreadPool( return threads.deferToThreadPool(
pool._reactor, pool._reactor,
pool.threadpool, pool.threadpool,
@@ -531,7 +570,7 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
**kwargs, **kwargs,
) )
def runInteraction(interaction, *args, **kwargs): def runInteraction(interaction: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> "Deferred[R]":
return threads.deferToThreadPool( return threads.deferToThreadPool(
pool._reactor, pool._reactor,
pool.threadpool, pool.threadpool,
@@ -559,7 +598,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
@implementer(ITransport) @implementer(ITransport)
@attr.s(cmp=False) @attr.s(cmp=False, auto_attribs=True)
class FakeTransport: class FakeTransport:
""" """
A twisted.internet.interfaces.ITransport implementation which sends all its data A twisted.internet.interfaces.ITransport implementation which sends all its data
@@ -574,35 +613,29 @@ class FakeTransport:
If you want bidirectional communication, you'll need two instances. If you want bidirectional communication, you'll need two instances.
""" """
other = attr.ib() other: IProtocol
"""The Protocol object which will receive any data written to this transport. """The Protocol object which will receive any data written to this transport."""
:type: twisted.internet.interfaces.IProtocol _reactor: IReactorTime
""" """Test reactor """
_reactor = attr.ib() _protocol: Optional[IProtocol] = None
"""Test reactor
:type: twisted.internet.interfaces.IReactorTime
"""
_protocol = attr.ib(default=None)
"""The Protocol which is producing data for this transport. Optional, but if set """The Protocol which is producing data for this transport. Optional, but if set
will get called back for connectionLost() notifications etc. will get called back for connectionLost() notifications etc.
""" """
_peer_address: Optional[IAddress] = attr.ib(default=None) _peer_address: Optional[IAddress] = None
"""The value to be returned by getPeer""" """The value to be returned by getPeer"""
_host_address: Optional[IAddress] = attr.ib(default=None) _host_address: Optional[IAddress] = None
"""The value to be returned by getHost""" """The value to be returned by getHost"""
disconnecting = False disconnecting: bool = False
disconnected = False disconnected: bool = False
connected = True connected: bool = True
buffer = attr.ib(default=b"") buffer: bytes = b""
producer = attr.ib(default=None) producer: Optional[IPushProducer] = None
autoflush = attr.ib(default=True) autoflush: bool = True
def getPeer(self) -> Optional[IAddress]: def getPeer(self) -> Optional[IAddress]:
return self._peer_address return self._peer_address
@@ -610,7 +643,7 @@ class FakeTransport:
def getHost(self) -> Optional[IAddress]: def getHost(self) -> Optional[IAddress]:
return self._host_address return self._host_address
def loseConnection(self, reason=None): def loseConnection(self, reason: Optional[Failure] = None) -> None:
if not self.disconnecting: if not self.disconnecting:
logger.info("FakeTransport: loseConnection(%s)", reason) logger.info("FakeTransport: loseConnection(%s)", reason)
self.disconnecting = True self.disconnecting = True
@@ -626,7 +659,7 @@ class FakeTransport:
self.connected = False self.connected = False
self.disconnected = True self.disconnected = True
def abortConnection(self): def abortConnection(self) -> None:
logger.info("FakeTransport: abortConnection()") logger.info("FakeTransport: abortConnection()")
if not self.disconnecting: if not self.disconnecting:
@@ -636,28 +669,28 @@ class FakeTransport:
self.disconnected = True self.disconnected = True
def pauseProducing(self): def pauseProducing(self) -> None:
if not self.producer: if not self.producer:
return return
self.producer.pauseProducing() self.producer.pauseProducing()
def resumeProducing(self): def resumeProducing(self) -> None:
if not self.producer: if not self.producer:
return return
self.producer.resumeProducing() self.producer.resumeProducing()
def unregisterProducer(self): def unregisterProducer(self) -> None:
if not self.producer: if not self.producer:
return return
self.producer = None self.producer = None
def registerProducer(self, producer, streaming): def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
self.producer = producer self.producer = producer
self.producerStreaming = streaming self.producerStreaming = streaming
def _produce(): def _produce() -> None:
if not self.producer: if not self.producer:
# we've been unregistered # we've been unregistered
return return
@@ -669,7 +702,7 @@ class FakeTransport:
if not streaming: if not streaming:
self._reactor.callLater(0.0, _produce) self._reactor.callLater(0.0, _produce)
def write(self, byt): def write(self, byt: bytes) -> None:
if self.disconnecting: if self.disconnecting:
raise Exception("Writing to disconnecting FakeTransport") raise Exception("Writing to disconnecting FakeTransport")
@@ -681,11 +714,11 @@ class FakeTransport:
if self.autoflush: if self.autoflush:
self._reactor.callLater(0.0, self.flush) self._reactor.callLater(0.0, self.flush)
def writeSequence(self, seq): def writeSequence(self, seq: Iterable[bytes]) -> None:
for x in seq: for x in seq:
self.write(x) self.write(x)
def flush(self, maxbytes=None): def flush(self, maxbytes: Optional[int] = None) -> None:
if not self.buffer: if not self.buffer:
# nothing to do. Don't write empty buffers: it upsets the # nothing to do. Don't write empty buffers: it upsets the
# TLSMemoryBIOProtocol # TLSMemoryBIOProtocol
@@ -739,14 +772,17 @@ class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore DATASTORE_CLASS = DataStore
HS = TypeVar("HS", bound=HomeServer)
def setup_test_homeserver( def setup_test_homeserver(
cleanup_func, cleanup_func: Callable[[Callable[[], None]], Any],
name="test", name: str = "test",
config=None, config: Union[HomeServerConfig, None] = None,
reactor=None, reactor: Optional[ISynapseReactor] = None,
homeserver_to_use: Type[HomeServer] = TestHomeServer, homeserver_to_use: Type[HS] = TestHomeServer,
**kwargs, **kwargs: object,
): ) -> HS:
""" """
Setup a homeserver suitable for running tests against. Keyword arguments Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor. are passed to the Homeserver constructor.
@@ -761,13 +797,12 @@ def setup_test_homeserver(
HomeserverTestCase. HomeserverTestCase.
""" """
if reactor is None: if reactor is None:
from twisted.internet import reactor from twisted.internet import reactor # type: ignore[no-redef]
if config is None: if config is None:
config = default_config(name, parse=True) config = default_config(name, parse=True)
config.caches.resize_all_caches() config.caches.resize_all_caches()
config.ldap_enabled = False
if "clock" not in kwargs: if "clock" not in kwargs:
kwargs["clock"] = MockClock() kwargs["clock"] = MockClock()
@@ -810,20 +845,25 @@ def setup_test_homeserver(
if "db_txn_limit" in kwargs: if "db_txn_limit" in kwargs:
database_config["txn_limit"] = kwargs["db_txn_limit"] database_config["txn_limit"] = kwargs["db_txn_limit"]
database = DatabaseConnectionConfig("master", database_config) database_conn_config = DatabaseConnectionConfig("master", database_config)
config.database.databases = [database] config.database.databases = [database_conn_config]
db_engine = create_engine(database.config) db_engine = create_engine(database_conn_config.config)
# Create the database before we actually try and connect to it, based off # Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb() # the template database we generate in setupdb()
if isinstance(db_engine, PostgresEngine): if isinstance(db_engine, PostgresEngine):
db_conn = db_engine.module.connect( import psycopg2
database=POSTGRES_BASE_DB,
user=POSTGRES_USER, db_conn = cast(
host=POSTGRES_HOST, psycopg2.connection,
port=POSTGRES_PORT, db_engine.module.connect(
password=POSTGRES_PASSWORD, database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
),
) )
db_conn.autocommit = True db_conn.autocommit = True
cur = db_conn.cursor() cur = db_conn.cursor()
@@ -856,7 +896,7 @@ def setup_test_homeserver(
database = hs.get_datastores().databases[0] database = hs.get_datastores().databases[0]
# We need to do cleanup on PostgreSQL # We need to do cleanup on PostgreSQL
def cleanup(): def cleanup() -> None:
import psycopg2 import psycopg2
# Close all the db pools # Close all the db pools
@@ -865,12 +905,15 @@ def setup_test_homeserver(
dropped = False dropped = False
# Drop the test database # Drop the test database
db_conn = db_engine.module.connect( db_conn = cast(
database=POSTGRES_BASE_DB, psycopg2.connection,
user=POSTGRES_USER, db_engine.module.connect(
host=POSTGRES_HOST, database=POSTGRES_BASE_DB,
port=POSTGRES_PORT, user=POSTGRES_USER,
password=POSTGRES_PASSWORD, host=POSTGRES_HOST,
port=POSTGRES_PORT,
password=POSTGRES_PASSWORD,
),
) )
db_conn.autocommit = True db_conn.autocommit = True
cur = db_conn.cursor() cur = db_conn.cursor()
@@ -904,12 +947,12 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it # Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one # because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg) # beforehand and pass it in to the HS's constructor (chicken / egg)
async def hash(p): async def hash(p: str) -> str:
return hashlib.md5(p.encode("utf8")).hexdigest() return hashlib.md5(p.encode("utf8")).hexdigest()
hs.get_auth_handler().hash = hash hs.get_auth_handler().hash = hash
async def validate_hash(p, h): async def validate_hash(p: str, h: str) -> bool:
return hashlib.md5(p.encode("utf8")).hexdigest() == h return hashlib.md5(p.encode("utf8")).hexdigest() == h
hs.get_auth_handler().validate_hash = validate_hash hs.get_auth_handler().validate_hash = validate_hash