Merge branch 'develop' of github.com:element-hq/synapse into develop

This commit is contained in:
Andrew Morgan
2025-10-01 09:40:38 +01:00
152 changed files with 2333 additions and 1052 deletions

View File

@@ -0,0 +1,193 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import gc
import weakref
from synapse.app.homeserver import SynapseHomeServer
from synapse.storage.background_updates import UpdaterStatus
from tests.server import (
cleanup_test_reactor_system_event_triggers,
get_clock,
setup_test_homeserver,
)
from tests.unittest import HomeserverTestCase
class HomeserverCleanShutdownTestCase(HomeserverTestCase):
def setUp(self) -> None:
pass
# NOTE: ideally we'd have another test to ensure we properly shutdown with
# real in-flight HTTP requests since those result in additional resources being
# setup that hold strong references to the homeserver.
# Mainly, the HTTP channel created by a real TCP connection from client to server
# is held open between requests and care needs to be taken in Twisted to ensure it is properly
# closed in a timely manner during shutdown. Simulating this behaviour in a unit test
# won't be as good as a proper integration test in complement.
def test_clean_homeserver_shutdown(self) -> None:
"""Ensure the `SynapseHomeServer` can be fully shutdown and garbage collected"""
self.reactor, self.clock = get_clock()
self.hs = setup_test_homeserver(
cleanup_func=self.addCleanup,
reactor=self.reactor,
homeserver_to_use=SynapseHomeServer,
clock=self.clock,
)
self.wait_for_background_updates()
hs_ref = weakref.ref(self.hs)
# Run the reactor so any `callWhenRunning` functions can be cleared out.
self.reactor.run()
# This would normally happen as part of `HomeServer.shutdown` but the `MemoryReactor`
# we use in tests doesn't handle this properly (see doc comment)
cleanup_test_reactor_system_event_triggers(self.reactor)
# Cleanup the homeserver.
self.get_success(self.hs.shutdown())
# Cleanup the internal reference in our test case
del self.hs
# Force garbage collection.
gc.collect()
# Ensure the `HomeServer` hs been garbage collected by attempting to use the
# weakref to it.
if hs_ref() is not None:
self.fail("HomeServer reference should not be valid at this point")
# To help debug this test when it fails, it is useful to leverage the
# `objgraph` module.
# The following code serves as an example of what I have found to be useful
# when tracking down references holding the `SynapseHomeServer` in memory:
#
# all_objects = gc.get_objects()
# for obj in all_objects:
# try:
# # These are a subset of types that are typically involved with
# # holding the `HomeServer` in memory. You may want to inspect
# # other types as well.
# if isinstance(obj, DataStore):
# print(sys.getrefcount(obj), "refs to", obj)
# if not isinstance(obj, weakref.ProxyType):
# db_obj = obj
# if isinstance(obj, SynapseHomeServer):
# print(sys.getrefcount(obj), "refs to", obj)
# if not isinstance(obj, weakref.ProxyType):
# synapse_hs = obj
# if isinstance(obj, SynapseSite):
# print(sys.getrefcount(obj), "refs to", obj)
# if not isinstance(obj, weakref.ProxyType):
# sysite = obj
# if isinstance(obj, DatabasePool):
# print(sys.getrefcount(obj), "refs to", obj)
# if not isinstance(obj, weakref.ProxyType):
# dbpool = obj
# except Exception:
# pass
#
# print(sys.getrefcount(hs_ref()), "refs to", hs_ref())
#
# # The following values for `max_depth` and `too_many` have been found to
# # render a useful amount of information without taking an overly long time
# # to generate the result.
# objgraph.show_backrefs(synapse_hs, max_depth=10, too_many=10)
def test_clean_homeserver_shutdown_mid_background_updates(self) -> None:
"""Ensure the `SynapseHomeServer` can be fully shutdown and garbage collected
before background updates have completed"""
self.reactor, self.clock = get_clock()
self.hs = setup_test_homeserver(
cleanup_func=self.addCleanup,
reactor=self.reactor,
homeserver_to_use=SynapseHomeServer,
clock=self.clock,
)
# Pump the background updates by a single iteration, just to ensure any extra
# resources it uses have been started.
store = weakref.proxy(self.hs.get_datastores().main)
self.get_success(store.db_pool.updates.do_next_background_update(False), by=0.1)
hs_ref = weakref.ref(self.hs)
# Run the reactor so any `callWhenRunning` functions can be cleared out.
self.reactor.run()
# This would normally happen as part of `HomeServer.shutdown` but the `MemoryReactor`
# we use in tests doesn't handle this properly (see doc comment)
cleanup_test_reactor_system_event_triggers(self.reactor)
# Ensure the background updates are not complete.
self.assertNotEqual(store.db_pool.updates.get_status(), UpdaterStatus.COMPLETE)
# Cleanup the homeserver.
self.get_success(self.hs.shutdown())
# Cleanup the internal reference in our test case
del self.hs
# Force garbage collection.
gc.collect()
# Ensure the `HomeServer` hs been garbage collected by attempting to use the
# weakref to it.
if hs_ref() is not None:
self.fail("HomeServer reference should not be valid at this point")
# To help debug this test when it fails, it is useful to leverage the
# `objgraph` module.
# The following code serves as an example of what I have found to be useful
# when tracking down references holding the `SynapseHomeServer` in memory:
#
# all_objects = gc.get_objects()
# for obj in all_objects:
# try:
# # These are a subset of types that are typically involved with
# # holding the `HomeServer` in memory. You may want to inspect
# # other types as well.
# if isinstance(obj, DataStore):
# print(sys.getrefcount(obj), "refs to", obj)
# if not isinstance(obj, weakref.ProxyType):
# db_obj = obj
# if isinstance(obj, SynapseHomeServer):
# print(sys.getrefcount(obj), "refs to", obj)
# if not isinstance(obj, weakref.ProxyType):
# synapse_hs = obj
# if isinstance(obj, SynapseSite):
# print(sys.getrefcount(obj), "refs to", obj)
# if not isinstance(obj, weakref.ProxyType):
# sysite = obj
# if isinstance(obj, DatabasePool):
# print(sys.getrefcount(obj), "refs to", obj)
# if not isinstance(obj, weakref.ProxyType):
# dbpool = obj
# except Exception:
# pass
#
# print(sys.getrefcount(hs_ref()), "refs to", hs_ref())
#
# # The following values for `max_depth` and `too_many` have been found to
# # render a useful amount of information without taking an overly long time
# # to generate the result.
# objgraph.show_backrefs(synapse_hs, max_depth=10, too_many=10)

View File

@@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import List, Optional, Sequence, Tuple, cast
from typing import List, Optional, Sequence, Tuple
from unittest.mock import AsyncMock, Mock
from typing_extensions import TypeAlias
@@ -44,13 +44,12 @@ from synapse.types import DeviceListUpdates, JsonDict
from synapse.util.clock import Clock
from tests import unittest
from ..utils import MockClock
from tests.server import get_clock
class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
def setUp(self) -> None:
self.clock = MockClock()
self.reactor, self.clock = get_clock()
self.store = Mock()
self.as_api = Mock()
@@ -168,16 +167,18 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
)
class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
class ApplicationServiceSchedulerRecovererTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None:
self.clock = MockClock()
super().setUp()
self.reactor, self.clock = get_clock()
self.as_api = Mock()
self.store = Mock()
self.service = Mock()
self.callback = AsyncMock()
self.recoverer = _Recoverer(
server_name="test_server",
clock=cast(Clock, self.clock),
hs=self.hs,
clock=self.clock,
as_api=self.as_api,
store=self.store,
service=self.service,
@@ -202,7 +203,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
txn.send = AsyncMock(return_value=True)
txn.complete = AsyncMock(return_value=None)
# wait for exp backoff
self.clock.advance_time(2)
self.reactor.advance(2)
self.assertEqual(1, txn.send.call_count)
self.assertEqual(1, txn.complete.call_count)
# 2 because it needs to get None to know there are no more txns
@@ -229,21 +230,21 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = AsyncMock(return_value=False)
txn.complete = AsyncMock(return_value=None)
self.clock.advance_time(2)
self.reactor.advance(2)
self.assertEqual(1, txn.send.call_count)
self.assertEqual(0, txn.complete.call_count)
self.assertEqual(0, self.callback.call_count)
self.clock.advance_time(4)
self.reactor.advance(4)
self.assertEqual(2, txn.send.call_count)
self.assertEqual(0, txn.complete.call_count)
self.assertEqual(0, self.callback.call_count)
self.clock.advance_time(8)
self.reactor.advance(8)
self.assertEqual(3, txn.send.call_count)
self.assertEqual(0, txn.complete.call_count)
self.assertEqual(0, self.callback.call_count)
txn.send = AsyncMock(return_value=True) # successfully send the txn
pop_txn = True # returns the txn the first time, then no more.
self.clock.advance_time(16)
self.reactor.advance(16)
self.assertEqual(1, txn.send.call_count) # new mock reset call count
self.assertEqual(1, txn.complete.call_count)
self.callback.assert_called_once_with(self.recoverer)
@@ -268,7 +269,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
txn.send = AsyncMock(return_value=False)
txn.complete = AsyncMock(return_value=None)
self.clock.advance_time(2)
self.reactor.advance(2)
self.assertEqual(1, txn.send.call_count)
self.assertEqual(0, txn.complete.call_count)
self.assertEqual(0, self.callback.call_count)

View File

@@ -24,6 +24,7 @@ from synapse.config.cache import CacheConfig, add_resizable_cache
from synapse.types import JsonDict
from synapse.util.caches.lrucache import LruCache
from tests.server import get_clock
from tests.unittest import TestCase
@@ -32,6 +33,7 @@ class CacheConfigTests(TestCase):
# Reset caches before each test since there's global state involved.
self.config = CacheConfig(RootConfig())
self.config.reset()
_, self.clock = get_clock()
def tearDown(self) -> None:
# Also reset the caches after each test to leave state pristine.
@@ -75,7 +77,9 @@ class CacheConfigTests(TestCase):
the default cache size in the interim, and then resized once the config
is loaded.
"""
cache: LruCache = LruCache(max_size=100, server_name="test_server")
cache: LruCache = LruCache(
max_size=100, clock=self.clock, server_name="test_server"
)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 50)
@@ -96,7 +100,9 @@ class CacheConfigTests(TestCase):
self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches()
cache: LruCache = LruCache(max_size=100, server_name="test_server")
cache: LruCache = LruCache(
max_size=100, clock=self.clock, server_name="test_server"
)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 200)
@@ -106,7 +112,9 @@ class CacheConfigTests(TestCase):
the default cache size in the interim, and then resized to the new
default cache size once the config is loaded.
"""
cache: LruCache = LruCache(max_size=100, server_name="test_server")
cache: LruCache = LruCache(
max_size=100, clock=self.clock, server_name="test_server"
)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 50)
@@ -126,7 +134,9 @@ class CacheConfigTests(TestCase):
self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches()
cache: LruCache = LruCache(max_size=100, server_name="test_server")
cache: LruCache = LruCache(
max_size=100, clock=self.clock, server_name="test_server"
)
add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
self.assertEqual(cache.max_size, 150)
@@ -145,15 +155,21 @@ class CacheConfigTests(TestCase):
self.config.read_config(config, config_dir_path="", data_dir_path="")
self.config.resize_all_caches()
cache_a: LruCache = LruCache(max_size=100, server_name="test_server")
cache_a: LruCache = LruCache(
max_size=100, clock=self.clock, server_name="test_server"
)
add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
self.assertEqual(cache_a.max_size, 200)
cache_b: LruCache = LruCache(max_size=100, server_name="test_server")
cache_b: LruCache = LruCache(
max_size=100, clock=self.clock, server_name="test_server"
)
add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor)
self.assertEqual(cache_b.max_size, 300)
cache_c: LruCache = LruCache(max_size=100, server_name="test_server")
cache_c: LruCache = LruCache(
max_size=100, clock=self.clock, server_name="test_server"
)
add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
self.assertEqual(cache_c.max_size, 200)
@@ -168,6 +184,7 @@ class CacheConfigTests(TestCase):
cache: LruCache = LruCache(
max_size=self.config.event_cache_size,
clock=self.clock,
apply_cache_factor_from_config=False,
server_name="test_server",
)

