diff --git a/tests/utils.py b/tests/utils.py index 260c2e91a2..41bf2d2e1b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -54,13 +54,47 @@ POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),) POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres" -def setupdb(): - # If we're using PostgreSQL, set up the db once - if USE_POSTGRES_FOR_TESTS: - # create a PostgresEngine - db_engine = create_engine({"name": "psycopg2", "args": {}}) +def setupdb() -> None: + """ + Set up a temporary database to run tests in. Only applicable to postgres, + which uses a persistent database server rather than a database in memory. + """ + # Setting up the database is only required when using postgres + if not USE_POSTGRES_FOR_TESTS: + return - # connect to postgres to create the base database. + # create a PostgresEngine + db_engine = create_engine({"name": "psycopg2", "args": {}}) + + # connect to postgres to create the base database. + db_conn = db_engine.module.connect( + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, + ) + db_conn.autocommit = True + cur = db_conn.cursor() + cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) + cur.execute( + "CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' " + "template=template0;" % (POSTGRES_BASE_DB,) + ) + cur.close() + db_conn.close() + + # Set up in the db + db_conn = db_engine.module.connect( + database=POSTGRES_BASE_DB, + user=POSTGRES_USER, + host=POSTGRES_HOST, + password=POSTGRES_PASSWORD, + ) + db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") + prepare_database(db_conn, db_engine, None) + db_conn.close() + + def _cleanup(): db_conn = db_engine.module.connect( user=POSTGRES_USER, host=POSTGRES_HOST, @@ -70,38 +104,10 @@ def setupdb(): db_conn.autocommit = True cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) - cur.execute( - "CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' " - "template=template0;" % (POSTGRES_BASE_DB,) - ) cur.close() db_conn.close() - # Set up in the db - db_conn = db_engine.module.connect( - database=POSTGRES_BASE_DB, - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - ) - db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests") - prepare_database(db_conn, db_engine, None) - db_conn.close() - - def _cleanup(): - db_conn = db_engine.module.connect( - user=POSTGRES_USER, - host=POSTGRES_HOST, - password=POSTGRES_PASSWORD, - dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE, - ) - db_conn.autocommit = True - cur = db_conn.cursor() - cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) - cur.close() - db_conn.close() - - atexit.register(_cleanup) + atexit.register(_cleanup) def default_config(name: str) -> Dict[str, Any]: