1
0

Compare commits

...

14 Commits

Author SHA1 Message Date
Andrew Morgan
c199b461fa Add types-psycopg2 package to optional dev dependencies 2021-11-16 00:38:30 +00:00
Andrew Morgan
5b1b067ea3 Changelog 2021-11-16 00:26:06 +00:00
Andrew Morgan
65a5eb5243 Remove some unused classes. 2021-11-16 00:19:55 +00:00
Andrew Morgan
9b2f54e859 Refactor 'setupdb' to remove an indent-level, and help explain its purpose a bit more. 2021-11-16 00:16:15 +00:00
Andrew Morgan
a977261542 Pass a fake Connection mock to SQLBaseStore, instead of None.
SQLBaseStore expects the connection to be non-Optional, so we can't just pass none.

Note that this parameter isn't even used by the function anyhow. I believe it's only there to satisfy the inherited class it's overriding.
2021-11-16 00:12:59 +00:00
Andrew Morgan
2c7e732233 Ignore monkey-patching functions. mypy doesn't currently allow this.
See https://github.com/python/mypy/issues/2427 for details/complaints.
2021-11-16 00:11:36 +00:00
Andrew Morgan
83be0d7e86 Remove mock of non-existent property. 2021-11-16 00:10:40 +00:00
Andrew Morgan
9f820ecdb0 Again, don't re-use variable names.
Interestingly I noticed that the reactor argument is never actually set by any calling functions. Should we just remove it?
2021-11-16 00:10:04 +00:00
Andrew Morgan
b3b8211a3c Don't re-use variable names with differing types. 2021-11-16 00:08:40 +00:00
Andrew Morgan
87efc2ea5b Make default_config only return a dict representation
This does make mypy happy, and does reduce a bit of confusion, though it's
a shame we have to duplicate the parsing code around everywhere now.

Is there a better way to solve this?
2021-11-16 00:06:12 +00:00
Andrew Morgan
ca6f4d6ff7 Add type annotation to Databases.databases
Awkwardly, this is a list of database configs, and should probably be renamed?
2021-11-15 23:59:25 +00:00
Andrew Morgan
95f71fb7c9 Denote HomeServer.DATASTORE_CLASS as an abstract property the Python 3 way
This seems to be required to make mypy happy about using it in inheriting classes
2021-11-15 23:58:13 +00:00
Andrew Morgan
74302bc097 Remove tests/utils.py from mypy exclude list 2021-11-15 23:56:57 +00:00
Andrew Morgan
6a91cfc3dc Fix types of do_execute
The control flow of this function was quite confusing. I refactored (fixed?) it a bit
and happened to fix the types along the way!
2021-11-15 23:56:25 +00:00
10 changed files with 97 additions and 179 deletions

1
changelog.d/11349.misc Normal file
View File

@@ -0,0 +1 @@
Fix type hints to allow `tests.utils` to pass `mypy`.

View File

@@ -142,7 +142,6 @@ exclude = (?x)
|tests/util/test_lrucache.py
|tests/util/test_rwlock.py
|tests/util/test_wheel_timer.py
|tests/utils.py
)$
[mypy-synapse.api.*]

View File

@@ -108,6 +108,7 @@ CONDITIONAL_REQUIREMENTS["mypy"] = [
"types-bleach>=4.1.0",
"types-jsonschema>=3.2.0",
"types-Pillow>=8.3.4",
"types-psycopg2>=2.9.1",
"types-pyOpenSSL>=20.0.7",
"types-PyYAML>=5.4.10",
"types-requests>=2.26.0",

View File

@@ -224,7 +224,10 @@ class HomeServer(metaclass=abc.ABCMeta):
# This is overridden in derived application classes
# (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
# instantiated during setup() for future return by get_datastore()
DATASTORE_CLASS = abc.abstractproperty()
@property
@abc.abstractmethod
def DATASTORE_CLASS(self):
pass
tls_server_context_factory: Optional[IOpenSSLContextFactory]

View File

@@ -300,7 +300,9 @@ class LoggingTransaction:
from psycopg2.extras import execute_values # type: ignore
return self._do_execute(
lambda *x: execute_values(self.txn, *x, fetch=fetch), sql, *args
lambda sql, *argslist: execute_values(self.txn, sql, argslist, fetch=fetch),
sql,
*args,
)
def execute(self, sql: str, *args: Any) -> None:

View File

@@ -52,7 +52,7 @@ class Databases(Generic[DataStoreT]):
# Note we pass in the main store class here as workers use a different main
# store.
self.databases = []
self.databases: List[DatabasePool] = []
main: Optional[DataStoreT] = None
state: Optional[StateGroupDataStore] = None
persist_events: Optional[PersistEventsStore] = None

View File

@@ -67,7 +67,7 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.mock_resolver = Mock()
config_dict = default_config("test", parse=False)
config_dict = default_config("test")
config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()]
self._config = config = HomeServerConfig()
@@ -957,7 +957,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.mock_resolver.resolve_service.side_effect = generate_resolve_service([])
self.reactor.lookups["testserv"] = "1.2.3.4"
config = default_config("test", parse=True)
config_dict = default_config("test")
config = HomeServerConfig()
config.parse_config_dict(config_dict)
# Build a new agent and WellKnownResolver with a different tls factory
tls_factory = FederationPolicyForHTTPS(config)