View File

@@ -231,7 +231,10 @@ class MSC3861OAuthDelegation(TestCase):
reactor, clock = get_clock()
with self.assertRaises(ConfigError):
setup_test_homeserver(
self.addCleanup, reactor=reactor, clock=clock, config=config
cleanup_func=self.addCleanup,
config=config,
reactor=reactor,
clock=clock,
)
def test_jwt_auth_cannot_be_enabled(self) -> None:
@@ -395,7 +398,10 @@ class MasAuthDelegation(TestCase):
reactor, clock = get_clock()
with self.assertRaises(ConfigError):
setup_test_homeserver(
self.addCleanup, reactor=reactor, clock=clock, config=config
cleanup_func=self.addCleanup,
config=config,
reactor=reactor,
clock=clock,
)
@skip_unless(HAS_AUTHLIB, "requires authlib")

View File

@@ -19,7 +19,17 @@
#
#
from typing import Dict, Iterable, List, Optional
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
TypeVar,
)
from unittest.mock import AsyncMock, Mock
from parameterized import parameterized
@@ -36,6 +46,7 @@ from synapse.appservice import (
TransactionUnusedFallbackKeys,
)
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client import login, receipts, register, room, sendtodevice
from synapse.server import HomeServer
from synapse.types import (
@@ -49,9 +60,14 @@ from synapse.util.clock import Clock
from synapse.util.stringutils import random_string
from tests import unittest
from tests.server import get_clock
from tests.test_utils import event_injection
from tests.unittest import override_config
from tests.utils import MockClock
if TYPE_CHECKING:
from typing_extensions import LiteralString
R = TypeVar("R")
class AppServiceHandlerTestCase(unittest.TestCase):
@@ -61,14 +77,27 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_store = Mock()
self.mock_as_api = AsyncMock()
self.mock_scheduler = Mock()
self.reactor, self.clock = get_clock()
hs = Mock()
def test_run_as_background_process(
desc: "LiteralString",
func: Callable[..., Awaitable[Optional[R]]],
*args: Any,
**kwargs: Any,
) -> "defer.Deferred[Optional[R]]":
# Ignore linter error as this is used only for testing purposes (i.e. outside of Synapse).
return run_as_background_process(desc, "test_server", func, *args, **kwargs) # type: ignore[untracked-background-process]
hs.run_as_background_process = test_run_as_background_process
hs.get_datastores.return_value = Mock(main=self.mock_store)
self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None)
self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None)
self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None)
hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
hs.get_clock.return_value = self.clock
self.handler = ApplicationServicesHandler(hs)
self.event_source = hs.get_event_sources()

View File

@@ -21,7 +21,6 @@
#
import copy
from unittest import mock
from twisted.internet.testing import MemoryReactor
@@ -50,7 +49,7 @@ room_keys = {
class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
return self.setup_test_homeserver(replication_layer=mock.Mock())
return self.setup_test_homeserver()
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.handler = hs.get_e2e_room_keys_handler()

View File

@@ -79,15 +79,17 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
) -> HomeServer:
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
mock_keyring = Mock(spec=["verify_json_for_server", "shutdown"])
mock_keyring.verify_json_for_server = AsyncMock(return_value=True)
mock_keyring.shutdown = Mock()
# we mock out the federation client too
self.mock_federation_client = AsyncMock(spec=["put_json"])
self.mock_federation_client.put_json.return_value = (200, "OK")
self.mock_federation_client.agent = MatrixFederationAgent(
server_name="OUR_STUB_HOMESERVER_NAME",
reactor=reactor,
reactor=self.reactor,
clock=self.clock,
tls_client_options_factory=None,
user_agent=b"SynapseInTrialTest/0.0.0",
ip_allowlist=None,
@@ -96,7 +98,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
self.reactor.pump((1000,))
self.mock_hs_notifier = Mock()
hs = self.setup_test_homeserver(

View File

@@ -65,7 +65,7 @@ from synapse.util.caches.ttlcache import TTLCache
from tests import unittest
from tests.http import dummy_address, get_test_ca_cert_file, wrap_server_factory_for_tls
from tests.server import FakeTransport, ThreadedMemoryReactorClock
from tests.server import FakeTransport, get_clock
from tests.utils import checked_cast, default_config
logger = logging.getLogger(__name__)
@@ -73,7 +73,7 @@ logger = logging.getLogger(__name__)
class MatrixFederationAgentTests(unittest.TestCase):
def setUp(self) -> None:
self.reactor = ThreadedMemoryReactorClock()
self.reactor, self.clock = get_clock()
self.mock_resolver = AsyncMock(spec=SrvResolver)
@@ -98,6 +98,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.well_known_resolver = WellKnownResolver(
server_name="OUR_STUB_HOMESERVER_NAME",
reactor=self.reactor,
clock=self.clock,
agent=Agent(self.reactor, contextFactory=self.tls_factory),
user_agent=b"test-agent",
well_known_cache=self.well_known_cache,
@@ -280,6 +281,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
return MatrixFederationAgent(
server_name="OUR_STUB_HOMESERVER_NAME",
reactor=cast(ISynapseReactor, self.reactor),
clock=self.clock,
tls_client_options_factory=self.tls_factory,
user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided.
ip_allowlist=IPSet(),
@@ -1024,6 +1026,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
agent = MatrixFederationAgent(
server_name="OUR_STUB_HOMESERVER_NAME",
reactor=self.reactor,
clock=self.clock,
tls_client_options_factory=tls_factory,
user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below.
ip_allowlist=IPSet(),
@@ -1033,6 +1036,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
_well_known_resolver=WellKnownResolver(
server_name="OUR_STUB_HOMESERVER_NAME",
reactor=cast(ISynapseReactor, self.reactor),
clock=self.clock,
agent=Agent(self.reactor, contextFactory=tls_factory),
user_agent=b"test-agent",
well_known_cache=self.well_known_cache,

View File

@@ -30,7 +30,7 @@ from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.logging.context import LoggingContext, current_context
from tests import unittest
from tests.utils import MockClock
from tests.server import get_clock
class SrvResolverTestCase(unittest.TestCase):
@@ -105,7 +105,7 @@ class SrvResolverTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_from_cache(self) -> Generator["Deferred[object]", object, None]:
clock = MockClock()
reactor, clock = get_clock()
dns_client_mock = Mock(spec_set=["lookupService"])
dns_client_mock.lookupService = Mock(spec_set=[])

View File

@@ -63,10 +63,6 @@ def check_logcontext(context: LoggingContextOrSentinel) -> None:
class FederationClientTests(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
return hs
def prepare(
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
) -> None:

View File

@@ -163,7 +163,9 @@ class TracingScopeTestCase(TestCase):
# implements `ISynapseThreadlessReactor` (combination of the normal Twisted
# Reactor/Clock interfaces), via inheritance from
# `twisted.internet.testing.MemoryReactor` and `twisted.internet.testing.Clock`
clock = Clock(
# Ignore `multiple-internal-clocks` linter error here since we are creating a `Clock`
# for testing purposes.
clock = Clock( # type: ignore[multiple-internal-clocks]
reactor, # type: ignore[arg-type]
server_name="test_server",
)
@@ -234,7 +236,9 @@ class TracingScopeTestCase(TestCase):
# implements `ISynapseThreadlessReactor` (combination of the normal Twisted
# Reactor/Clock interfaces), via inheritance from
# `twisted.internet.testing.MemoryReactor` and `twisted.internet.testing.Clock`
clock = Clock(
# Ignore `multiple-internal-clocks` linter error here since we are creating a `Clock`
# for testing purposes.
clock = Clock( # type: ignore[multiple-internal-clocks]
reactor, # type: ignore[arg-type]
server_name="test_server",
)

View File

@@ -37,7 +37,6 @@ from synapse.util.stringutils import (
from tests import unittest
from tests.unittest import override_config
from tests.utils import MockClock
class MediaRetentionTestCase(unittest.HomeserverTestCase):
@@ -51,12 +50,6 @@ class MediaRetentionTestCase(unittest.HomeserverTestCase):
admin.register_servlets_for_client_rest_resource,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
# We need to be able to test advancing time in the homeserver, so we
# replace the test homeserver's default clock with a MockClock, which
# supports advancing time.
return self.setup_test_homeserver(clock=MockClock())
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.remote_server_name = "remote.homeserver"
self.store = hs.get_datastores().main

View File

@@ -164,7 +164,10 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
"""
CACHE_NAME = "cache_metrics_test_fgjkbdfg"
cache: DeferredCache[str, str] = DeferredCache(
name=CACHE_NAME, server_name=self.hs.hostname, max_entries=777
name=CACHE_NAME,
clock=self.hs.get_clock(),
server_name=self.hs.hostname,
max_entries=777,
)
metrics_map = get_latest_metrics()
@@ -212,10 +215,10 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
"""
CACHE_NAME = "cache_metric_multiple_servers_test"
cache1: DeferredCache[str, str] = DeferredCache(
name=CACHE_NAME, server_name="hs1", max_entries=777
name=CACHE_NAME, clock=self.clock, server_name="hs1", max_entries=777
)
cache2: DeferredCache[str, str] = DeferredCache(
name=CACHE_NAME, server_name="hs2", max_entries=777
name=CACHE_NAME, clock=self.clock, server_name="hs2", max_entries=777
)
metrics_map = get_latest_metrics()

View File

@@ -173,7 +173,13 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Set up the server side protocol
server_address = IPv4Address("TCP", host, port)
channel = self.site.buildProtocol((host, port))
# The type ignore is here because mypy doesn't think the host/port tuple is of
# the correct type, even though it is the exact example given for
# `twisted.internet.interfaces.IAddress`.
# Mypy was happy with the type before we overrode `buildProtocol` in
# `SynapseSite`, probably because there was enough inheritance indirection before
# withe the argument not having a type associated with it.
channel = self.site.buildProtocol((host, port)) # type: ignore[arg-type]
# hook into the channel's request factory so that we can keep a record
# of the requests
@@ -185,7 +191,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
requests.append(request)
return request
channel.requestFactory = request_factory
channel.requestFactory = request_factory # type: ignore[method-assign]
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
@@ -427,7 +433,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Set up the server side protocol
server_address = IPv4Address("TCP", host, port)
channel = self._hs_to_site[hs].buildProtocol((host, port))
channel = self._hs_to_site[hs].buildProtocol((host, port)) # type: ignore[arg-type]
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(

View File

@@ -66,10 +66,11 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
def setUp(self) -> None:
super().setUp()
reactor, _ = get_clock()
reactor, clock = get_clock()
self.matrix_federation_agent = MatrixFederationAgent(
server_name="OUR_STUB_HOMESERVER_NAME",
reactor=reactor,
clock=clock,
tls_client_options_factory=None,
user_agent=b"SynapseInTrialTest/0.0.0",
ip_allowlist=None,

View File

@@ -24,6 +24,7 @@ import synapse
from synapse.module_api import cached
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import get_clock
logger = logging.getLogger(__name__)
@@ -36,6 +37,7 @@ KEY = "mykey"
class TestCache:
current_value = FIRST_VALUE
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
async def cached_function(self, user_id: str) -> str:

View File

@@ -29,16 +29,19 @@ from synapse.logging.context import SENTINEL_CONTEXT, LoggingContext, current_co
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache
from synapse.types import ISynapseReactor, JsonDict
from synapse.util.clock import Clock
from synapse.util.constants import (
MILLISECONDS_PER_SECOND,
)
from tests import unittest
from tests.utils import MockClock
from tests.server import get_clock
reactor = cast(ISynapseReactor, _reactor)
class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self) -> None:
self.clock = MockClock()
self.reactor, self.clock = get_clock()
self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
@@ -90,8 +93,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
) -> Generator["defer.Deferred[Any]", object, None]:
@defer.inlineCallbacks
def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]:
# Ignore `multiple-internal-clocks` linter error here since we are creating a `Clock`
# for testing purposes.
yield defer.ensureDeferred(
Clock(reactor, server_name="test_server").sleep(0)
Clock(reactor, server_name="test_server").sleep(0) # type: ignore[multiple-internal-clocks]
)
return 1, {}
@@ -180,8 +185,9 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg"
)
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
# Advance time just under the cleanup period.
# Should NOT have cleaned up yet
self.reactor.advance((CLEANUP_PERIOD_MS - 1) / MILLISECONDS_PER_SECOND)
yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg"
@@ -189,7 +195,8 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
# still using cache
cb.assert_called_once_with("an arg")
self.clock.advance_time_msec(CLEANUP_PERIOD_MS)
# Advance time just after the cleanup period.
self.reactor.advance(2 / MILLISECONDS_PER_SECOND)
yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg"

View File

@@ -170,7 +170,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# make a second homeserver, configured to use the first one as a key notary
self.http_client2 = Mock()
config = default_config(name="keyclient")
config = default_config(server_name="keyclient")
config["trusted_key_servers"] = [
{
"server_name": self.hs.hostname,

View File

@@ -28,6 +28,7 @@ import sqlite3
import time
import uuid
import warnings
import weakref
from collections import deque
from io import SEEK_END, BytesIO
from typing import (
@@ -56,7 +57,7 @@ from zope.interface import implementer
import twisted
from twisted.enterprise import adbapi
from twisted.internet import address, tcp, threads, udp
from twisted.internet import address, defer, tcp, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
@@ -114,7 +115,6 @@ from tests.utils import (
POSTGRES_USER,
SQLITE_PERSIST_DB,
USE_POSTGRES_FOR_TESTS,
MockClock,
default_config,
)
@@ -525,6 +525,19 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
# overwrite it again.
self.nameResolver = SimpleResolverComplexifier(FakeResolver())
def run(self) -> None:
"""
Override the call from `MemoryReactorClock` to add an additional step that
cleans up any `whenRunningHooks` that have been called.
This is necessary for a clean shutdown to occur as these hooks can hold
references to the `SynapseHomeServer`.
"""
super().run()
# `MemoryReactorClock` never clears the hooks that have already been called.
# So manually clear the hooks here after they have been run.
self.whenRunningHooks.clear()
def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
raise NotImplementedError()
@@ -650,6 +663,19 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
super().advance(0)
def cleanup_test_reactor_system_event_triggers(
reactor: ThreadedMemoryReactorClock,
) -> None:
"""Cleanup any registered system event triggers.
The `twisted.internet.test.ThreadedMemoryReactor` does not implement
`removeSystemEventTrigger` so won't clean these triggers up on it's own properly.
When trying to override `removeSystemEventTrigger` in `ThreadedMemoryReactorClock`
in order to implement this functionality, twisted complains about the reactor being
unclean and fails some tests.
"""
reactor.triggers.clear()
def validate_connector(connector: tcp.Connector, expected_ip: str) -> None:
"""Try to validate the obtained connector as it would happen when
synapse is running and the conection will be established.
@@ -781,14 +807,19 @@ class ThreadPool:
d: "Deferred[None]" = Deferred()
d.addCallback(lambda x: function(*args, **kwargs))
d.addBoth(_)
self._reactor.callLater(0, d.callback, True)
# mypy ignored here because:
# - this is part of the test infrastructure (outside of Synapse) so tracking
# these calls for for homeserver shutdown doesn't make sense.
self._reactor.callLater(0, d.callback, True) # type: ignore[call-later-not-tracked]
return d
def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]:
clock = ThreadedMemoryReactorClock()
hs_clock = Clock(clock, server_name="test_server")
return clock, hs_clock
# Ignore the linter error since this is an expected usage of creating a `Clock` for
# testing purposes.
reactor = ThreadedMemoryReactorClock()
hs_clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks]
return reactor, hs_clock
@implementer(ITCPTransport)
@@ -899,10 +930,16 @@ class FakeTransport:
# some implementations of IProducer (for example, FileSender)
# don't return a deferred.
d = maybeDeferred(self.producer.resumeProducing)
d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))
# mypy ignored here because:
# - this is part of the test infrastructure (outside of Synapse) so tracking
# these calls for for homeserver shutdown doesn't make sense.
d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) # type: ignore[call-later-not-tracked,call-overload]
if not streaming:
self._reactor.callLater(0.0, _produce)
# mypy ignored here because:
# - this is part of the test infrastructure (outside of Synapse) so tracking
# these calls for for homeserver shutdown doesn't make sense.
self._reactor.callLater(0.0, _produce) # type: ignore[call-later-not-tracked]
def write(self, byt: bytes) -> None:
if self.disconnecting:
@@ -914,7 +951,10 @@ class FakeTransport:
# TLSMemoryBIOProtocol) get very confused if a read comes back while they are
# still doing a write. Doing a callLater here breaks the cycle.
if self.autoflush:
self._reactor.callLater(0.0, self.flush)
# mypy ignored here because:
# - this is part of the test infrastructure (outside of Synapse) so tracking
# these calls for for homeserver shutdown doesn't make sense.
self._reactor.callLater(0.0, self.flush) # type: ignore[call-later-not-tracked]
def writeSequence(self, seq: Iterable[bytes]) -> None:
for x in seq:
@@ -944,7 +984,10 @@ class FakeTransport:
self.buffer = self.buffer[len(to_write) :]
if self.buffer and self.autoflush:
self._reactor.callLater(0.0, self.flush)
# mypy ignored here because:
# - this is part of the test infrastructure (outside of Synapse) so tracking
# these calls for for homeserver shutdown doesn't make sense.
self._reactor.callLater(0.0, self.flush) # type: ignore[call-later-not-tracked]
if not self.buffer and self.disconnecting:
logger.info("FakeTransport: Buffer now empty, completing disconnect")
@@ -1020,12 +1063,14 @@ class TestHomeServer(HomeServer):
def setup_test_homeserver(
cleanup_func: Callable[[Callable[[], None]], None],
name: str = "test",
*,
cleanup_func: Callable[[Callable[[], Optional["Deferred[None]"]]], None],
server_name: str = "test",
config: Optional[HomeServerConfig] = None,
reactor: Optional[ISynapseReactor] = None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
**kwargs: Any,
db_txn_limit: Optional[int] = None,
**extra_homeserver_attributes: Any,
) -> HomeServer:
"""
Setup a homeserver suitable for running tests against. Keyword arguments
@@ -1035,28 +1080,44 @@ def setup_test_homeserver(
Args:
cleanup_func : The function used to register a cleanup routine for
after the test.
after the test. If the function returns a Deferred, the
test case will wait until the Deferred has fired before
proceeding to the next cleanup function.
server_name: Homeserver name
config: Homeserver config
reactor: Twisted reactor
homeserver_to_use: Homeserver class to instantiate.
db_txn_limit: Gives the maximum number of database transactions to run per
connection before reconnecting. 0 means no limit. If unset, defaults to None
here which will default upstream to `0`.
**extra_homeserver_attributes: Additional keyword arguments to install as
`@cache_in_self` attributes on the homeserver. For example, `clock` will be
installed as `hs._clock`.
Calling this method directly is deprecated: you should instead derive from
HomeserverTestCase.
"""
if reactor is None:
from twisted.internet import reactor as _reactor
reactor = cast(ISynapseReactor, _reactor)
reactor = ThreadedMemoryReactorClock()
if config is None:
config = default_config(name, parse=True)
config = default_config(server_name, parse=True)
server_name = config.server.server_name
if not isinstance(server_name, str):
raise ConfigError("Must be a string", ("server_name",))
if "clock" not in extra_homeserver_attributes:
# Ignore `multiple-internal-clocks` linter error here since we are creating a `Clock`
# for testing purposes (i.e. outside of Synapse).
extra_homeserver_attributes["clock"] = Clock(reactor, server_name=server_name) # type: ignore[multiple-internal-clocks]
config.caches.resize_all_caches()
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
database_config = {
database_config: JsonDict = {
"name": "psycopg2",
"args": {
"dbname": test_db,
@@ -1088,10 +1149,6 @@ def setup_test_homeserver(
"args": {"database": test_db_location, "cp_min": 1, "cp_max": 1},
}
server_name = config.server.server_name
if not isinstance(server_name, str):
raise ConfigError("Must be a string", ("server_name",))
# Check if we have set up a DB that we can use as a template.
global PREPPED_SQLITE_DB_CONN
if PREPPED_SQLITE_DB_CONN is None:
@@ -1111,8 +1168,8 @@ def setup_test_homeserver(
database_config["_TEST_PREPPED_CONN"] = PREPPED_SQLITE_DB_CONN
if "db_txn_limit" in kwargs:
database_config["txn_limit"] = kwargs["db_txn_limit"]
if db_txn_limit is not None:
database_config["txn_limit"] = db_txn_limit
database = DatabaseConnectionConfig("master", database_config)
config.database.databases = [database]
@@ -1139,17 +1196,30 @@ def setup_test_homeserver(
db_conn.close()
hs = homeserver_to_use(
name,
server_name,
config=config,
version_string="Synapse/tests",
reactor=reactor,
)
# Register the cleanup hook
cleanup_func(hs.cleanup)
# Capture the `hs` as a `weakref` here to ensure there is no scenario where uncalled
# cleanup functions result in holding the `hs` in memory.
cleanup_hs_ref = weakref.ref(hs)
def shutdown_hs_on_cleanup() -> "Deferred[None]":
cleanup_hs = cleanup_hs_ref()
deferred: "Deferred[None]" = defer.succeed(None)
if cleanup_hs is not None:
deferred = defer.ensureDeferred(cleanup_hs.shutdown())
return deferred
# Register the cleanup hook for the homeserver.
# A full `hs.shutdown()` is necessary otherwise CI tests will fail while exhibiting
# strange behaviours.
cleanup_func(shutdown_hs_on_cleanup)
# Install @cache_in_self attributes
for key, val in kwargs.items():
for key, val in extra_homeserver_attributes.items():
setattr(hs, "_" + key, val)
# Mock TLS
@@ -1175,14 +1245,18 @@ def setup_test_homeserver(
hs.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False
if USE_POSTGRES_FOR_TESTS:
database_pool = hs.get_datastores().databases[0]
# Capture the `database_pool` as a `weakref` here to ensure there is no scenario where uncalled
# cleanup functions result in holding the `hs` in memory.
database_pool = weakref.ref(hs.get_datastores().databases[0])
# We need to do cleanup on PostgreSQL
def cleanup() -> None:
import psycopg2
# Close all the db pools
database_pool._db_pool.close()
db_pool = database_pool()
if db_pool is not None:
db_pool._db_pool.close()
dropped = False

View File

@@ -86,7 +86,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
conn_pool.runWithConnection = runWithConnection
config = default_config(name="test", parse=True)
config = default_config(server_name="test", parse=True)
hs = TestHomeServer("test", config=config)
if USE_POSTGRES_FOR_TESTS:

View File

@@ -26,9 +26,10 @@ from synapse.util.distributor import Distributor
from . import unittest
class DistributorTestCase(unittest.TestCase):
class DistributorTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None:
self.dist = Distributor(server_name="test_server")
super().setUp()
self.dist = Distributor(hs=self.hs)
def test_signal_dispatch(self) -> None:
self.dist.declare("alert")

View File

@@ -55,9 +55,9 @@ class JsonResourceTests(unittest.TestCase):
reactor, clock = get_clock()
self.reactor = reactor
self.homeserver = setup_test_homeserver(
self.addCleanup,
clock=clock,
cleanup_func=self.addCleanup,
reactor=self.reactor,
clock=clock,
)
def test_handler_for_request(self) -> None:
@@ -217,9 +217,9 @@ class OptionsResourceTests(unittest.TestCase):
reactor, clock = get_clock()
self.reactor = reactor
self.homeserver = setup_test_homeserver(
self.addCleanup,
clock=clock,
cleanup_func=self.addCleanup,
reactor=self.reactor,
clock=clock,
)
class DummyResource(Resource):

View File

@@ -29,7 +29,6 @@ from typing import (
Optional,
Set,
Tuple,
cast,
)
from unittest.mock import AsyncMock, Mock
@@ -43,12 +42,11 @@ from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
from synapse.types import MutableStateMap, StateMap
from synapse.types.state import StateFilter
from synapse.util.clock import Clock
from synapse.util.macaroons import MacaroonGenerator
from tests import unittest
from .utils import MockClock, default_config
from tests.server import get_clock
from tests.utils import default_config
_next_event_id = 1000
@@ -248,7 +246,7 @@ class StateTestCase(unittest.TestCase):
"hostname",
]
)
clock = cast(Clock, MockClock())
reactor, clock = get_clock()
hs.config = default_config("tesths", True)
hs.get_datastores.return_value = Mock(
main=self.dummy_store,

View File

@@ -1,79 +0,0 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2016 OpenMarket Ltd
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
from tests import unittest
from tests.utils import MockClock
class MockClockTestCase(unittest.TestCase):
def setUp(self) -> None:
self.clock = MockClock()
def test_advance_time(self) -> None:
start_time = self.clock.time()
self.clock.advance_time(20)
self.assertEqual(20, self.clock.time() - start_time)
def test_later(self) -> None:
invoked = [0, 0]
def _cb0() -> None:
invoked[0] = 1
self.clock.call_later(10, _cb0)
def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
self.assertFalse(invoked[0])
self.clock.advance_time(15)
self.assertTrue(invoked[0])
self.assertFalse(invoked[1])
self.clock.advance_time(5)
self.assertTrue(invoked[1])
def test_cancel_later(self) -> None:
invoked = [0, 0]
def _cb0() -> None:
invoked[0] = 1
t0 = self.clock.call_later(10, _cb0)
def _cb1() -> None:
invoked[1] = 1
self.clock.call_later(20, _cb1)
self.clock.cancel_call_later(t0)
self.clock.advance_time(30)
self.assertFalse(invoked[0])
self.assertTrue(invoked[1])

View File

@@ -80,7 +80,7 @@ from synapse.logging.context import (
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.types import ISynapseReactor, JsonDict, Requester, UserID, create_requester
from synapse.util.clock import Clock
from synapse.util.httpresourcetree import create_resource_tree
@@ -99,6 +99,8 @@ from tests.utils import checked_cast, default_config, setupdb
setupdb()
setup_logging()
logger = logging.getLogger(__name__)
TV = TypeVar("TV")
_ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
@@ -135,7 +137,7 @@ def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
return _around
_TConfig = TypeVar("_TConfig", Config, RootConfig)
_TConfig = TypeVar("_TConfig", Config, HomeServerConfig)
def deepcopy_config(config: _TConfig) -> _TConfig:
@@ -161,13 +163,13 @@ def deepcopy_config(config: _TConfig) -> _TConfig:
@functools.lru_cache(maxsize=8)
def _parse_config_dict(config: str) -> RootConfig:
def _parse_config_dict(config: str) -> HomeServerConfig:
config_obj = HomeServerConfig()
config_obj.parse_config_dict(json.loads(config), "", "")
return config_obj
def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig:
def make_homeserver_config_obj(config: Dict[str, Any]) -> HomeServerConfig:
"""Creates a :class:`HomeServerConfig` instance with the given configuration dict.
This is equivalent to::
@@ -392,8 +394,8 @@ class HomeserverTestCase(TestCase):
hijacking the authentication system to return a fixed user, and then
calling the prepare function.
"""
# We need to share the reactor between the homeserver and all of our test utils.
self.reactor, self.clock = get_clock()
self._hs_args = {"clock": self.clock, "reactor": self.reactor}
self.hs = self.make_homeserver(self.reactor, self.clock)
self.hs.get_datastores().main.tests_allow_no_chain_cover_index = False
@@ -511,7 +513,7 @@ class HomeserverTestCase(TestCase):
Function to be overridden in subclasses.
"""
hs = self.setup_test_homeserver()
hs = self.setup_test_homeserver(reactor=reactor, clock=clock)
return hs
def create_test_resource(self) -> Resource:
@@ -634,7 +636,12 @@ class HomeserverTestCase(TestCase):
)
def setup_test_homeserver(
self, server_name: Optional[str] = None, **kwargs: Any
self,
server_name: Optional[str] = None,
config: Optional[JsonDict] = None,
reactor: Optional[ISynapseReactor] = None,
clock: Optional[Clock] = None,
**extra_homeserver_attributes: Any,
) -> HomeServer:
"""
Set up the test homeserver, meant to be called by the overridable
@@ -647,12 +654,15 @@ class HomeserverTestCase(TestCase):
Returns:
synapse.server.HomeServer
"""
kwargs = dict(kwargs)
kwargs.update(self._hs_args)
if "config" not in kwargs:
if config is None:
config = self.default_config()
else:
config = kwargs["config"]
# The sane default is to use the same reactor and clock as our other test utils
if reactor is None:
reactor = self.reactor
if clock is None:
clock = self.clock
# The server name can be specified using either the `name` argument or a config
# override. The `name` argument takes precedence over any config overrides.
@@ -661,19 +671,24 @@ class HomeserverTestCase(TestCase):
# Parse the config from a config dict into a HomeServerConfig
config_obj = make_homeserver_config_obj(config)
kwargs["config"] = config_obj
# The server name in the config is now `name`, if provided, or the `server_name`
# from a config override, or the default of "test". Whichever it is, we
# construct a homeserver with a matching name.
server_name = config_obj.server.server_name
kwargs["name"] = server_name
async def run_bg_updates() -> None:
with LoggingContext(name="run_bg_updates", server_name=server_name):
self.get_success(stor.db_pool.updates.run_background_updates(False))
hs = setup_test_homeserver(self.addCleanup, **kwargs)
hs = setup_test_homeserver(
cleanup_func=self.addCleanup,
server_name=server_name,
config=config_obj,
reactor=reactor,
clock=clock,
**extra_homeserver_attributes,
)
stor = hs.get_datastores().main
# Run the database background updates, when running against "master".

View File

@@ -26,20 +26,26 @@ from twisted.internet import defer
from synapse.util.caches.deferred_cache import DeferredCache
from tests.server import get_clock
from tests.unittest import TestCase
class DeferredCacheTestCase(TestCase):
def setUp(self) -> None:
super().setUp()
_, self.clock = get_clock()
def test_empty(self) -> None:
cache: DeferredCache[str, int] = DeferredCache(
name="test", server_name="test_server"
name="test", clock=self.clock, server_name="test_server"
)
with self.assertRaises(KeyError):
cache.get("foo")
def test_hit(self) -> None:
cache: DeferredCache[str, int] = DeferredCache(
name="test", server_name="test_server"
name="test", clock=self.clock, server_name="test_server"
)
cache.prefill("foo", 123)
@@ -47,7 +53,7 @@ class DeferredCacheTestCase(TestCase):
def test_hit_deferred(self) -> None:
cache: DeferredCache[str, int] = DeferredCache(
name="test", server_name="test_server"
name="test", clock=self.clock, server_name="test_server"
)
origin_d: "defer.Deferred[int]" = defer.Deferred()
set_d = cache.set("k1", origin_d)
@@ -72,7 +78,7 @@ class DeferredCacheTestCase(TestCase):
def test_callbacks(self) -> None:
"""Invalidation callbacks are called at the right time"""
cache: DeferredCache[str, int] = DeferredCache(
name="test", server_name="test_server"
name="test", clock=self.clock, server_name="test_server"
)
callbacks = set()
@@ -107,7 +113,7 @@ class DeferredCacheTestCase(TestCase):
def test_set_fail(self) -> None:
cache: DeferredCache[str, int] = DeferredCache(
name="test", server_name="test_server"
name="test", clock=self.clock, server_name="test_server"
)
callbacks = set()
@@ -146,7 +152,7 @@ class DeferredCacheTestCase(TestCase):
def test_get_immediate(self) -> None:
cache: DeferredCache[str, int] = DeferredCache(
name="test", server_name="test_server"
name="test", clock=self.clock, server_name="test_server"
)
d1: "defer.Deferred[int]" = defer.Deferred()
cache.set("key1", d1)
@@ -164,7 +170,7 @@ class DeferredCacheTestCase(TestCase):
def test_invalidate(self) -> None:
cache: DeferredCache[Tuple[str], int] = DeferredCache(
name="test", server_name="test_server"
name="test", clock=self.clock, server_name="test_server"
)
cache.prefill(("foo",), 123)
cache.invalidate(("foo",))
@@ -174,7 +180,7 @@ class DeferredCacheTestCase(TestCase):
def test_invalidate_all(self) -> None:
cache: DeferredCache[str, str] = DeferredCache(
name="testcache", server_name="test_server"
name="testcache", clock=self.clock, server_name="test_server"
)
callback_record = [False, False]
@@ -220,6 +226,7 @@ class DeferredCacheTestCase(TestCase):
def test_eviction(self) -> None:
cache: DeferredCache[int, str] = DeferredCache(
name="test",
clock=self.clock,
server_name="test_server",
max_entries=2,
apply_cache_factor_from_config=False,
@@ -238,6 +245,7 @@ class DeferredCacheTestCase(TestCase):
def test_eviction_lru(self) -> None:
cache: DeferredCache[int, str] = DeferredCache(
name="test",
clock=self.clock,
server_name="test_server",
max_entries=2,
apply_cache_factor_from_config=False,
@@ -260,6 +268,7 @@ class DeferredCacheTestCase(TestCase):
def test_eviction_iterable(self) -> None:
cache: DeferredCache[int, List[str]] = DeferredCache(
name="test",
clock=self.clock,
server_name="test_server",
max_entries=3,
apply_cache_factor_from_config=False,

View File

@@ -49,6 +49,7 @@ from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from tests import unittest
from tests.server import get_clock
from tests.test_utils import get_awaitable_result
logger = logging.getLogger(__name__)
@@ -56,7 +57,10 @@ logger = logging.getLogger(__name__)
def run_on_reactor() -> "Deferred[int]":
d: "Deferred[int]" = Deferred()
cast(IReactorTime, reactor).callLater(0, d.callback, 0)
# mypy ignored here because:
# - this is part of the test infrastructure (outside of Synapse) so tracking
# these calls for for homeserver shutdown doesn't make sense.
cast(IReactorTime, reactor).callLater(0, d.callback, 0) # type: ignore[call-later-not-tracked]
return make_deferred_yieldable(d)
@@ -67,6 +71,7 @@ class DescriptorTestCase(unittest.TestCase):
def __init__(self) -> None:
self.mock = mock.Mock()
self.server_name = "test_server"
_, self.clock = get_clock() # nb must be called this for @cached
@descriptors.cached()
def fn(self, arg1: int, arg2: int) -> str:
@@ -102,6 +107,7 @@ class DescriptorTestCase(unittest.TestCase):
def __init__(self) -> None:
self.mock = mock.Mock()
self.server_name = "test_server"
_, self.clock = get_clock() # nb must be called this for @cached
@descriptors.cached(num_args=1)
def fn(self, arg1: int, arg2: int) -> str:
@@ -148,6 +154,7 @@ class DescriptorTestCase(unittest.TestCase):
def __init__(self) -> None:
self.mock = mock.Mock()
self.server_name = "test_server"
_, self.clock = get_clock() # nb must be called this for @cached
obj = Cls()
obj.mock.return_value = "fish"
@@ -179,6 +186,7 @@ class DescriptorTestCase(unittest.TestCase):
def __init__(self) -> None:
self.mock = mock.Mock()
self.server_name = "test_server"
_, self.clock = get_clock() # nb must be called this for @cached
@descriptors.cached()
def fn(self, arg1: int, kwarg1: int = 2) -> str:
@@ -214,6 +222,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls:
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
def fn(self, arg1: int) -> NoReturn:
@@ -239,6 +248,7 @@ class DescriptorTestCase(unittest.TestCase):
result: Optional[Deferred] = None
call_count = 0
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
def fn(self, arg1: int) -> Deferred:
@@ -293,6 +303,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls:
server_name = "test_server"
_, clock = get_clock() # nb must be called this for @cached
@descriptors.cached()
def fn(self, arg1: int) -> "Deferred[int]":
@@ -337,6 +348,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls:
server_name = "test_server"
_, clock = get_clock() # nb must be called this for @cached
@descriptors.cached()
def fn(self, arg1: int) -> Deferred:
@@ -381,6 +393,7 @@ class DescriptorTestCase(unittest.TestCase):
def __init__(self) -> None:
self.mock = mock.Mock()
self.server_name = "test_server"
_, self.clock = get_clock() # nb must be called this for @cached
@descriptors.cached()
def fn(self, arg1: int, arg2: int = 2, arg3: int = 3) -> str:
@@ -419,6 +432,7 @@ class DescriptorTestCase(unittest.TestCase):
def __init__(self) -> None:
self.mock = mock.Mock()
self.server_name = "test_server"
_, self.clock = get_clock() # nb must be called this for @cached
@descriptors.cached(iterable=True)
def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]:
@@ -453,6 +467,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls:
server_name = "test_server"
_, clock = get_clock() # nb must be called this for @cached
@descriptors.cached(iterable=True)
def fn(self, arg1: int) -> NoReturn:
@@ -476,6 +491,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls:
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached(cache_context=True)
async def func1(self, key: str, cache_context: _CacheContext) -> int:
@@ -504,6 +520,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls:
server_name = "test_server"
_, clock = get_clock() # nb must be called this for @cached
@cached()
async def fn(self, arg1: int) -> str:
@@ -537,6 +554,7 @@ class DescriptorTestCase(unittest.TestCase):
class Cls:
inner_context_was_finished = False
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
async def fn(self, arg1: int) -> str:
@@ -583,6 +601,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
def test_passthrough(self) -> Generator["Deferred[Any]", object, None]:
class A:
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
def func(self, key: str) -> str:
@@ -599,6 +618,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
class A:
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
def func(self, key: str) -> str:
@@ -619,6 +639,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
class A:
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
def func(self, key: str) -> str:
@@ -639,6 +660,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
def test_invalidate_missing(self) -> None:
class A:
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
def func(self, key: str) -> str:
@@ -652,6 +674,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
class A:
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached(max_entries=10)
def func(self, key: int) -> int:
@@ -681,6 +704,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
class A:
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
def func(self, key: str) -> "Deferred[int]":
@@ -701,6 +725,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
class A:
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
def func(self, key: str) -> str:
@@ -736,6 +761,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
class A:
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached(max_entries=2)
def func(self, key: str) -> str:
@@ -775,6 +801,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase):
class A:
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
def func(self, key: str) -> str:
@@ -824,6 +851,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def __init__(self) -> None:
self.mock = mock.Mock()
self.server_name = "test_server"
_, self.clock = get_clock() # nb must be called this for @cached
@descriptors.cached()
def fn(self, arg1: int, arg2: int) -> None:
@@ -890,6 +918,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def __init__(self) -> None:
self.mock = mock.Mock()
self.server_name = "test_server"
_, self.clock = get_clock() # nb must be called this for @cached
@descriptors.cached()
def fn(self, arg1: int) -> None:
@@ -934,6 +963,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def __init__(self) -> None:
self.mock = mock.Mock()
self.server_name = "test_server"
_, self.clock = get_clock() # nb must be called this for @cached
@descriptors.cached()
def fn(self, arg1: int, arg2: int) -> None:
@@ -975,6 +1005,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
class Cls:
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
def fn(self, arg1: int) -> None:
@@ -1011,6 +1042,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
class Cls:
inner_context_was_finished = False
server_name = "test_server" # nb must be called this for @cached
_, clock = get_clock() # nb must be called this for @cached
@cached()
def fn(self, arg1: int) -> None:
@@ -1055,6 +1087,7 @@ class CachedListDescriptorTestCase(unittest.TestCase):
class Cls:
server_name = "test_server"
_, clock = get_clock() # nb must be called this for @cached
@descriptors.cached(tree=True)
def fn(self, room_id: str, event_id: str) -> None:

View File

@@ -25,7 +25,6 @@ from parameterized import parameterized_class
from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
from twisted.internet.task import Clock
from twisted.python.failure import Failure
from synapse.logging.context import (
@@ -152,7 +151,7 @@ class ObservableDeferredTest(TestCase):
class TimeoutDeferredTest(TestCase):
def setUp(self) -> None:
self.clock = Clock()
self.reactor, self.clock = get_clock()
def test_times_out(self) -> None:
"""Basic test case that checks that the original deferred is cancelled and that
@@ -165,12 +164,16 @@ class TimeoutDeferredTest(TestCase):
cancelled = True
non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
timing_out_d = timeout_deferred(
deferred=non_completing_d,
timeout=1.0,
clock=self.clock,
)
self.assertNoResult(timing_out_d)
self.assertFalse(cancelled, "deferred was cancelled prematurely")
self.clock.pump((1.0,))
self.reactor.pump((1.0,))
self.assertTrue(cancelled, "deferred was not cancelled by timeout")
self.failureResultOf(timing_out_d, defer.TimeoutError)
@@ -183,11 +186,15 @@ class TimeoutDeferredTest(TestCase):
raise Exception("can't cancel this deferred")
non_completing_d: Deferred = Deferred(canceller)
timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock)
timing_out_d = timeout_deferred(
deferred=non_completing_d,
timeout=1.0,
clock=self.clock,
)
self.assertNoResult(timing_out_d)
self.clock.pump((1.0,))
self.reactor.pump((1.0,))
self.failureResultOf(timing_out_d, defer.TimeoutError)
@@ -227,7 +234,7 @@ class TimeoutDeferredTest(TestCase):
timing_out_d = timeout_deferred(
deferred=incomplete_d,
timeout=1.0,
reactor=self.clock,
clock=self.clock,
)
self.assertNoResult(timing_out_d)
# We should still be in the logcontext we started in
@@ -243,7 +250,7 @@ class TimeoutDeferredTest(TestCase):
# we're pumping the reactor in the block and return us back to our current
# logcontext after the block.
with PreserveLoggingContext():
self.clock.pump(
self.reactor.pump(
# We only need to pump `1.0` (seconds) as we set
# `timeout_deferred(timeout=1.0)` above
(1.0,)
@@ -264,7 +271,7 @@ class TimeoutDeferredTest(TestCase):
self.assertEqual(current_context(), SENTINEL_CONTEXT)
class _TestException(Exception):
class _TestException(Exception): #
pass
@@ -560,8 +567,8 @@ class AwakenableSleeperTests(TestCase):
"Tests AwakenableSleeper"
def test_sleep(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
reactor, clock = get_clock()
sleeper = AwakenableSleeper(clock)
d = defer.ensureDeferred(sleeper.sleep("name", 1000))
@@ -575,8 +582,8 @@ class AwakenableSleeperTests(TestCase):
self.assertTrue(d.called)
def test_explicit_wake(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
reactor, clock = get_clock()
sleeper = AwakenableSleeper(clock)
d = defer.ensureDeferred(sleeper.sleep("name", 1000))
@@ -592,8 +599,8 @@ class AwakenableSleeperTests(TestCase):
reactor.advance(0.6)
def test_multiple_sleepers_timeout(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
reactor, clock = get_clock()
sleeper = AwakenableSleeper(clock)
d1 = defer.ensureDeferred(sleeper.sleep("name", 1000))
@@ -612,8 +619,8 @@ class AwakenableSleeperTests(TestCase):
self.assertTrue(d2.called)
def test_multiple_sleepers_wake(self) -> None:
reactor, _ = get_clock()
sleeper = AwakenableSleeper(reactor)
reactor, clock = get_clock()
sleeper = AwakenableSleeper(clock)
d1 = defer.ensureDeferred(sleeper.sleep("name", 1000))

View File

@@ -32,13 +32,12 @@ from synapse.util.batching_queue import (
number_queued,
)
from tests.server import get_clock
from tests.unittest import TestCase
from tests.unittest import HomeserverTestCase
class BatchingQueueTestCase(TestCase):
class BatchingQueueTestCase(HomeserverTestCase):
def setUp(self) -> None:
self.clock, hs_clock = get_clock()
super().setUp()
# We ensure that we remove any existing metrics for "test_queue".
try:
@@ -51,8 +50,8 @@ class BatchingQueueTestCase(TestCase):
self._pending_calls: List[Tuple[List[str], defer.Deferred]] = []
self.queue: BatchingQueue[str, str] = BatchingQueue(
name="test_queue",
server_name="test_server",
clock=hs_clock,
hs=self.hs,
clock=self.clock,
process_batch_callback=self._process_queue,
)
@@ -108,7 +107,7 @@ class BatchingQueueTestCase(TestCase):
self.assertFalse(queue_d.called)
# We should see a call to `_process_queue` after a reactor tick.
self.clock.pump([0])
self.reactor.pump([0])
self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo"])
@@ -134,7 +133,7 @@ class BatchingQueueTestCase(TestCase):
self._assert_metrics(queued=2, keys=1, in_flight=2)
self.clock.pump([0])
self.reactor.pump([0])
# We should see only *one* call to `_process_queue`
self.assertEqual(len(self._pending_calls), 1)
@@ -158,7 +157,7 @@ class BatchingQueueTestCase(TestCase):
self.assertFalse(self._pending_calls)
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1"))
self.clock.pump([0])
self.reactor.pump([0])
self.assertEqual(len(self._pending_calls), 1)
@@ -185,7 +184,7 @@ class BatchingQueueTestCase(TestCase):
self._assert_metrics(queued=2, keys=1, in_flight=2)
# We should now see a second call to `_process_queue`
self.clock.pump([0])
self.reactor.pump([0])
self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo2", "foo3"])
self.assertFalse(queue_d2.called)
@@ -206,9 +205,9 @@ class BatchingQueueTestCase(TestCase):
self.assertFalse(self._pending_calls)
queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1", key=1))
self.clock.pump([0])
self.reactor.pump([0])
queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2", key=2))
self.clock.pump([0])
self.reactor.pump([0])
# We queue up another item with key=2 to check that we will keep taking
# things off the queue.
@@ -240,7 +239,7 @@ class BatchingQueueTestCase(TestCase):
self.assertFalse(queue_d3.called)
# We should now see a call `_pending_calls` for `foo3`
self.clock.pump([0])
self.reactor.pump([0])
self.assertEqual(len(self._pending_calls), 1)
self.assertEqual(self._pending_calls[0][0], ["foo3"])
self.assertFalse(queue_d3.called)

View File

@@ -23,12 +23,14 @@
from synapse.util.caches.dictionary_cache import DictionaryCache
from tests import unittest
from tests.server import get_clock
class DictCacheTestCase(unittest.TestCase):
def setUp(self) -> None:
_, clock = get_clock()
self.cache: DictionaryCache[str, str, str] = DictionaryCache(
name="foobar", server_name="test_server", max_entries=10
name="foobar", clock=clock, server_name="test_server", max_entries=10
)
def test_simple_cache_hit_full(self) -> None:

View File

@@ -19,23 +19,23 @@
#
#
from typing import List, cast
from typing import List
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.clock import Clock
from tests.utils import MockClock
from tests.server import get_clock
from .. import unittest
class ExpiringCacheTestCase(unittest.HomeserverTestCase):
def test_get_set(self) -> None:
clock = MockClock()
reactor, clock = get_clock()
cache: ExpiringCache[str, str] = ExpiringCache(
cache_name="test",
server_name="testserver",
clock=cast(Clock, clock),
hs=self.hs,
clock=clock,
max_len=1,
)
@@ -44,11 +44,12 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache["key"], "value")
def test_eviction(self) -> None:
clock = MockClock()
reactor, clock = get_clock()
cache: ExpiringCache[str, str] = ExpiringCache(
cache_name="test",
server_name="testserver",
clock=cast(Clock, clock),
hs=self.hs,
clock=clock,
max_len=2,
)
@@ -63,11 +64,12 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key3"), "value3")
def test_iterable_eviction(self) -> None:
clock = MockClock()
reactor, clock = get_clock()
cache: ExpiringCache[str, List[int]] = ExpiringCache(
cache_name="test",
server_name="testserver",
clock=cast(Clock, clock),
hs=self.hs,
clock=clock,
max_len=5,
iterable=True,
)
@@ -87,25 +89,26 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key4"), [6, 7])
def test_time_eviction(self) -> None:
clock = MockClock()
reactor, clock = get_clock()
cache: ExpiringCache[str, int] = ExpiringCache(
cache_name="test",
server_name="testserver",
clock=cast(Clock, clock),
hs=self.hs,
clock=clock,
expiry_ms=1000,
)
cache["key"] = 1
clock.advance_time(0.5)
reactor.advance(0.5)
cache["key2"] = 2
self.assertEqual(cache.get("key"), 1)
self.assertEqual(cache.get("key2"), 2)
clock.advance_time(0.9)
reactor.advance(0.9)
self.assertEqual(cache.get("key"), None)
self.assertEqual(cache.get("key2"), 2)
clock.advance_time(1)
reactor.advance(1)
self.assertEqual(cache.get("key"), None)
self.assertEqual(cache.get("key2"), None)

View File

@@ -66,7 +66,8 @@ class LoggingContextTestCase(unittest.TestCase):
"""
Test `Clock.sleep`
"""
clock = Clock(reactor, server_name="test_server")
# Ignore linter error since we are creating a `Clock` for testing purposes.
clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks]
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@@ -90,7 +91,7 @@ class LoggingContextTestCase(unittest.TestCase):
# so that the test can complete and we see the underlying error.
callback_finished = True
reactor.callLater(0, lambda: defer.ensureDeferred(competing_callback()))
reactor.callLater(0, lambda: defer.ensureDeferred(competing_callback())) # type: ignore[call-later-not-tracked]
with LoggingContext(name="foo", server_name="test_server"):
await clock.sleep(0)
@@ -111,7 +112,8 @@ class LoggingContextTestCase(unittest.TestCase):
"""
Test `Clock.looping_call`
"""
clock = Clock(reactor, server_name="test_server")
# Ignore linter error since we are creating a `Clock` for testing purposes.
clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks]
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@@ -161,7 +163,8 @@ class LoggingContextTestCase(unittest.TestCase):
"""
Test `Clock.looping_call_now`
"""
clock = Clock(reactor, server_name="test_server")
# Ignore linter error since we are creating a `Clock` for testing purposes.
clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks]
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@@ -209,7 +212,8 @@ class LoggingContextTestCase(unittest.TestCase):
"""
Test `Clock.call_later`
"""
clock = Clock(reactor, server_name="test_server")
# Ignore linter error since we are creating a `Clock` for testing purposes.
clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks]
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@@ -261,7 +265,8 @@ class LoggingContextTestCase(unittest.TestCase):
`d.callback(None)` without anything else. See the *Deferred callbacks* section
of docs/log_contexts.md for more details.
"""
clock = Clock(reactor, server_name="test_server")
# Ignore linter error since we are creating a `Clock` for testing purposes.
clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks]
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@@ -318,7 +323,8 @@ class LoggingContextTestCase(unittest.TestCase):
`d.callback(None)` without anything else. See the *Deferred callbacks* section
of docs/log_contexts.md for more details.
"""
clock = Clock(reactor, server_name="test_server")
# Ignore linter error since we are creating a `Clock` for testing purposes.
clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks]
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@@ -379,7 +385,8 @@ class LoggingContextTestCase(unittest.TestCase):
`d.callback(None)` without anything else. See the *Deferred callbacks* section
of docs/log_contexts.md for more details.
"""
clock = Clock(reactor, server_name="test_server")
# Ignore linter error since we are creating a `Clock` for testing purposes.
clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks]
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@@ -450,7 +457,8 @@ class LoggingContextTestCase(unittest.TestCase):
self._check_test_key("sentinel")
async def _test_run_in_background(self, function: Callable[[], object]) -> None:
clock = Clock(reactor, server_name="test_server")
# Ignore linter error since we are creating a `Clock` for testing purposes.
clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks]
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@@ -492,7 +500,8 @@ class LoggingContextTestCase(unittest.TestCase):
@logcontext_clean
async def test_run_in_background_with_blocking_fn(self) -> None:
async def blocking_function() -> None:
await Clock(reactor, server_name="test_server").sleep(0)
# Ignore linter error since we are creating a `Clock` for testing purposes.
await Clock(reactor, server_name="test_server").sleep(0) # type: ignore[multiple-internal-clocks]
await self._test_run_in_background(blocking_function)
@@ -525,7 +534,8 @@ class LoggingContextTestCase(unittest.TestCase):
async def testfunc() -> None:
self._check_test_key("foo")
d = defer.ensureDeferred(Clock(reactor, server_name="test_server").sleep(0))
# Ignore linter error since we are creating a `Clock` for testing purposes.
d = defer.ensureDeferred(Clock(reactor, server_name="test_server").sleep(0)) # type: ignore[multiple-internal-clocks]
self.assertIs(current_context(), SENTINEL_CONTEXT)
await d
self._check_test_key("foo")
@@ -554,7 +564,8 @@ class LoggingContextTestCase(unittest.TestCase):
This will stress the logic around incomplete deferreds in `run_coroutine_in_background`.
"""
clock = Clock(reactor, server_name="test_server")
# Ignore linter error since we are creating a `Clock` for testing purposes.
clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks]
# Sanity check that we start in the sentinel context
self._check_test_key("sentinel")
@@ -645,7 +656,7 @@ class LoggingContextTestCase(unittest.TestCase):
# the synapse rules.
def blocking_function() -> defer.Deferred:
d: defer.Deferred = defer.Deferred()
reactor.callLater(0, d.callback, None)
reactor.callLater(0, d.callback, None) # type: ignore[call-later-not-tracked]
return d
sentinel_context = current_context()
@@ -692,7 +703,7 @@ def _chained_deferred_function() -> defer.Deferred:
def cb(res: object) -> defer.Deferred:
d2: defer.Deferred = defer.Deferred()
reactor.callLater(0, d2.callback, res)
reactor.callLater(0, d2.callback, res) # type: ignore[call-later-not-tracked]
return d2
d.addCallback(cb)

View File

@@ -29,18 +29,28 @@ from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entrie
from synapse.util.caches.treecache import TreeCache
from tests import unittest
from tests.server import get_clock
from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase):
def setUp(self) -> None:
super().setUp()
_, self.clock = get_clock()
def test_get_set(self) -> None:
cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server")
cache: LruCache[str, str] = LruCache(
max_size=1, clock=self.clock, server_name="test_server"
)
cache["key"] = "value"
self.assertEqual(cache.get("key"), "value")
self.assertEqual(cache["key"], "value")
def test_eviction(self) -> None:
cache: LruCache[int, int] = LruCache(max_size=2, server_name="test_server")
cache: LruCache[int, int] = LruCache(
max_size=2, clock=self.clock, server_name="test_server"
)
cache[1] = 1
cache[2] = 2
@@ -54,7 +64,9 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get(3), 3)
def test_setdefault(self) -> None:
cache: LruCache[str, int] = LruCache(max_size=1, server_name="test_server")
cache: LruCache[str, int] = LruCache(
max_size=1, clock=self.clock, server_name="test_server"
)
self.assertEqual(cache.setdefault("key", 1), 1)
self.assertEqual(cache.get("key"), 1)
self.assertEqual(cache.setdefault("key", 2), 1)
@@ -63,7 +75,9 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(cache.get("key"), 2)
def test_pop(self) -> None:
cache: LruCache[str, int] = LruCache(max_size=1, server_name="test_server")
cache: LruCache[str, int] = LruCache(
max_size=1, clock=self.clock, server_name="test_server"
)
cache["key"] = 1
self.assertEqual(cache.pop("key"), 1)
self.assertEqual(cache.pop("key"), None)
@@ -71,7 +85,10 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
def test_del_multi(self) -> None:
# The type here isn't quite correct as they don't handle TreeCache well.
cache: LruCache[Tuple[str, str], str] = LruCache(
max_size=4, cache_type=TreeCache, server_name="test_server"
max_size=4,
clock=self.clock,
cache_type=TreeCache,
server_name="test_server",
)
cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof"
@@ -91,7 +108,9 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
# Man from del_multi say "Yes".
def test_clear(self) -> None:
cache: LruCache[str, int] = LruCache(max_size=1, server_name="test_server")
cache: LruCache[str, int] = LruCache(
max_size=1, clock=self.clock, server_name="test_server"
)
cache["key"] = 1
cache.clear()
self.assertEqual(len(cache), 0)
@@ -99,7 +118,10 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
@override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
def test_special_size(self) -> None:
cache: LruCache = LruCache(
max_size=10, server_name="test_server", cache_name="mycache"
max_size=10,
clock=self.clock,
server_name="test_server",
cache_name="mycache",
)
self.assertEqual(cache.max_size, 100)
@@ -107,7 +129,9 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self) -> None:
m = Mock()
cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server")
cache: LruCache[str, str] = LruCache(
max_size=1, clock=self.clock, server_name="test_server"
)
cache.set("key", "value")
self.assertFalse(m.called)
@@ -126,7 +150,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_multi_get(self) -> None:
m = Mock()
cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server")
cache: LruCache[str, str] = LruCache(
max_size=1, clock=self.clock, server_name="test_server"
)
cache.set("key", "value")
self.assertFalse(m.called)
@@ -145,7 +171,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_set(self) -> None:
m = Mock()
cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server")
cache: LruCache[str, str] = LruCache(
max_size=1, clock=self.clock, server_name="test_server"
)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
@@ -161,7 +189,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_pop(self) -> None:
m = Mock()
cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server")
cache: LruCache[str, str] = LruCache(
max_size=1, clock=self.clock, server_name="test_server"
)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
@@ -182,7 +212,10 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
m4 = Mock()
# The type here isn't quite correct as they don't handle TreeCache well.
cache: LruCache[Tuple[str, str], str] = LruCache(
max_size=4, cache_type=TreeCache, server_name="test_server"
max_size=4,
clock=self.clock,
cache_type=TreeCache,
server_name="test_server",
)
cache.set(("a", "1"), "value", callbacks=[m1])
@@ -205,7 +238,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_clear(self) -> None:
m1 = Mock()
m2 = Mock()
cache: LruCache[str, str] = LruCache(max_size=5, server_name="test_server")
cache: LruCache[str, str] = LruCache(
max_size=5, clock=self.clock, server_name="test_server"
)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
@@ -222,7 +257,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
m1 = Mock(name="m1")
m2 = Mock(name="m2")
m3 = Mock(name="m3")
cache: LruCache[str, str] = LruCache(max_size=2, server_name="test_server")
cache: LruCache[str, str] = LruCache(
max_size=2, clock=self.clock, server_name="test_server"
)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
@@ -259,7 +296,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_evict(self) -> None:
cache: LruCache[str, List[int]] = LruCache(
max_size=5, size_callback=len, server_name="test_server"
max_size=5, clock=self.clock, size_callback=len, server_name="test_server"
)
cache["key1"] = [0]
cache["key2"] = [1, 2]
@@ -284,7 +321,10 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase):
def test_zero_size_drop_from_cache(self) -> None:
"""Test that `drop_from_cache` works correctly with 0-sized entries."""
cache: LruCache[str, List[int]] = LruCache(
max_size=5, size_callback=lambda x: 0, server_name="test_server"
max_size=5,
clock=self.clock,
size_callback=lambda x: 0,
server_name="test_server",
)
cache["key1"] = []
@@ -402,7 +442,10 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase):
class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase):
def test_invalidate_simple(self) -> None:
cache: LruCache[str, int] = LruCache(
max_size=10, server_name="test_server", extra_index_cb=lambda k, v: str(v)
max_size=10,
clock=self.hs.get_clock(),
server_name="test_server",
extra_index_cb=lambda k, v: str(v),
)
cache["key1"] = 1
cache["key2"] = 2
@@ -417,7 +460,10 @@ class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase):
def test_invalidate_multi(self) -> None:
cache: LruCache[str, int] = LruCache(
max_size=10, server_name="test_server", extra_index_cb=lambda k, v: str(v)
max_size=10,
clock=self.hs.get_clock(),
server_name="test_server",
extra_index_cb=lambda k, v: str(v),
)
cache["key1"] = 1
cache["key2"] = 1

View File

@@ -35,6 +35,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
hs=self.hs,
clock=self.clock,
store=store,
)
@@ -57,6 +58,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
hs=self.hs,
clock=self.clock,
store=store,
)
@@ -89,6 +91,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
hs=self.hs,
clock=self.clock,
store=store,
),
@@ -104,6 +107,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
hs=self.hs,
clock=self.clock,
store=store,
)
@@ -139,6 +143,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
hs=self.hs,
clock=self.clock,
store=store,
)
@@ -165,6 +170,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
hs=self.hs,
clock=self.clock,
store=store,
notifier=notifier,
@@ -238,6 +244,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
hs=self.hs,
clock=self.clock,
store=store,
)
@@ -261,6 +268,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
hs=self.hs,
clock=self.clock,
store=store,
),
@@ -273,6 +281,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
hs=self.hs,
clock=self.clock,
store=store,
)
@@ -297,6 +306,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
get_retry_limiter(
destination="test_dest",
our_server_name=self.hs.hostname,
hs=self.hs,
clock=self.clock,
store=store,
),

View File

@@ -24,27 +24,19 @@ import os
import signal
from types import FrameType, TracebackType
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
overload,
)
import attr
from typing_extensions import ParamSpec
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import create_engine
@@ -140,21 +132,27 @@ def setupdb() -> None:
@overload
def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]: ...
def default_config(
server_name: str, parse: Literal[False] = ...
) -> Dict[str, object]: ...
@overload
def default_config(name: str, parse: Literal[True]) -> HomeServerConfig: ...
def default_config(server_name: str, parse: Literal[True]) -> HomeServerConfig: ...
def default_config(
name: str, parse: bool = False
server_name: str, parse: bool = False
) -> Union[Dict[str, object], HomeServerConfig]:
"""
Create a reasonable test config.
Args:
server_name: homeserver name
parse: TODO
"""
config_dict = {
"server_name": name,
"server_name": server_name,
# Setting this to an empty list turns off federation sending.
"federation_sender_instances": [],
"media_store_path": "media",
@@ -247,101 +245,6 @@ def mock_getRawHeaders(headers=None): # type: ignore[no-untyped-def]
return getRawHeaders
P = ParamSpec("P")
@attr.s(slots=True, auto_attribs=True)
class Timer:
absolute_time: float
callback: Callable[[], None]
expired: bool
# TODO: Make this generic over a ParamSpec?
@attr.s(slots=True, auto_attribs=True)
class Looper:
func: Callable[..., Any]
interval: float # seconds
last: float
args: Tuple[object, ...]
kwargs: Dict[str, object]
class MockClock:
now = 1000.0
def __init__(self) -> None:
# Timers in no particular order
self.timers: List[Timer] = []
self.loopers: List[Looper] = []
def time(self) -> float:
return self.now
def time_msec(self) -> int:
return int(self.time() * 1000)
def call_later(
self,
delay: float,
callback: Callable[P, object],
*args: P.args,
**kwargs: P.kwargs,
) -> Timer:
ctx = current_context()
def wrapped_callback() -> None:
set_current_context(ctx)
callback(*args, **kwargs)
t = Timer(self.now + delay, wrapped_callback, False)
self.timers.append(t)
return t
def looping_call(
self,
function: Callable[P, object],
interval: float,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs))
def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None:
if timer.expired:
if not ignore_errs:
raise Exception("Cannot cancel an expired timer")
timer.expired = True
self.timers = [t for t in self.timers if t != timer]
# For unit testing
def advance_time(self, secs: float) -> None:
self.now += secs
timers = self.timers
self.timers = []
for t in timers:
if t.expired:
raise Exception("Timer already expired")
if self.now >= t.absolute_time:
t.expired = True
t.callback()
else:
self.timers.append(t)
for looped in self.loopers:
if looped.last + looped.interval < self.now:
looped.func(*looped.args, **looped.kwargs)
looped.last = self.now
def advance_time_msec(self, ms: float) -> None:
self.advance_time(ms / 1000.0)
async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
"""Creates and persist a creation event for the given room"""