diff --git a/changelog.d/19372.bugfix b/changelog.d/19372.bugfix new file mode 100644 index 0000000000..5e09e1c6b9 --- /dev/null +++ b/changelog.d/19372.bugfix @@ -0,0 +1 @@ +Always rollback database transaction when retrying (avoid orphaned connections). diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 2d5e1d3c48..6e38b55686 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -800,9 +800,8 @@ class DatabasePool: transaction_logger.debug("[TXN START] {%s}", name) try: - i = 0 - N = 5 - while True: + MAX_NUMBER_OF_ATTEMPTS = 5 + for attempt_number in range(1, MAX_NUMBER_OF_ATTEMPTS + 1): cursor = conn.cursor( txn_name=name, after_callbacks=after_callbacks, @@ -828,34 +827,37 @@ class DatabasePool: "[TXN OPERROR] {%s} %s %d/%d", name, e, - i, - N, + attempt_number, + MAX_NUMBER_OF_ATTEMPTS, ) - if i < N: - i += 1 - try: - with opentracing.start_active_span("db.rollback"): - conn.rollback() - except self.engine.module.Error as e1: - transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1) + try: + with opentracing.start_active_span("db.rollback"): + conn.rollback() + except self.engine.module.Error as e1: + transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1) + # Keep retrying if we haven't reached max attempts + if attempt_number < MAX_NUMBER_OF_ATTEMPTS: continue raise except self.engine.module.DatabaseError as e: if self.engine.is_deadlock(e): transaction_logger.warning( - "[TXN DEADLOCK] {%s} %d/%d", name, i, N + "[TXN DEADLOCK] {%s} %d/%d", + name, + attempt_number, + MAX_NUMBER_OF_ATTEMPTS, ) - if i < N: - i += 1 - try: - with opentracing.start_active_span("db.rollback"): - conn.rollback() - except self.engine.module.Error as e1: - transaction_logger.warning( - "[TXN EROLL] {%s} %s", - name, - e1, - ) + try: + with opentracing.start_active_span("db.rollback"): + conn.rollback() + except self.engine.module.Error as e1: + transaction_logger.warning( + "[TXN EROLL] {%s} %s", + name, + e1, + ) + # Keep retrying if we haven't reached max attempts + if attempt_number < MAX_NUMBER_OF_ATTEMPTS: continue raise finally: @@ -892,6 +894,21 @@ class DatabasePool: # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465 # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236 cursor.close() + else: + # To appease the linter, we mark this as unreachable. Unreachable + # because we expect the code above to always return from the loop or + # raise an exception. `mypy` just doesn't understand our logic above. + # + # The Python docs + # (https://typing.python.org/en/latest/guides/unreachable.html#marking-code-as-unreachable) + # suggest `assert False` but that also gets linted to suggest raising an + # `AssertionError`. I'm not sure this has the same "unreachable" + # semantics, but it works anyway to solve the linter complaint because + # we're raising an exception. + raise AssertionError( + "We expect this to be unreachable because the code above should either return or raise. " + "This is a logic error in Synapse itself." + ) except Exception as e: transaction_logger.debug("[TXN FAIL] {%s} %s", name, e) raise diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index ffcff3363f..6213abd753 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -20,7 +20,9 @@ # from typing import Callable -from unittest.mock import Mock, call +from unittest.mock import Mock, call, patch + +import attr from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred @@ -140,6 +142,14 @@ class ExecuteScriptTestCase(unittest.HomeserverTestCase): ) +@attr.s(slots=True, auto_attribs=True) +class TransactionMocks: + after_callback: Mock + exception_callback: Mock + commit: Mock + rollback: Mock + + class CallbacksTestCase(unittest.HomeserverTestCase): """Tests for transaction callbacks.""" @@ -149,7 +159,7 @@ class CallbacksTestCase(unittest.HomeserverTestCase): def _run_interaction( self, func: Callable[[LoggingTransaction], object] - ) -> tuple[Mock, Mock]: + ) -> TransactionMocks: """Run the given function in a database transaction, with callbacks registered. Args: @@ -163,53 +173,111 @@ class CallbacksTestCase(unittest.HomeserverTestCase): after_callback = Mock() exception_callback = Mock() + # Track commit/rollback calls on the LoggingDatabaseConnection used + # for the transaction so tests can assert whether attempts committed + # or rolled back. + commit_mock = Mock() + rollback_mock = Mock() + def _test_txn(txn: LoggingTransaction) -> None: txn.call_after(after_callback, 123, 456, extra=789) txn.call_on_exception(exception_callback, 987, 654, extra=321) func(txn) + # Wrap the real commit/rollback so we record calls but still perform + # the original behaviour. + orig_commit = LoggingDatabaseConnection.commit + orig_rollback = LoggingDatabaseConnection.rollback + + # type-ignore becauase we're just transparently passing through args/kwargs and + # returning whatever result that the original function does + def _commit(self, *a, **kw): # type: ignore[no-untyped-def] + commit_mock() + return orig_commit(self, *a, **kw) + + # type-ignore becauase we're just transparently passing through args/kwargs and + # returning whatever result that the original function does + def _rollback(self, *a, **kw): # type: ignore[no-untyped-def] + rollback_mock() + return orig_rollback(self, *a, **kw) + try: - self.get_success_or_raise( - self.db_pool.runInteraction("test_transaction", _test_txn) - ) + with ( + patch.object(LoggingDatabaseConnection, "commit", _commit), + patch.object(LoggingDatabaseConnection, "rollback", _rollback), + ): + self.get_success_or_raise( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) except Exception: pass - return after_callback, exception_callback + # FIXME: Sanity check that every transaction is either committed or rolled back, + # see https://github.com/element-hq/synapse/issues/19202 + # transaction_count = after_callback.call_count + exception_callback.call_count + # self.assertEqual( + # transaction_count, + # commit_mock.call_count + rollback_mock.call_count, + # "We expect every transaction attempt to either commit or rollback. " + # f"Saw {transaction_count} transactions, but only {commit_mock.call_count} commits and {rollback_mock.call_count} rollbacks", + # ) + + return TransactionMocks( + after_callback=after_callback, + exception_callback=exception_callback, + commit=commit_mock, + rollback=rollback_mock, + ) def test_after_callback(self) -> None: """Test that the after callback is called when a transaction succeeds.""" - after_callback, exception_callback = self._run_interaction(lambda txn: None) + txn_mocks = self._run_interaction(lambda txn: None) - after_callback.assert_called_once_with(123, 456, extra=789) - exception_callback.assert_not_called() + txn_mocks.after_callback.assert_called_once_with(123, 456, extra=789) + txn_mocks.exception_callback.assert_not_called() + + # Should have commited right away + self.assertEqual(txn_mocks.commit.call_count, 1) + # (nothing was rolled back) + self.assertEqual(txn_mocks.rollback.call_count, 0) def test_exception_callback(self) -> None: """Test that the exception callback is called when a transaction fails.""" _test_txn = Mock(side_effect=ZeroDivisionError) - after_callback, exception_callback = self._run_interaction(_test_txn) + txn_mocks = self._run_interaction(_test_txn) - after_callback.assert_not_called() - exception_callback.assert_called_once_with(987, 654, extra=321) + txn_mocks.after_callback.assert_not_called() + txn_mocks.exception_callback.assert_called_once_with(987, 654, extra=321) + + # Nothing should have committed. + self.assertEqual(txn_mocks.commit.call_count, 0) + # FIXME: Every transaction should have been rolled back, see + # https://github.com/element-hq/synapse/issues/19202 + # self.assertEqual(txn_mocks.rollback.call_count, 1) def test_failed_retry(self) -> None: """Test that the exception callback is called for every failed attempt.""" # Always raise an `OperationalError`. _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError) - after_callback, exception_callback = self._run_interaction(_test_txn) + txn_mocks = self._run_interaction(_test_txn) - after_callback.assert_not_called() - exception_callback.assert_has_calls( + txn_mocks.after_callback.assert_not_called() + txn_mocks.exception_callback.assert_has_calls( [ call(987, 654, extra=321), call(987, 654, extra=321), call(987, 654, extra=321), call(987, 654, extra=321), call(987, 654, extra=321), - call(987, 654, extra=321), ] ) - self.assertEqual(exception_callback.call_count, 6) # no additional calls + # no additional calls + self.assertEqual(txn_mocks.exception_callback.call_count, 5) + + # Nothing should have committed. + self.assertEqual(txn_mocks.commit.call_count, 0) + # Every transaction should have been rolled back. + self.assertEqual(txn_mocks.rollback.call_count, 5) def test_successful_retry(self) -> None: """Test callbacks for a failed transaction followed by a successful attempt.""" @@ -217,19 +285,25 @@ class CallbacksTestCase(unittest.HomeserverTestCase): _test_txn = Mock( side_effect=[self.db_pool.engine.module.OperationalError, None] ) - after_callback, exception_callback = self._run_interaction(_test_txn) + txn_mocks = self._run_interaction(_test_txn) # Calling both `after_callback`s when the first attempt failed is rather # surprising (https://github.com/matrix-org/synapse/issues/12184). # Let's document the behaviour in a test. - after_callback.assert_has_calls( + txn_mocks.after_callback.assert_has_calls( [ call(123, 456, extra=789), call(123, 456, extra=789), ] ) - self.assertEqual(after_callback.call_count, 2) # no additional calls - exception_callback.assert_not_called() + # no additional calls + self.assertEqual(txn_mocks.after_callback.call_count, 2) + txn_mocks.exception_callback.assert_not_called() + + # The last attempt should have committed. + self.assertEqual(txn_mocks.commit.call_count, 1) + # The first attempt should have been rolled back. + self.assertEqual(txn_mocks.rollback.call_count, 1) class CancellationTestCase(unittest.HomeserverTestCase): @@ -282,7 +356,6 @@ class CancellationTestCase(unittest.HomeserverTestCase): call(987, 654, extra=321), call(987, 654, extra=321), call(987, 654, extra=321), - call(987, 654, extra=321), ] ) - self.assertEqual(exception_callback.call_count, 6) # no additional calls + self.assertEqual(exception_callback.call_count, 5) # no additional calls