View File

@@ -18,9 +18,11 @@ from unittest.mock import Mock
from twisted.internet import defer
from synapse.config.homeserver import HomeServerConfig
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
from synapse.storage.engines import create_engine
from synapse.storage.types import Connection
from tests import unittest
from tests.utils import TestHomeServer, default_config
@@ -47,7 +49,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.db_pool.runWithConnection = runWithConnection
config = default_config(name="test", parse=True)
config_dict = default_config(name="test")
config = HomeServerConfig()
config.parse_config_dict(config_dict)
hs = TestHomeServer("test", config=config)
sqlite_config = {"name": "sqlite3"}
@@ -59,7 +64,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
db = DatabasePool(Mock(), Mock(config=sqlite_config), fake_engine)
db._db_pool = self.db_pool
self.datastore = SQLBaseStore(db, None, hs)
self.datastore = SQLBaseStore(db, Mock(spec=Connection), hs)
@defer.inlineCallbacks
def test_insert_1col(self):

View File

@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler
@@ -172,7 +173,11 @@ class StateTestCase(unittest.TestCase):
"hostname",
]
)
hs.config = default_config("tesths", True)
config_dict = default_config("tesths")
hs.config = HomeServerConfig()
hs.config.parse_config_dict(config_dict)
hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock()

View File

