Compare commits
9 Commits
patch-1
...
dmr/typing
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c929b8a073 | ||
|
|
9b6764b2ef | ||
|
|
dd70d11373 | ||
|
|
5126d867b1 | ||
|
|
9d4da69ffd | ||
|
|
c9e80bc772 | ||
|
|
48ae00e5bd | ||
|
|
db1c5ffce9 | ||
|
|
895c09b6e4 |
1
changelog.d/13578.misc
Normal file
1
changelog.d/13578.misc
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Improve the type annotations in `tests.server`.
|
||||||
205
tests/server.py
205
tests/server.py
@@ -22,20 +22,24 @@ import warnings
|
|||||||
from collections import deque
|
from collections import deque
|
||||||
from io import SEEK_END, BytesIO
|
from io import SEEK_END, BytesIO
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
MutableMapping,
|
MutableMapping,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
cast,
|
||||||
)
|
)
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from typing_extensions import Deque
|
from typing_extensions import Deque, ParamSpec
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer
|
||||||
|
|
||||||
from twisted.internet import address, threads, udp
|
from twisted.internet import address, threads, udp
|
||||||
@@ -44,23 +48,28 @@ from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
|
|||||||
from twisted.internet.error import DNSLookupError
|
from twisted.internet.error import DNSLookupError
|
||||||
from twisted.internet.interfaces import (
|
from twisted.internet.interfaces import (
|
||||||
IAddress,
|
IAddress,
|
||||||
|
IConnector,
|
||||||
IConsumer,
|
IConsumer,
|
||||||
IHostnameResolver,
|
IHostnameResolver,
|
||||||
IProtocol,
|
IProtocol,
|
||||||
IPullProducer,
|
IPullProducer,
|
||||||
IPushProducer,
|
IPushProducer,
|
||||||
|
IReactorFromThreads,
|
||||||
IReactorPluggableNameResolver,
|
IReactorPluggableNameResolver,
|
||||||
IReactorTime,
|
IReactorTime,
|
||||||
IResolverSimple,
|
IResolverSimple,
|
||||||
ITransport,
|
ITransport,
|
||||||
)
|
)
|
||||||
|
from twisted.internet.protocol import ClientFactory, DatagramProtocol
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
|
from twisted.web.iweb import IRequest
|
||||||
from twisted.web.resource import IResource
|
from twisted.web.resource import IResource
|
||||||
from twisted.web.server import Request, Site
|
from twisted.web.server import Request, Site
|
||||||
|
|
||||||
from synapse.config.database import DatabaseConnectionConfig
|
from synapse.config.database import DatabaseConnectionConfig
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
from synapse.events.presence_router import load_legacy_presence_router
|
from synapse.events.presence_router import load_legacy_presence_router
|
||||||
from synapse.events.spamcheck import load_legacy_spam_checkers
|
from synapse.events.spamcheck import load_legacy_spam_checkers
|
||||||
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
|
||||||
@@ -70,7 +79,7 @@ from synapse.logging.context import ContextResourceUsage
|
|||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.storage.engines import PostgresEngine, create_engine
|
from synapse.storage.engines import PostgresEngine, create_engine
|
||||||
from synapse.types import JsonDict
|
from synapse.types import ISynapseReactor, JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
from tests.utils import (
|
from tests.utils import (
|
||||||
@@ -90,6 +99,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# the type of thing that can be passed into `make_request` in the headers list
|
# the type of thing that can be passed into `make_request` in the headers list
|
||||||
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
|
CustomHeaderType = Tuple[Union[str, bytes], Union[str, bytes]]
|
||||||
|
P = ParamSpec("P")
|
||||||
|
R = TypeVar("R")
|
||||||
|
|
||||||
|
|
||||||
class TimedOutException(Exception):
|
class TimedOutException(Exception):
|
||||||
@@ -165,7 +176,9 @@ class FakeChannel:
|
|||||||
h.addRawHeader(*i)
|
h.addRawHeader(*i)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
def writeHeaders(self, version, code, reason, headers):
|
def writeHeaders(
|
||||||
|
self, version: bytes, code: bytes, reason: bytes, headers: Headers
|
||||||
|
) -> None:
|
||||||
self.result["version"] = version
|
self.result["version"] = version
|
||||||
self.result["code"] = code
|
self.result["code"] = code
|
||||||
self.result["reason"] = reason
|
self.result["reason"] = reason
|
||||||
@@ -275,7 +288,7 @@ class FakeSite:
|
|||||||
self._resource = resource
|
self._resource = resource
|
||||||
self.reactor = reactor
|
self.reactor = reactor
|
||||||
|
|
||||||
def getResourceFor(self, request):
|
def getResourceFor(self, request: IRequest) -> IResource:
|
||||||
return self._resource
|
return self._resource
|
||||||
|
|
||||||
|
|
||||||
@@ -389,17 +402,17 @@ def make_request(
|
|||||||
return channel
|
return channel
|
||||||
|
|
||||||
|
|
||||||
@implementer(IReactorPluggableNameResolver)
|
@implementer(IReactorPluggableNameResolver, IReactorFromThreads)
|
||||||
class ThreadedMemoryReactorClock(MemoryReactorClock):
|
class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||||
"""
|
"""
|
||||||
A MemoryReactorClock that supports callFromThread.
|
A MemoryReactorClock that supports callFromThread.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.threadpool = ThreadPool(self)
|
self.threadpool = ThreadPool(self)
|
||||||
|
|
||||||
self._tcp_callbacks: Dict[Tuple[str, int], Callable] = {}
|
self._tcp_callbacks: Dict[Tuple[str, int], Callable[[], None]] = {}
|
||||||
self._udp = []
|
self._udp: List[udp.Port] = []
|
||||||
self.lookups: Dict[str, str] = {}
|
self.lookups: Dict[str, str] = {}
|
||||||
self._thread_callbacks: Deque[Callable[[], None]] = deque()
|
self._thread_callbacks: Deque[Callable[[], None]] = deque()
|
||||||
|
|
||||||
@@ -407,7 +420,9 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
|||||||
|
|
||||||
@implementer(IResolverSimple)
|
@implementer(IResolverSimple)
|
||||||
class FakeResolver:
|
class FakeResolver:
|
||||||
def getHostByName(self, name, timeout=None):
|
def getHostByName(
|
||||||
|
self, name: str, timeout: Sequence[int] = ()
|
||||||
|
) -> "Deferred[str]":
|
||||||
if name not in lookups:
|
if name not in lookups:
|
||||||
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
|
return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
|
||||||
return succeed(lookups[name])
|
return succeed(lookups[name])
|
||||||
@@ -418,13 +433,22 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
|||||||
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
|
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def listenUDP(self, port, protocol, interface="", maxPacketSize=8196):
|
def listenUDP(
|
||||||
|
self,
|
||||||
|
port: int,
|
||||||
|
protocol: DatagramProtocol,
|
||||||
|
interface: str = "",
|
||||||
|
maxPacketSize: int = 8196,
|
||||||
|
) -> udp.Port:
|
||||||
p = udp.Port(port, protocol, interface, maxPacketSize, self)
|
p = udp.Port(port, protocol, interface, maxPacketSize, self)
|
||||||
p.startListening()
|
p.startListening()
|
||||||
self._udp.append(p)
|
self._udp.append(p)
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def callFromThread(self, callback, *args, **kwargs):
|
# Type-ignore: IReactorFromThreads doesn't use paramspec here.
|
||||||
|
def callFromThread( # type: ignore[override]
|
||||||
|
self, callback: Callable[P, Any], *args: P.args, **kwargs: P.kwargs
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Make the callback fire in the next reactor iteration.
|
Make the callback fire in the next reactor iteration.
|
||||||
"""
|
"""
|
||||||
@@ -433,10 +457,12 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
|||||||
# separate queue.
|
# separate queue.
|
||||||
self._thread_callbacks.append(cb)
|
self._thread_callbacks.append(cb)
|
||||||
|
|
||||||
def getThreadPool(self):
|
def getThreadPool(self) -> "ThreadPool":
|
||||||
return self.threadpool
|
return self.threadpool
|
||||||
|
|
||||||
def add_tcp_client_callback(self, host: str, port: int, callback: Callable):
|
def add_tcp_client_callback(
|
||||||
|
self, host: str, port: int, callback: Callable[[], None]
|
||||||
|
) -> None:
|
||||||
"""Add a callback that will be invoked when we receive a connection
|
"""Add a callback that will be invoked when we receive a connection
|
||||||
attempt to the given IP/port using `connectTCP`.
|
attempt to the given IP/port using `connectTCP`.
|
||||||
|
|
||||||
@@ -445,7 +471,14 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
|||||||
"""
|
"""
|
||||||
self._tcp_callbacks[(host, port)] = callback
|
self._tcp_callbacks[(host, port)] = callback
|
||||||
|
|
||||||
def connectTCP(self, host: str, port: int, factory, timeout=30, bindAddress=None):
|
def connectTCP(
|
||||||
|
self,
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
factory: ClientFactory,
|
||||||
|
timeout: float = 30,
|
||||||
|
bindAddress: Optional[Tuple[str, int]] = None,
|
||||||
|
) -> IConnector:
|
||||||
"""Fake L{IReactorTCP.connectTCP}."""
|
"""Fake L{IReactorTCP.connectTCP}."""
|
||||||
|
|
||||||
conn = super().connectTCP(
|
conn = super().connectTCP(
|
||||||
@@ -458,7 +491,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
|||||||
|
|
||||||
return conn
|
return conn
|
||||||
|
|
||||||
def advance(self, amount):
|
def advance(self, amount: float) -> None:
|
||||||
# first advance our reactor's time, and run any "callLater" callbacks that
|
# first advance our reactor's time, and run any "callLater" callbacks that
|
||||||
# makes ready
|
# makes ready
|
||||||
super().advance(amount)
|
super().advance(amount)
|
||||||
@@ -485,26 +518,32 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
|||||||
|
|
||||||
class ThreadPool:
|
class ThreadPool:
|
||||||
"""
|
"""
|
||||||
Threadless thread pool.
|
Threadless thread pool. A stand-in for twisted.python.threadpool.ThreadPool.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reactor):
|
def __init__(self, reactor: IReactorTime):
|
||||||
self._reactor = reactor
|
self._reactor = reactor
|
||||||
|
|
||||||
def start(self):
|
def start(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def stop(self):
|
def stop(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def callInThreadWithCallback(self, onResult, function, *args, **kwargs):
|
def callInThreadWithCallback(
|
||||||
def _(res):
|
self,
|
||||||
|
onResult: Callable[[bool, object], Any],
|
||||||
|
function: Callable[P, Any],
|
||||||
|
*args: P.args,
|
||||||
|
**kwargs: P.kwargs,
|
||||||
|
) -> "Deferred[bool]":
|
||||||
|
def _(res: object) -> None:
|
||||||
if isinstance(res, Failure):
|
if isinstance(res, Failure):
|
||||||
onResult(False, res)
|
onResult(False, res)
|
||||||
else:
|
else:
|
||||||
onResult(True, res)
|
onResult(True, res)
|
||||||
|
|
||||||
d = Deferred()
|
d: "Deferred[bool]" = Deferred()
|
||||||
d.addCallback(lambda x: function(*args, **kwargs))
|
d.addCallback(lambda x: function(*args, **kwargs))
|
||||||
d.addBoth(_)
|
d.addBoth(_)
|
||||||
self._reactor.callLater(0, d.callback, True)
|
self._reactor.callLater(0, d.callback, True)
|
||||||
@@ -521,7 +560,7 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
|
|||||||
for database in server.get_datastores().databases:
|
for database in server.get_datastores().databases:
|
||||||
pool = database._db_pool
|
pool = database._db_pool
|
||||||
|
|
||||||
def runWithConnection(func, *args, **kwargs):
|
def runWithConnection(func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> "Deferred[R]":
|
||||||
return threads.deferToThreadPool(
|
return threads.deferToThreadPool(
|
||||||
pool._reactor,
|
pool._reactor,
|
||||||
pool.threadpool,
|
pool.threadpool,
|
||||||
@@ -531,7 +570,7 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def runInteraction(interaction, *args, **kwargs):
|
def runInteraction(interaction: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> "Deferred[R]":
|
||||||
return threads.deferToThreadPool(
|
return threads.deferToThreadPool(
|
||||||
pool._reactor,
|
pool._reactor,
|
||||||
pool.threadpool,
|
pool.threadpool,
|
||||||
@@ -559,7 +598,7 @@ def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
|
|||||||
|
|
||||||
|
|
||||||
@implementer(ITransport)
|
@implementer(ITransport)
|
||||||
@attr.s(cmp=False)
|
@attr.s(cmp=False, auto_attribs=True)
|
||||||
class FakeTransport:
|
class FakeTransport:
|
||||||
"""
|
"""
|
||||||
A twisted.internet.interfaces.ITransport implementation which sends all its data
|
A twisted.internet.interfaces.ITransport implementation which sends all its data
|
||||||
@@ -574,35 +613,29 @@ class FakeTransport:
|
|||||||
If you want bidirectional communication, you'll need two instances.
|
If you want bidirectional communication, you'll need two instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
other = attr.ib()
|
other: IProtocol
|
||||||
"""The Protocol object which will receive any data written to this transport.
|
"""The Protocol object which will receive any data written to this transport."""
|
||||||
|
|
||||||
:type: twisted.internet.interfaces.IProtocol
|
_reactor: IReactorTime
|
||||||
"""
|
"""Test reactor """
|
||||||
|
|
||||||
_reactor = attr.ib()
|
_protocol: Optional[IProtocol] = None
|
||||||
"""Test reactor
|
|
||||||
|
|
||||||
:type: twisted.internet.interfaces.IReactorTime
|
|
||||||
"""
|
|
||||||
|
|
||||||
_protocol = attr.ib(default=None)
|
|
||||||
"""The Protocol which is producing data for this transport. Optional, but if set
|
"""The Protocol which is producing data for this transport. Optional, but if set
|
||||||
will get called back for connectionLost() notifications etc.
|
will get called back for connectionLost() notifications etc.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_peer_address: Optional[IAddress] = attr.ib(default=None)
|
_peer_address: Optional[IAddress] = None
|
||||||
"""The value to be returned by getPeer"""
|
"""The value to be returned by getPeer"""
|
||||||
|
|
||||||
_host_address: Optional[IAddress] = attr.ib(default=None)
|
_host_address: Optional[IAddress] = None
|
||||||
"""The value to be returned by getHost"""
|
"""The value to be returned by getHost"""
|
||||||
|
|
||||||
disconnecting = False
|
disconnecting: bool = False
|
||||||
disconnected = False
|
disconnected: bool = False
|
||||||
connected = True
|
connected: bool = True
|
||||||
buffer = attr.ib(default=b"")
|
buffer: bytes = b""
|
||||||
producer = attr.ib(default=None)
|
producer: Optional[IPushProducer] = None
|
||||||
autoflush = attr.ib(default=True)
|
autoflush: bool = True
|
||||||
|
|
||||||
def getPeer(self) -> Optional[IAddress]:
|
def getPeer(self) -> Optional[IAddress]:
|
||||||
return self._peer_address
|
return self._peer_address
|
||||||
@@ -610,7 +643,7 @@ class FakeTransport:
|
|||||||
def getHost(self) -> Optional[IAddress]:
|
def getHost(self) -> Optional[IAddress]:
|
||||||
return self._host_address
|
return self._host_address
|
||||||
|
|
||||||
def loseConnection(self, reason=None):
|
def loseConnection(self, reason: Optional[Failure] = None) -> None:
|
||||||
if not self.disconnecting:
|
if not self.disconnecting:
|
||||||
logger.info("FakeTransport: loseConnection(%s)", reason)
|
logger.info("FakeTransport: loseConnection(%s)", reason)
|
||||||
self.disconnecting = True
|
self.disconnecting = True
|
||||||
@@ -626,7 +659,7 @@ class FakeTransport:
|
|||||||
self.connected = False
|
self.connected = False
|
||||||
self.disconnected = True
|
self.disconnected = True
|
||||||
|
|
||||||
def abortConnection(self):
|
def abortConnection(self) -> None:
|
||||||
logger.info("FakeTransport: abortConnection()")
|
logger.info("FakeTransport: abortConnection()")
|
||||||
|
|
||||||
if not self.disconnecting:
|
if not self.disconnecting:
|
||||||
@@ -636,28 +669,28 @@ class FakeTransport:
|
|||||||
|
|
||||||
self.disconnected = True
|
self.disconnected = True
|
||||||
|
|
||||||
def pauseProducing(self):
|
def pauseProducing(self) -> None:
|
||||||
if not self.producer:
|
if not self.producer:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.producer.pauseProducing()
|
self.producer.pauseProducing()
|
||||||
|
|
||||||
def resumeProducing(self):
|
def resumeProducing(self) -> None:
|
||||||
if not self.producer:
|
if not self.producer:
|
||||||
return
|
return
|
||||||
self.producer.resumeProducing()
|
self.producer.resumeProducing()
|
||||||
|
|
||||||
def unregisterProducer(self):
|
def unregisterProducer(self) -> None:
|
||||||
if not self.producer:
|
if not self.producer:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.producer = None
|
self.producer = None
|
||||||
|
|
||||||
def registerProducer(self, producer, streaming):
|
def registerProducer(self, producer: IPushProducer, streaming: bool) -> None:
|
||||||
self.producer = producer
|
self.producer = producer
|
||||||
self.producerStreaming = streaming
|
self.producerStreaming = streaming
|
||||||
|
|
||||||
def _produce():
|
def _produce() -> None:
|
||||||
if not self.producer:
|
if not self.producer:
|
||||||
# we've been unregistered
|
# we've been unregistered
|
||||||
return
|
return
|
||||||
@@ -669,7 +702,7 @@ class FakeTransport:
|
|||||||
if not streaming:
|
if not streaming:
|
||||||
self._reactor.callLater(0.0, _produce)
|
self._reactor.callLater(0.0, _produce)
|
||||||
|
|
||||||
def write(self, byt):
|
def write(self, byt: bytes) -> None:
|
||||||
if self.disconnecting:
|
if self.disconnecting:
|
||||||
raise Exception("Writing to disconnecting FakeTransport")
|
raise Exception("Writing to disconnecting FakeTransport")
|
||||||
|
|
||||||
@@ -681,11 +714,11 @@ class FakeTransport:
|
|||||||
if self.autoflush:
|
if self.autoflush:
|
||||||
self._reactor.callLater(0.0, self.flush)
|
self._reactor.callLater(0.0, self.flush)
|
||||||
|
|
||||||
def writeSequence(self, seq):
|
def writeSequence(self, seq: Iterable[bytes]) -> None:
|
||||||
for x in seq:
|
for x in seq:
|
||||||
self.write(x)
|
self.write(x)
|
||||||
|
|
||||||
def flush(self, maxbytes=None):
|
def flush(self, maxbytes: Optional[int] = None) -> None:
|
||||||
if not self.buffer:
|
if not self.buffer:
|
||||||
# nothing to do. Don't write empty buffers: it upsets the
|
# nothing to do. Don't write empty buffers: it upsets the
|
||||||
# TLSMemoryBIOProtocol
|
# TLSMemoryBIOProtocol
|
||||||
@@ -739,14 +772,17 @@ class TestHomeServer(HomeServer):
|
|||||||
DATASTORE_CLASS = DataStore
|
DATASTORE_CLASS = DataStore
|
||||||
|
|
||||||
|
|
||||||
|
HS = TypeVar("HS", bound=HomeServer)
|
||||||
|
|
||||||
|
|
||||||
def setup_test_homeserver(
|
def setup_test_homeserver(
|
||||||
cleanup_func,
|
cleanup_func: Callable[[Callable[[], None]], Any],
|
||||||
name="test",
|
name: str = "test",
|
||||||
config=None,
|
config: Union[HomeServerConfig, None] = None,
|
||||||
reactor=None,
|
reactor: Optional[ISynapseReactor] = None,
|
||||||
homeserver_to_use: Type[HomeServer] = TestHomeServer,
|
homeserver_to_use: Type[HS] = TestHomeServer,
|
||||||
**kwargs,
|
**kwargs: object,
|
||||||
):
|
) -> HS:
|
||||||
"""
|
"""
|
||||||
Setup a homeserver suitable for running tests against. Keyword arguments
|
Setup a homeserver suitable for running tests against. Keyword arguments
|
||||||
are passed to the Homeserver constructor.
|
are passed to the Homeserver constructor.
|
||||||
@@ -761,13 +797,12 @@ def setup_test_homeserver(
|
|||||||
HomeserverTestCase.
|
HomeserverTestCase.
|
||||||
"""
|
"""
|
||||||
if reactor is None:
|
if reactor is None:
|
||||||
from twisted.internet import reactor
|
from twisted.internet import reactor # type: ignore[no-redef]
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
config = default_config(name, parse=True)
|
config = default_config(name, parse=True)
|
||||||
|
|
||||||
config.caches.resize_all_caches()
|
config.caches.resize_all_caches()
|
||||||
config.ldap_enabled = False
|
|
||||||
|
|
||||||
if "clock" not in kwargs:
|
if "clock" not in kwargs:
|
||||||
kwargs["clock"] = MockClock()
|
kwargs["clock"] = MockClock()
|
||||||
@@ -810,20 +845,25 @@ def setup_test_homeserver(
|
|||||||
if "db_txn_limit" in kwargs:
|
if "db_txn_limit" in kwargs:
|
||||||
database_config["txn_limit"] = kwargs["db_txn_limit"]
|
database_config["txn_limit"] = kwargs["db_txn_limit"]
|
||||||
|
|
||||||
database = DatabaseConnectionConfig("master", database_config)
|
database_conn_config = DatabaseConnectionConfig("master", database_config)
|
||||||
config.database.databases = [database]
|
config.database.databases = [database_conn_config]
|
||||||
|
|
||||||
db_engine = create_engine(database.config)
|
db_engine = create_engine(database_conn_config.config)
|
||||||
|
|
||||||
# Create the database before we actually try and connect to it, based off
|
# Create the database before we actually try and connect to it, based off
|
||||||
# the template database we generate in setupdb()
|
# the template database we generate in setupdb()
|
||||||
if isinstance(db_engine, PostgresEngine):
|
if isinstance(db_engine, PostgresEngine):
|
||||||
db_conn = db_engine.module.connect(
|
import psycopg2
|
||||||
database=POSTGRES_BASE_DB,
|
|
||||||
user=POSTGRES_USER,
|
db_conn = cast(
|
||||||
host=POSTGRES_HOST,
|
psycopg2.connection,
|
||||||
port=POSTGRES_PORT,
|
db_engine.module.connect(
|
||||||
password=POSTGRES_PASSWORD,
|
database=POSTGRES_BASE_DB,
|
||||||
|
user=POSTGRES_USER,
|
||||||
|
host=POSTGRES_HOST,
|
||||||
|
port=POSTGRES_PORT,
|
||||||
|
password=POSTGRES_PASSWORD,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
db_conn.autocommit = True
|
db_conn.autocommit = True
|
||||||
cur = db_conn.cursor()
|
cur = db_conn.cursor()
|
||||||
@@ -856,7 +896,7 @@ def setup_test_homeserver(
|
|||||||
database = hs.get_datastores().databases[0]
|
database = hs.get_datastores().databases[0]
|
||||||
|
|
||||||
# We need to do cleanup on PostgreSQL
|
# We need to do cleanup on PostgreSQL
|
||||||
def cleanup():
|
def cleanup() -> None:
|
||||||
import psycopg2
|
import psycopg2
|
||||||
|
|
||||||
# Close all the db pools
|
# Close all the db pools
|
||||||
@@ -865,12 +905,15 @@ def setup_test_homeserver(
|
|||||||
dropped = False
|
dropped = False
|
||||||
|
|
||||||
# Drop the test database
|
# Drop the test database
|
||||||
db_conn = db_engine.module.connect(
|
db_conn = cast(
|
||||||
database=POSTGRES_BASE_DB,
|
psycopg2.connection,
|
||||||
user=POSTGRES_USER,
|
db_engine.module.connect(
|
||||||
host=POSTGRES_HOST,
|
database=POSTGRES_BASE_DB,
|
||||||
port=POSTGRES_PORT,
|
user=POSTGRES_USER,
|
||||||
password=POSTGRES_PASSWORD,
|
host=POSTGRES_HOST,
|
||||||
|
port=POSTGRES_PORT,
|
||||||
|
password=POSTGRES_PASSWORD,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
db_conn.autocommit = True
|
db_conn.autocommit = True
|
||||||
cur = db_conn.cursor()
|
cur = db_conn.cursor()
|
||||||
@@ -904,12 +947,12 @@ def setup_test_homeserver(
|
|||||||
# Need to let the HS build an auth handler and then mess with it
|
# Need to let the HS build an auth handler and then mess with it
|
||||||
# because AuthHandler's constructor requires the HS, so we can't make one
|
# because AuthHandler's constructor requires the HS, so we can't make one
|
||||||
# beforehand and pass it in to the HS's constructor (chicken / egg)
|
# beforehand and pass it in to the HS's constructor (chicken / egg)
|
||||||
async def hash(p):
|
async def hash(p: str) -> str:
|
||||||
return hashlib.md5(p.encode("utf8")).hexdigest()
|
return hashlib.md5(p.encode("utf8")).hexdigest()
|
||||||
|
|
||||||
hs.get_auth_handler().hash = hash
|
hs.get_auth_handler().hash = hash
|
||||||
|
|
||||||
async def validate_hash(p, h):
|
async def validate_hash(p: str, h: str) -> bool:
|
||||||
return hashlib.md5(p.encode("utf8")).hexdigest() == h
|
return hashlib.md5(p.encode("utf8")).hexdigest() == h
|
||||||
|
|
||||||
hs.get_auth_handler().validate_hash = validate_hash
|
hs.get_auth_handler().validate_hash = validate_hash
|
||||||
|
|||||||
Reference in New Issue
Block a user