Compare commits
24 Commits
v1.140.0rc
...
dmr/typing
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d59916b06a | ||
|
|
e007d87eb1 | ||
|
|
0ee260f4d4 | ||
|
|
0f7af76d45 | ||
|
|
ca41febb7e | ||
|
|
9a42f0785f | ||
|
|
0d78696b21 | ||
|
|
a8f7a3ac07 | ||
|
|
31e42e9b74 | ||
|
|
436081dfcd | ||
|
|
78f24160bf | ||
|
|
d091b71ea2 | ||
|
|
92af1306db | ||
|
|
97d72ef11a | ||
|
|
395d624e12 | ||
|
|
3338364dbe | ||
|
|
d6ff8bdf96 | ||
|
|
a07c7335cc | ||
|
|
39d547bf74 | ||
|
|
fb5809aca4 | ||
|
|
575e1c4309 | ||
|
|
29f721f4b2 | ||
|
|
797b2bd9c7 | ||
|
|
a95ab83713 |
1
changelog.d/13028.misc
Normal file
1
changelog.d/13028.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type annotations to `tests.utils`.
|
||||
9
mypy.ini
9
mypy.ini
@@ -73,9 +73,6 @@ exclude = (?x)
|
||||
|tests/util/test_linearizer.py
|
||||
|tests/util/test_logcontext.py
|
||||
|tests/util/test_lrucache.py
|
||||
|tests/util/test_rwlock.py
|
||||
|tests/util/test_wheel_timer.py
|
||||
|tests/utils.py
|
||||
)$
|
||||
|
||||
[mypy-synapse.federation.transport.client]
|
||||
@@ -129,6 +126,12 @@ disallow_untyped_defs = True
|
||||
[mypy-tests.federation.transport.test_client]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.util.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.utils]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
|
||||
;; Dependencies without annotations
|
||||
;; Before ignoring a module, check to see if type stubs are available.
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, Generator, Optional
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
from matrix_common.versionstring import get_distribution_version_string
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from twisted.internet import defer, task
|
||||
from twisted.internet.defer import Deferred
|
||||
@@ -82,6 +83,9 @@ def unwrapFirstError(failure: Failure) -> Failure:
|
||||
return failure.value.subFailure # type: ignore[union-attr] # Issue in Twisted's annotations
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class Clock:
|
||||
"""
|
||||
@@ -110,7 +114,7 @@ class Clock:
|
||||
return int(self.time() * 1000)
|
||||
|
||||
def looping_call(
|
||||
self, f: Callable, msec: float, *args: Any, **kwargs: Any
|
||||
self, f: Callable[P, object], msec: float, *args: P.args, **kwargs: P.kwargs
|
||||
) -> LoopingCall:
|
||||
"""Call a function repeatedly.
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ GLOBAL_ROOT = ListNode["_Node"].create_root_node()
|
||||
|
||||
@wrap_as_background_process("LruCache._expire_old_entries")
|
||||
async def _expire_old_entries(
|
||||
clock: Clock, expiry_seconds: int, autotune_config: Optional[dict]
|
||||
clock: Clock, expiry_seconds: float, autotune_config: Optional[dict]
|
||||
) -> None:
|
||||
"""Walks the global cache list to find cache entries that haven't been
|
||||
accessed in the given number of seconds, or if a given memory threshold has been breached.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Generator, Optional
|
||||
from typing import Generator, NoReturn, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
from synapse.util.check_dependencies import (
|
||||
@@ -12,17 +12,19 @@ from tests.unittest import TestCase
|
||||
|
||||
|
||||
class DummyDistribution(metadata.Distribution):
|
||||
def __init__(self, version: object):
|
||||
def __init__(self, version: Optional[str]):
|
||||
self._version = version
|
||||
|
||||
# Type-ignore: This really can return None. More context here:
|
||||
# https://github.com/python/importlib_metadata/issues/371
|
||||
@property
|
||||
def version(self):
|
||||
def version(self) -> Optional[str]: # type: ignore[override]
|
||||
return self._version
|
||||
|
||||
def locate_file(self, path):
|
||||
def locate_file(self, path: object) -> NoReturn:
|
||||
raise NotImplementedError()
|
||||
|
||||
def read_text(self, filename):
|
||||
def read_text(self, filename: object) -> NoReturn:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -42,7 +44,7 @@ class TestDependencyChecker(TestCase):
|
||||
) -> Generator[None, None, None]:
|
||||
"""Pretend that looking up any distribution yields the given `distribution`."""
|
||||
|
||||
def mock_distribution(name: str):
|
||||
def mock_distribution(name: str) -> DummyDistribution:
|
||||
if distribution is None:
|
||||
raise metadata.PackageNotFoundError
|
||||
else:
|
||||
|
||||
@@ -19,7 +19,7 @@ from tests.unittest import TestCase
|
||||
|
||||
|
||||
class ChunkSeqTests(TestCase):
|
||||
def test_short_seq(self):
|
||||
def test_short_seq(self) -> None:
|
||||
parts = chunk_seq("123", 8)
|
||||
|
||||
self.assertEqual(
|
||||
@@ -27,7 +27,7 @@ class ChunkSeqTests(TestCase):
|
||||
["123"],
|
||||
)
|
||||
|
||||
def test_long_seq(self):
|
||||
def test_long_seq(self) -> None:
|
||||
parts = chunk_seq("abcdefghijklmnop", 8)
|
||||
|
||||
self.assertEqual(
|
||||
@@ -35,7 +35,7 @@ class ChunkSeqTests(TestCase):
|
||||
["abcdefgh", "ijklmnop"],
|
||||
)
|
||||
|
||||
def test_uneven_parts(self):
|
||||
def test_uneven_parts(self) -> None:
|
||||
parts = chunk_seq("abcdefghijklmnop", 5)
|
||||
|
||||
self.assertEqual(
|
||||
@@ -43,7 +43,7 @@ class ChunkSeqTests(TestCase):
|
||||
["abcde", "fghij", "klmno", "p"],
|
||||
)
|
||||
|
||||
def test_empty_input(self):
|
||||
def test_empty_input(self) -> None:
|
||||
parts: Iterable[Sequence] = chunk_seq([], 5)
|
||||
|
||||
self.assertEqual(
|
||||
@@ -53,13 +53,13 @@ class ChunkSeqTests(TestCase):
|
||||
|
||||
|
||||
class SortTopologically(TestCase):
|
||||
def test_empty(self):
|
||||
def test_empty(self) -> None:
|
||||
"Test that an empty graph works correctly"
|
||||
|
||||
graph: Dict[int, List[int]] = {}
|
||||
self.assertEqual(list(sorted_topologically([], graph)), [])
|
||||
|
||||
def test_handle_empty_graph(self):
|
||||
def test_handle_empty_graph(self) -> None:
|
||||
"Test that a graph where a node doesn't have an entry is treated as empty"
|
||||
|
||||
graph: Dict[int, List[int]] = {}
|
||||
@@ -67,7 +67,7 @@ class SortTopologically(TestCase):
|
||||
# For disconnected nodes the output is simply sorted.
|
||||
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
|
||||
|
||||
def test_disconnected(self):
|
||||
def test_disconnected(self) -> None:
|
||||
"Test that a graph with no edges work"
|
||||
|
||||
graph: Dict[int, List[int]] = {1: [], 2: []}
|
||||
@@ -75,20 +75,20 @@ class SortTopologically(TestCase):
|
||||
# For disconnected nodes the output is simply sorted.
|
||||
self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
|
||||
|
||||
def test_linear(self):
|
||||
def test_linear(self) -> None:
|
||||
"Test that a simple `4 -> 3 -> 2 -> 1` graph works"
|
||||
|
||||
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
|
||||
|
||||
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
|
||||
|
||||
def test_subset(self):
|
||||
def test_subset(self) -> None:
|
||||
"Test that only sorting a subset of the graph works"
|
||||
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]}
|
||||
|
||||
self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4])
|
||||
|
||||
def test_fork(self):
|
||||
def test_fork(self) -> None:
|
||||
"Test that a forked graph works"
|
||||
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]}
|
||||
|
||||
@@ -96,13 +96,13 @@ class SortTopologically(TestCase):
|
||||
# always get the same one.
|
||||
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
|
||||
|
||||
def test_duplicates(self):
|
||||
def test_duplicates(self) -> None:
|
||||
"Test that a graph with duplicate edges work"
|
||||
graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]}
|
||||
|
||||
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
|
||||
|
||||
def test_multiple_paths(self):
|
||||
def test_multiple_paths(self) -> None:
|
||||
"Test that a graph with multiple paths between two nodes work"
|
||||
graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class TestException(Exception):
|
||||
|
||||
|
||||
class LogFormatterTestCase(unittest.TestCase):
|
||||
def test_formatter(self):
|
||||
def test_formatter(self) -> None:
|
||||
formatter = LogFormatter()
|
||||
|
||||
try:
|
||||
|
||||
@@ -13,16 +13,19 @@
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.ratelimiting import FederationRateLimitConfig
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
from tests.server import get_clock
|
||||
from tests.server import ThreadedMemoryReactorClock, get_clock
|
||||
from tests.unittest import TestCase
|
||||
from tests.utils import default_config
|
||||
|
||||
|
||||
class FederationRateLimiterTestCase(TestCase):
|
||||
def test_ratelimit(self):
|
||||
def test_ratelimit(self) -> None:
|
||||
"""A simple test with the default values"""
|
||||
reactor, clock = get_clock()
|
||||
rc_config = build_rc_config()
|
||||
@@ -32,7 +35,7 @@ class FederationRateLimiterTestCase(TestCase):
|
||||
# shouldn't block
|
||||
self.successResultOf(d1)
|
||||
|
||||
def test_concurrent_limit(self):
|
||||
def test_concurrent_limit(self) -> None:
|
||||
"""Test what happens when we hit the concurrent limit"""
|
||||
reactor, clock = get_clock()
|
||||
rc_config = build_rc_config({"rc_federation": {"concurrent": 2}})
|
||||
@@ -56,7 +59,7 @@ class FederationRateLimiterTestCase(TestCase):
|
||||
cm2.__exit__(None, None, None)
|
||||
self.successResultOf(d3)
|
||||
|
||||
def test_sleep_limit(self):
|
||||
def test_sleep_limit(self) -> None:
|
||||
"""Test what happens when we hit the sleep limit"""
|
||||
reactor, clock = get_clock()
|
||||
rc_config = build_rc_config(
|
||||
@@ -79,7 +82,9 @@ class FederationRateLimiterTestCase(TestCase):
|
||||
self.assertAlmostEqual(sleep_time, 500, places=3)
|
||||
|
||||
|
||||
def _await_resolution(reactor, d):
|
||||
def _await_resolution(
|
||||
reactor: ThreadedMemoryReactorClock, d: "Deferred[None]"
|
||||
) -> float:
|
||||
"""advance the clock until the deferred completes.
|
||||
|
||||
Returns the number of milliseconds it took to complete.
|
||||
@@ -90,7 +95,7 @@ def _await_resolution(reactor, d):
|
||||
return (reactor.seconds() - start_time) * 1000
|
||||
|
||||
|
||||
def build_rc_config(settings: Optional[dict] = None):
|
||||
def build_rc_config(settings: Optional[dict] = None) -> FederationRateLimitConfig:
|
||||
config_dict = default_config("test")
|
||||
config_dict.update(settings or {})
|
||||
config = HomeServerConfig()
|
||||
|
||||
@@ -22,7 +22,7 @@ from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class RetryLimiterTestCase(HomeserverTestCase):
|
||||
def test_new_destination(self):
|
||||
def test_new_destination(self) -> None:
|
||||
"""A happy-path case with a new destination and a successful operation"""
|
||||
store = self.hs.get_datastores().main
|
||||
limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
|
||||
@@ -36,7 +36,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
|
||||
new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
|
||||
self.assertIsNone(new_timings)
|
||||
|
||||
def test_limiter(self):
|
||||
def test_limiter(self) -> None:
|
||||
"""General test case which walks through the process of a failing request"""
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
acquired_d: "Deferred[None]" = Deferred()
|
||||
unblock_d: "Deferred[None]" = Deferred()
|
||||
|
||||
async def reader_or_writer():
|
||||
async def reader_or_writer() -> str:
|
||||
async with read_or_write(key):
|
||||
acquired_d.callback(None)
|
||||
await unblock_d
|
||||
@@ -134,7 +134,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
d.called, msg="deferred %d was unexpectedly resolved" % (i + n)
|
||||
)
|
||||
|
||||
def test_rwlock(self):
|
||||
def test_rwlock(self) -> None:
|
||||
rwlock = ReadWriteLock()
|
||||
key = "key"
|
||||
|
||||
@@ -197,7 +197,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
_, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader")
|
||||
self.assertTrue(acquired_d.called)
|
||||
|
||||
def test_lock_handoff_to_nonblocking_writer(self):
|
||||
def test_lock_handoff_to_nonblocking_writer(self) -> None:
|
||||
"""Test a writer handing the lock to another writer that completes instantly."""
|
||||
rwlock = ReadWriteLock()
|
||||
key = "key"
|
||||
@@ -216,7 +216,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed")
|
||||
self.assertTrue(d3.called)
|
||||
|
||||
def test_cancellation_while_holding_read_lock(self):
|
||||
def test_cancellation_while_holding_read_lock(self) -> None:
|
||||
"""Test cancellation while holding a read lock.
|
||||
|
||||
A waiting writer should be given the lock when the reader holding the lock is
|
||||
@@ -242,7 +242,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual("write completed", self.successResultOf(writer_d))
|
||||
|
||||
def test_cancellation_while_holding_write_lock(self):
|
||||
def test_cancellation_while_holding_write_lock(self) -> None:
|
||||
"""Test cancellation while holding a write lock.
|
||||
|
||||
A waiting reader should be given the lock when the writer holding the lock is
|
||||
@@ -268,7 +268,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual("read completed", self.successResultOf(reader_d))
|
||||
|
||||
def test_cancellation_while_waiting_for_read_lock(self):
|
||||
def test_cancellation_while_waiting_for_read_lock(self) -> None:
|
||||
"""Test cancellation while waiting for a read lock.
|
||||
|
||||
Tests that cancelling a waiting reader:
|
||||
@@ -319,7 +319,7 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
|
||||
|
||||
def test_cancellation_while_waiting_for_write_lock(self):
|
||||
def test_cancellation_while_waiting_for_write_lock(self) -> None:
|
||||
"""Test cancellation while waiting for a write lock.
|
||||
|
||||
Tests that cancelling a waiting writer:
|
||||
|
||||
@@ -8,7 +8,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
Tests for StreamChangeCache.
|
||||
"""
|
||||
|
||||
def test_prefilled_cache(self):
|
||||
def test_prefilled_cache(self) -> None:
|
||||
"""
|
||||
Providing a prefilled cache to StreamChangeCache will result in a cache
|
||||
with the prefilled-cache entered in.
|
||||
@@ -16,7 +16,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
|
||||
self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
|
||||
|
||||
def test_has_entity_changed(self):
|
||||
def test_has_entity_changed(self) -> None:
|
||||
"""
|
||||
StreamChangeCache.entity_has_changed will mark entities as changed, and
|
||||
has_entity_changed will observe the changed entities.
|
||||
@@ -52,7 +52,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
|
||||
self.assertTrue(cache.has_entity_changed("not@here.website", 0))
|
||||
|
||||
def test_entity_has_changed_pops_off_start(self):
|
||||
def test_entity_has_changed_pops_off_start(self) -> None:
|
||||
"""
|
||||
StreamChangeCache.entity_has_changed will respect the max size and
|
||||
purge the oldest items upon reaching that max size.
|
||||
@@ -86,7 +86,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertIsNone(cache.get_all_entities_changed(1))
|
||||
|
||||
def test_get_all_entities_changed(self):
|
||||
def test_get_all_entities_changed(self) -> None:
|
||||
"""
|
||||
StreamChangeCache.get_all_entities_changed will return all changed
|
||||
entities since the given position. If the position is before the start
|
||||
@@ -142,7 +142,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
r = cache.get_all_entities_changed(3)
|
||||
self.assertTrue(r == ok1 or r == ok2)
|
||||
|
||||
def test_has_any_entity_changed(self):
|
||||
def test_has_any_entity_changed(self) -> None:
|
||||
"""
|
||||
StreamChangeCache.has_any_entity_changed will return True if any
|
||||
entities have been changed since the provided stream position, and
|
||||
@@ -168,7 +168,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
self.assertFalse(cache.has_any_entity_changed(2))
|
||||
self.assertFalse(cache.has_any_entity_changed(3))
|
||||
|
||||
def test_get_entities_changed(self):
|
||||
def test_get_entities_changed(self) -> None:
|
||||
"""
|
||||
StreamChangeCache.get_entities_changed will return the entities in the
|
||||
given list that have changed since the provided stream ID. If the
|
||||
@@ -228,7 +228,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
|
||||
{"bar@baz.net"},
|
||||
)
|
||||
|
||||
def test_max_pos(self):
|
||||
def test_max_pos(self) -> None:
|
||||
"""
|
||||
StreamChangeCache.get_max_pos_of_last_change will return the most
|
||||
recent point where the entity could have changed. If the entity is not
|
||||
|
||||
@@ -19,7 +19,7 @@ from .. import unittest
|
||||
|
||||
|
||||
class StringUtilsTestCase(unittest.TestCase):
|
||||
def test_client_secret_regex(self):
|
||||
def test_client_secret_regex(self) -> None:
|
||||
"""Ensure that client_secret does not contain illegal characters"""
|
||||
good = [
|
||||
"abcde12345",
|
||||
@@ -46,7 +46,7 @@ class StringUtilsTestCase(unittest.TestCase):
|
||||
with self.assertRaises(SynapseError):
|
||||
assert_valid_client_secret(client_secret)
|
||||
|
||||
def test_base62_encode(self):
|
||||
def test_base62_encode(self) -> None:
|
||||
self.assertEqual("0", base62_encode(0))
|
||||
self.assertEqual("10", base62_encode(62))
|
||||
self.assertEqual("1c", base62_encode(100))
|
||||
|
||||
@@ -18,31 +18,31 @@ from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class CanonicaliseEmailTests(HomeserverTestCase):
|
||||
def test_no_at(self):
|
||||
def test_no_at(self) -> None:
|
||||
with self.assertRaises(ValueError):
|
||||
canonicalise_email("address-without-at.bar")
|
||||
|
||||
def test_two_at(self):
|
||||
def test_two_at(self) -> None:
|
||||
with self.assertRaises(ValueError):
|
||||
canonicalise_email("foo@foo@test.bar")
|
||||
|
||||
def test_bad_format(self):
|
||||
def test_bad_format(self) -> None:
|
||||
with self.assertRaises(ValueError):
|
||||
canonicalise_email("user@bad.example.net@good.example.com")
|
||||
|
||||
def test_valid_format(self):
|
||||
def test_valid_format(self) -> None:
|
||||
self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar")
|
||||
|
||||
def test_domain_to_lower(self):
|
||||
def test_domain_to_lower(self) -> None:
|
||||
self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar")
|
||||
|
||||
def test_domain_with_umlaut(self):
|
||||
def test_domain_with_umlaut(self) -> None:
|
||||
self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com")
|
||||
|
||||
def test_address_casefold(self):
|
||||
def test_address_casefold(self) -> None:
|
||||
self.assertEqual(
|
||||
canonicalise_email("Strauß@Example.com"), "strauss@example.com"
|
||||
)
|
||||
|
||||
def test_address_trim(self):
|
||||
def test_address_trim(self) -> None:
|
||||
self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar")
|
||||
|
||||
@@ -19,7 +19,7 @@ from .. import unittest
|
||||
|
||||
|
||||
class TreeCacheTestCase(unittest.TestCase):
|
||||
def test_get_set_onelevel(self):
|
||||
def test_get_set_onelevel(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a",)] = "A"
|
||||
cache[("b",)] = "B"
|
||||
@@ -27,7 +27,7 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||
self.assertEqual(cache.get(("b",)), "B")
|
||||
self.assertEqual(len(cache), 2)
|
||||
|
||||
def test_pop_onelevel(self):
|
||||
def test_pop_onelevel(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a",)] = "A"
|
||||
cache[("b",)] = "B"
|
||||
@@ -36,7 +36,7 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||
self.assertEqual(cache.get(("b",)), "B")
|
||||
self.assertEqual(len(cache), 1)
|
||||
|
||||
def test_get_set_twolevel(self):
|
||||
def test_get_set_twolevel(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a", "a")] = "AA"
|
||||
cache[("a", "b")] = "AB"
|
||||
@@ -46,7 +46,7 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||
self.assertEqual(cache.get(("b", "a")), "BA")
|
||||
self.assertEqual(len(cache), 3)
|
||||
|
||||
def test_pop_twolevel(self):
|
||||
def test_pop_twolevel(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a", "a")] = "AA"
|
||||
cache[("a", "b")] = "AB"
|
||||
@@ -58,7 +58,7 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||
self.assertEqual(cache.pop(("b", "a")), None)
|
||||
self.assertEqual(len(cache), 1)
|
||||
|
||||
def test_pop_mixedlevel(self):
|
||||
def test_pop_mixedlevel(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a", "a")] = "AA"
|
||||
cache[("a", "b")] = "AB"
|
||||
@@ -72,14 +72,14 @@ class TreeCacheTestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped)))
|
||||
|
||||
def test_clear(self):
|
||||
def test_clear(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a",)] = "A"
|
||||
cache[("b",)] = "B"
|
||||
cache.clear()
|
||||
self.assertEqual(len(cache), 0)
|
||||
|
||||
def test_contains(self):
|
||||
def test_contains(self) -> None:
|
||||
cache = TreeCache()
|
||||
cache[("a",)] = "A"
|
||||
self.assertTrue(("a",) in cache)
|
||||
|
||||
@@ -18,8 +18,8 @@ from .. import unittest
|
||||
|
||||
|
||||
class WheelTimerTestCase(unittest.TestCase):
|
||||
def test_single_insert_fetch(self):
|
||||
wheel = WheelTimer(bucket_size=5)
|
||||
def test_single_insert_fetch(self) -> None:
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
|
||||
|
||||
obj = object()
|
||||
wheel.insert(100, obj, 150)
|
||||
@@ -32,8 +32,8 @@ class WheelTimerTestCase(unittest.TestCase):
|
||||
self.assertListEqual(wheel.fetch(156), [obj])
|
||||
self.assertListEqual(wheel.fetch(170), [])
|
||||
|
||||
def test_multi_insert(self):
|
||||
wheel = WheelTimer(bucket_size=5)
|
||||
def test_multi_insert(self) -> None:
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
|
||||
|
||||
obj1 = object()
|
||||
obj2 = object()
|
||||
@@ -50,15 +50,15 @@ class WheelTimerTestCase(unittest.TestCase):
|
||||
self.assertListEqual(wheel.fetch(200), [obj3])
|
||||
self.assertListEqual(wheel.fetch(210), [])
|
||||
|
||||
def test_insert_past(self):
|
||||
wheel = WheelTimer(bucket_size=5)
|
||||
def test_insert_past(self) -> None:
|
||||
wheel: WheelTimer["object"] = WheelTimer(bucket_size=5)
|
||||
|
||||
obj = object()
|
||||
wheel.insert(100, obj, 50)
|
||||
self.assertListEqual(wheel.fetch(120), [obj])
|
||||
|
||||
def test_insert_past_multi(self):
|
||||
wheel = WheelTimer(bucket_size=5)
|
||||
def test_insert_past_multi(self) -> None:
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
|
||||
|
||||
obj1 = object()
|
||||
obj2 = object()
|
||||
|
||||
194
tests/utils.py
194
tests/utils.py
@@ -15,14 +15,30 @@
|
||||
|
||||
import atexit
|
||||
import os
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
ParamSpec,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import attr
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.server import DEFAULT_ROOM_VERSION
|
||||
from synapse.logging.context import current_context, set_current_context
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.database import LoggingDatabaseConnection
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.engines import PostgresEngine, create_engine
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
|
||||
# set this to True to run the tests against postgres instead of sqlite.
|
||||
@@ -48,21 +64,27 @@ SQLITE_PERSIST_DB = os.environ.get("SYNAPSE_TEST_PERSIST_SQLITE_DB") is not None
|
||||
|
||||
# the dbname we will connect to in order to create the base database.
|
||||
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
|
||||
if TYPE_CHECKING:
|
||||
import psycopg2.extensions
|
||||
|
||||
|
||||
def setupdb():
|
||||
def setupdb() -> None:
|
||||
# If we're using PostgreSQL, set up the db once
|
||||
if USE_POSTGRES_FOR_TESTS:
|
||||
# create a PostgresEngine
|
||||
db_engine = create_engine({"name": "psycopg2", "args": {}})
|
||||
|
||||
db_engine = cast(
|
||||
PostgresEngine, create_engine({"name": "psycopg2", "args": {}})
|
||||
)
|
||||
# connect to postgres to create the base database.
|
||||
db_conn = db_engine.module.connect(
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
password=POSTGRES_PASSWORD,
|
||||
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
|
||||
db_conn = cast(
|
||||
"psycopg2.extensions.connection",
|
||||
db_engine.module.connect(
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
password=POSTGRES_PASSWORD,
|
||||
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
|
||||
),
|
||||
)
|
||||
db_conn.autocommit = True
|
||||
cur = db_conn.cursor()
|
||||
@@ -75,24 +97,30 @@ def setupdb():
|
||||
db_conn.close()
|
||||
|
||||
# Set up in the db
|
||||
db_conn = db_engine.module.connect(
|
||||
database=POSTGRES_BASE_DB,
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
password=POSTGRES_PASSWORD,
|
||||
)
|
||||
db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
|
||||
prepare_database(db_conn, db_engine, None)
|
||||
db_conn.close()
|
||||
|
||||
def _cleanup():
|
||||
db_conn = db_engine.module.connect(
|
||||
db_conn = cast(
|
||||
"psycopg2.extensions.connection",
|
||||
db_engine.module.connect(
|
||||
database=POSTGRES_BASE_DB,
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
password=POSTGRES_PASSWORD,
|
||||
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
|
||||
),
|
||||
)
|
||||
logging_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
|
||||
prepare_database(logging_conn, db_engine, None)
|
||||
logging_conn.close()
|
||||
|
||||
def _cleanup() -> None:
|
||||
db_conn = cast(
|
||||
"psycopg2.extensions.connection",
|
||||
db_engine.module.connect(
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
password=POSTGRES_PASSWORD,
|
||||
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
|
||||
),
|
||||
)
|
||||
db_conn.autocommit = True
|
||||
cur = db_conn.cursor()
|
||||
@@ -103,7 +131,19 @@ def setupdb():
|
||||
atexit.register(_cleanup)
|
||||
|
||||
|
||||
def default_config(name, parse=False):
|
||||
@overload
|
||||
def default_config(name: str, parse: Literal[False] = ...) -> Dict[str, object]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def default_config(name: str, parse: Literal[True]) -> HomeServerConfig:
|
||||
...
|
||||
|
||||
|
||||
def default_config(
|
||||
name: str, parse: bool = False
|
||||
) -> Union[Dict[str, object], HomeServerConfig]:
|
||||
"""
|
||||
Create a reasonable test config.
|
||||
"""
|
||||
@@ -181,90 +221,122 @@ def default_config(name, parse=False):
|
||||
return config_dict
|
||||
|
||||
|
||||
def mock_getRawHeaders(headers=None):
|
||||
def mock_getRawHeaders(headers=None): # type: ignore[no-untyped-def]
|
||||
headers = headers if headers is not None else {}
|
||||
|
||||
def getRawHeaders(name, default=None):
|
||||
def getRawHeaders(name, default=None): # type: ignore[no-untyped-def]
|
||||
# If the requested header is present, the real twisted function returns
|
||||
# List[str] if name is a str and List[bytes] if name is a bytes.
|
||||
# This mock doesn't support that behaviour.
|
||||
# Fortunately, none of the current callers of mock_getRawHeaders() provide a
|
||||
# headers dict, so we don't encounter this discrepancy in practice.
|
||||
return headers.get(name, default)
|
||||
|
||||
return getRawHeaders
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class Timer:
|
||||
absolute_time: float
|
||||
callback: Callable[[], None]
|
||||
expired: bool
|
||||
|
||||
|
||||
# TODO: Make this generic over a ParamSpec?
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class Looper:
|
||||
func: Callable[..., Any]
|
||||
interval: float # seconds
|
||||
last: float
|
||||
args: Tuple[object, ...]
|
||||
kwargs: Dict[str, object]
|
||||
|
||||
|
||||
class MockClock:
|
||||
now = 1000
|
||||
now = 1000.0
|
||||
|
||||
def __init__(self):
|
||||
# list of lists of [absolute_time, callback, expired] in no particular
|
||||
# order
|
||||
self.timers = []
|
||||
self.loopers = []
|
||||
def __init__(self) -> None:
|
||||
# Timers in no particular order
|
||||
self.timers: List[Timer] = []
|
||||
self.loopers: List[Looper] = []
|
||||
|
||||
def time(self):
|
||||
def time(self) -> float:
|
||||
return self.now
|
||||
|
||||
def time_msec(self):
|
||||
return self.time() * 1000
|
||||
def time_msec(self) -> int:
|
||||
return int(self.time() * 1000)
|
||||
|
||||
def call_later(self, delay, callback, *args, **kwargs):
|
||||
def call_later(
|
||||
self,
|
||||
delay: float,
|
||||
callback: Callable[P, object],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Timer:
|
||||
ctx = current_context()
|
||||
|
||||
def wrapped_callback():
|
||||
def wrapped_callback() -> None:
|
||||
set_current_context(ctx)
|
||||
callback(*args, **kwargs)
|
||||
|
||||
t = [self.now + delay, wrapped_callback, False]
|
||||
t = Timer(self.now + delay, wrapped_callback, False)
|
||||
self.timers.append(t)
|
||||
|
||||
return t
|
||||
|
||||
def looping_call(self, function, interval, *args, **kwargs):
|
||||
self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
|
||||
def looping_call(
|
||||
self,
|
||||
function: Callable[P, object],
|
||||
interval: float,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
# This type-ignore should be redundant once we use a mypy release with
|
||||
# https://github.com/python/mypy/pull/12668.
|
||||
self.loopers.append(Looper(function, interval / 1000.0, self.now, args, kwargs)) # type: ignore[arg-type]
|
||||
|
||||
def cancel_call_later(self, timer, ignore_errs=False):
|
||||
if timer[2]:
|
||||
def cancel_call_later(self, timer: Timer, ignore_errs: bool = False) -> None:
|
||||
if timer.expired:
|
||||
if not ignore_errs:
|
||||
raise Exception("Cannot cancel an expired timer")
|
||||
|
||||
timer[2] = True
|
||||
timer.expired = True
|
||||
self.timers = [t for t in self.timers if t != timer]
|
||||
|
||||
# For unit testing
|
||||
def advance_time(self, secs):
|
||||
def advance_time(self, secs: float) -> None:
|
||||
self.now += secs
|
||||
|
||||
timers = self.timers
|
||||
self.timers = []
|
||||
|
||||
for t in timers:
|
||||
time, callback, expired = t
|
||||
|
||||
if expired:
|
||||
if t.expired:
|
||||
raise Exception("Timer already expired")
|
||||
|
||||
if self.now >= time:
|
||||
t[2] = True
|
||||
callback()
|
||||
if self.now >= t.absolute_time:
|
||||
t.expired = True
|
||||
t.callback()
|
||||
else:
|
||||
self.timers.append(t)
|
||||
|
||||
for looped in self.loopers:
|
||||
func, interval, last, args, kwargs = looped
|
||||
if last + interval < self.now:
|
||||
func(*args, **kwargs)
|
||||
looped[2] = self.now
|
||||
if looped.last + looped.interval < self.now:
|
||||
looped.func(*looped.args, **looped.kwargs)
|
||||
looped.last = self.now
|
||||
|
||||
def advance_time_msec(self, ms):
|
||||
def advance_time_msec(self, ms: float) -> None:
|
||||
self.advance_time(ms / 1000.0)
|
||||
|
||||
def time_bound_deferred(self, d, *args, **kwargs):
|
||||
# We don't bother timing things out for now.
|
||||
return d
|
||||
|
||||
|
||||
async def create_room(hs, room_id: str, creator_id: str):
|
||||
async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
|
||||
"""Creates and persist a creation event for the given room"""
|
||||
|
||||
persistence_store = hs.get_storage_controllers().persistence
|
||||
assert persistence_store is not None
|
||||
store = hs.get_datastores().main
|
||||
event_builder_factory = hs.get_event_builder_factory()
|
||||
event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
Reference in New Issue
Block a user