@@ -1,5 +1,4 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,14 +18,11 @@ import os
import time
import uuid
import warnings
from typing import Type
from unittest.mock import Mock, patch
from urllib import parse as urlparse
from twisted.internet import defer
from types import ModuleType
from typing import Any, Callable, Dict, Optional, Type
from unittest.mock import Mock
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
@@ -54,13 +50,47 @@ POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
def setupdb():
# If we're using PostgreSQL, set up the db once
if USE_POSTGRES_FOR_TESTS:
# create a PostgresEngine
db_engine = create_engine({"name": "psycopg2", "args": {}})
def setupdb() -> None:
"""
Set up a temporary database to run tests in. Only applicable to postgres,
which uses a persistent database server rather than a database in memory.
"""
# Setting up the database is only required when using postgres
if not USE_POSTGRES_FOR_TESTS:
return
# connect to postgres to create the base database.
# create a PostgresEngine
db_engine = create_engine({"name": "psycopg2", "args": {}})
# connect to postgres to create the base database.
db_conn = db_engine.module.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.execute(
"CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
"template=template0;" % (POSTGRES_BASE_DB,)
)
cur.close()
db_conn.close()
# Set up in the db
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(db_conn, db_engine, None)
db_conn.close()
def _cleanup():
db_conn = db_engine.module.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
@@ -70,43 +100,21 @@ def setupdb():
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.execute(
"CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
"template=template0;" % (POSTGRES_BASE_DB,)
)
cur.close()
db_conn.close()
# Set up in the db
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
prepare_database(db_conn, db_engine, None)
db_conn.close()
def _cleanup():
db_conn = db_engine.module.connect(
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.close()
db_conn.close()
atexit.register(_cleanup)
atexit.register(_cleanup)
def default_config(name, parse=False):
def default_config(name: str) -> Dict[str, Any]:
"""
Create a reasonable test config.
Args:
name: The value of the 'server_name' option in the returned config.
Returns:
A sensible, default homeserver config.
"""
config_dict = {
"server_name": name,
@@ -175,11 +183,6 @@ def default_config(name, parse=False):
"listeners": [{"port": 0, "type": "http"}],
}
if parse:
config = HomeServerConfig()
config.parse_config_dict(config_dict, "", "")
return config
return config_dict
@@ -188,10 +191,10 @@ class TestHomeServer(HomeServer):
def setup_test_homeserver(
cleanup_func,
name="test",
config=None,
reactor=None,
cleanup_func: Callable[[Callable], Any],
name: str = "test",
config: Optional[HomeServerConfig] = None,
reactor: Optional[ModuleType] = None,
homeserver_to_use: Type[HomeServer] = TestHomeServer,
**kwargs,
):
@@ -209,12 +212,15 @@ def setup_test_homeserver(
HomeserverTestCase.
"""
if reactor is None:
from twisted.internet import reactor
from twisted.internet import reactor as _reactor
reactor = _reactor
if config is None:
config = default_config(name, parse=True)
config_dict = default_config(name)
config.ldap_enabled = False
config = HomeServerConfig()
config.parse_config_dict(config_dict)
if "clock" not in kwargs:
kwargs["clock"] = MockClock()
@@ -222,7 +228,7 @@ def setup_test_homeserver(
if USE_POSTGRES_FOR_TESTS:
test_db = "synapse_test_%s" % uuid.uuid4().hex
database_config = {
database_config_dict = {
"name": "psycopg2",
"args": {
"database": test_db,
@@ -234,18 +240,18 @@ def setup_test_homeserver(
},
}
else:
database_config = {
database_config_dict = {
"name": "sqlite3",
"args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
}
if "db_txn_limit" in kwargs:
database_config["txn_limit"] = kwargs["db_txn_limit"]
database_config_dict["txn_limit"] = kwargs["db_txn_limit"]
database = DatabaseConnectionConfig("master", database_config)
config.database.databases = [database]
database_config = DatabaseConnectionConfig("master", database_config_dict)
config.database.databases = [database_config]
db_engine = create_engine(database.config)
db_engine = create_engine(database_config.config)
# Create the database before we actually try and connect to it, based off
# the template database we generate in setupdb()
@@ -278,7 +284,6 @@ def setup_test_homeserver(
# Mock TLS
hs.tls_server_context_factory = Mock()
hs.tls_client_options_factory = Mock()
hs.setup()
if homeserver_to_use == TestHomeServer:
@@ -338,12 +343,12 @@ def setup_test_homeserver(
async def hash(p):
return hashlib.md5(p.encode("utf8")).hexdigest()
hs.get_auth_handler().hash = hash
hs.get_auth_handler().hash = hash # type: ignore
async def validate_hash(p, h):
return hashlib.md5(p.encode("utf8")).hexdigest() == h
hs.get_auth_handler().validate_hash = validate_hash
hs.get_auth_handler().validate_hash = validate_hash # type: ignore
return hs
@@ -357,111 +362,6 @@ def mock_getRawHeaders(headers=None):
return getRawHeaders
# This is a mock /resource/ not an entire server
class MockHttpResource:
def __init__(self, prefix=""):
self.callbacks = [] # 3-tuple of method/pattern/function
self.prefix = prefix
def trigger_get(self, path):
return self.trigger(b"GET", path, None)
@patch("twisted.web.http.Request")
@defer.inlineCallbacks
def trigger(
self, http_method, path, content, mock_request, federation_auth_origin=None
):
"""Fire an HTTP event.
Args:
http_method : The HTTP method
path : The HTTP path
content : The HTTP body
mock_request : Mocked request to pass to the event so it can get
content.
federation_auth_origin (bytes|None): domain to authenticate as, for federation
Returns:
A tuple of (code, response)
Raises:
KeyError If no event is found which will handle the path.
"""
path = self.prefix + path
# annoyingly we return a twisted http request which has chained calls
# to get at the http content, hence mock it here.
mock_content = Mock()
config = {"read.return_value": content}
mock_content.configure_mock(**config)
mock_request.content = mock_content
mock_request.method = http_method.encode("ascii")
mock_request.uri = path.encode("ascii")
mock_request.getClientIP.return_value = "-"
headers = {}
if federation_auth_origin is not None:
headers[b"Authorization"] = [
b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
]
mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
# return the right path if the event requires it
mock_request.path = path
# add in query params to the right place
try:
mock_request.args = urlparse.parse_qs(path.split("?")[1])
mock_request.path = path.split("?")[0]
path = mock_request.path
except Exception:
pass
if isinstance(path, bytes):
path = path.decode("utf8")
for (method, pattern, func) in self.callbacks:
if http_method != method:
continue
matcher = pattern.match(path)
if matcher:
try:
args = [urlparse.unquote(u) for u in matcher.groups()]
(code, response) = yield defer.ensureDeferred(
func(mock_request, *args)
)
return code, response
except CodeMessageException as e:
return e.code, cs_error(e.msg, code=e.errcode)
raise KeyError("No event can handle %s" % path)
def register_paths(self, method, path_patterns, callback, servlet_name):
for path_pattern in path_patterns:
self.callbacks.append((method, path_pattern, callback))
class MockKey:
alg = "mock_alg"
version = "mock_version"
signature = b"\x9a\x87$"
@property
def verify_key(self):
return self
def sign(self, message):
return self
def verify(self, message, sig):
assert sig == b"\x9a\x87$"
def encode(self):
return b"<fake_encoded_key>"
class MockClock:
now = 1000