1
0

Always rollback transaction when retrying (#19372)

Previously, because `conn.rollback()` was inside the `if i < MAX_NUMBER_OF_RETRIES:` condition,
it never rolled back on the final retry.

Part of https://github.com/element-hq/synapse/issues/19202

There are other problems mentioned in
https://github.com/element-hq/synapse/issues/19202 but this is a nice
standalone change.
This commit is contained in:
Eric Eastwood
2026-01-15 19:35:51 -06:00
committed by GitHub
parent 6363d77ba2
commit 13c6476d6e
3 changed files with 138 additions and 47 deletions

1
changelog.d/19372.bugfix Normal file
View File

@@ -0,0 +1 @@
Always rollback database transaction when retrying (avoid orphaned connections).

View File

@@ -800,9 +800,8 @@ class DatabasePool:
transaction_logger.debug("[TXN START] {%s}", name)
try:
i = 0
N = 5
while True:
MAX_NUMBER_OF_ATTEMPTS = 5
for attempt_number in range(1, MAX_NUMBER_OF_ATTEMPTS + 1):
cursor = conn.cursor(
txn_name=name,
after_callbacks=after_callbacks,
@@ -828,34 +827,37 @@ class DatabasePool:
"[TXN OPERROR] {%s} %s %d/%d",
name,
e,
i,
N,
attempt_number,
MAX_NUMBER_OF_ATTEMPTS,
)
if i < N:
i += 1
try:
with opentracing.start_active_span("db.rollback"):
conn.rollback()
except self.engine.module.Error as e1:
transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
try:
with opentracing.start_active_span("db.rollback"):
conn.rollback()
except self.engine.module.Error as e1:
transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
# Keep retrying if we haven't reached max attempts
if attempt_number < MAX_NUMBER_OF_ATTEMPTS:
continue
raise
except self.engine.module.DatabaseError as e:
if self.engine.is_deadlock(e):
transaction_logger.warning(
"[TXN DEADLOCK] {%s} %d/%d", name, i, N
"[TXN DEADLOCK] {%s} %d/%d",
name,
attempt_number,
MAX_NUMBER_OF_ATTEMPTS,
)
if i < N:
i += 1
try:
with opentracing.start_active_span("db.rollback"):
conn.rollback()
except self.engine.module.Error as e1:
transaction_logger.warning(
"[TXN EROLL] {%s} %s",
name,
e1,
)
try:
with opentracing.start_active_span("db.rollback"):
conn.rollback()
except self.engine.module.Error as e1:
transaction_logger.warning(
"[TXN EROLL] {%s} %s",
name,
e1,
)
# Keep retrying if we haven't reached max attempts
if attempt_number < MAX_NUMBER_OF_ATTEMPTS:
continue
raise
finally:
@@ -892,6 +894,21 @@ class DatabasePool:
# [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
# [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
cursor.close()
else:
# To appease the linter, we mark this as unreachable. Unreachable
# because we expect the code above to always return from the loop or
# raise an exception. `mypy` just doesn't understand our logic above.
#
# The Python docs
# (https://typing.python.org/en/latest/guides/unreachable.html#marking-code-as-unreachable)
# suggest `assert False` but that also gets linted to suggest raising an
# `AssertionError`. I'm not sure this has the same "unreachable"
# semantics, but it works anyway to solve the linter complaint because
# we're raising an exception.
raise AssertionError(
"We expect this to be unreachable because the code above should either return or raise. "
"This is a logic error in Synapse itself."
)
except Exception as e:
transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
raise

View File

@@ -20,7 +20,9 @@
#
from typing import Callable
from unittest.mock import Mock, call
from unittest.mock import Mock, call, patch
import attr
from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
@@ -140,6 +142,14 @@ class ExecuteScriptTestCase(unittest.HomeserverTestCase):
)
@attr.s(slots=True, auto_attribs=True)
class TransactionMocks:
after_callback: Mock
exception_callback: Mock
commit: Mock
rollback: Mock
class CallbacksTestCase(unittest.HomeserverTestCase):
"""Tests for transaction callbacks."""
@@ -149,7 +159,7 @@ class CallbacksTestCase(unittest.HomeserverTestCase):
def _run_interaction(
self, func: Callable[[LoggingTransaction], object]
) -> tuple[Mock, Mock]:
) -> TransactionMocks:
"""Run the given function in a database transaction, with callbacks registered.
Args:
@@ -163,53 +173,111 @@ class CallbacksTestCase(unittest.HomeserverTestCase):
after_callback = Mock()
exception_callback = Mock()
# Track commit/rollback calls on the LoggingDatabaseConnection used
# for the transaction so tests can assert whether attempts committed
# or rolled back.
commit_mock = Mock()
rollback_mock = Mock()
def _test_txn(txn: LoggingTransaction) -> None:
txn.call_after(after_callback, 123, 456, extra=789)
txn.call_on_exception(exception_callback, 987, 654, extra=321)
func(txn)
# Wrap the real commit/rollback so we record calls but still perform
# the original behaviour.
orig_commit = LoggingDatabaseConnection.commit
orig_rollback = LoggingDatabaseConnection.rollback
# type-ignore becauase we're just transparently passing through args/kwargs and
# returning whatever result that the original function does
def _commit(self, *a, **kw): # type: ignore[no-untyped-def]
commit_mock()
return orig_commit(self, *a, **kw)
# type-ignore becauase we're just transparently passing through args/kwargs and
# returning whatever result that the original function does
def _rollback(self, *a, **kw): # type: ignore[no-untyped-def]
rollback_mock()
return orig_rollback(self, *a, **kw)
try:
self.get_success_or_raise(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
with (
patch.object(LoggingDatabaseConnection, "commit", _commit),
patch.object(LoggingDatabaseConnection, "rollback", _rollback),
):
self.get_success_or_raise(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
except Exception:
pass
return after_callback, exception_callback
# FIXME: Sanity check that every transaction is either committed or rolled back,
# see https://github.com/element-hq/synapse/issues/19202
# transaction_count = after_callback.call_count + exception_callback.call_count
# self.assertEqual(
# transaction_count,
# commit_mock.call_count + rollback_mock.call_count,
# "We expect every transaction attempt to either commit or rollback. "
# f"Saw {transaction_count} transactions, but only {commit_mock.call_count} commits and {rollback_mock.call_count} rollbacks",
# )
return TransactionMocks(
after_callback=after_callback,
exception_callback=exception_callback,
commit=commit_mock,
rollback=rollback_mock,
)
def test_after_callback(self) -> None:
"""Test that the after callback is called when a transaction succeeds."""
after_callback, exception_callback = self._run_interaction(lambda txn: None)
txn_mocks = self._run_interaction(lambda txn: None)
after_callback.assert_called_once_with(123, 456, extra=789)
exception_callback.assert_not_called()
txn_mocks.after_callback.assert_called_once_with(123, 456, extra=789)
txn_mocks.exception_callback.assert_not_called()
# Should have commited right away
self.assertEqual(txn_mocks.commit.call_count, 1)
# (nothing was rolled back)
self.assertEqual(txn_mocks.rollback.call_count, 0)
def test_exception_callback(self) -> None:
"""Test that the exception callback is called when a transaction fails."""
_test_txn = Mock(side_effect=ZeroDivisionError)
after_callback, exception_callback = self._run_interaction(_test_txn)
txn_mocks = self._run_interaction(_test_txn)
after_callback.assert_not_called()
exception_callback.assert_called_once_with(987, 654, extra=321)
txn_mocks.after_callback.assert_not_called()
txn_mocks.exception_callback.assert_called_once_with(987, 654, extra=321)
# Nothing should have committed.
self.assertEqual(txn_mocks.commit.call_count, 0)
# FIXME: Every transaction should have been rolled back, see
# https://github.com/element-hq/synapse/issues/19202
# self.assertEqual(txn_mocks.rollback.call_count, 1)
def test_failed_retry(self) -> None:
"""Test that the exception callback is called for every failed attempt."""
# Always raise an `OperationalError`.
_test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError)
after_callback, exception_callback = self._run_interaction(_test_txn)
txn_mocks = self._run_interaction(_test_txn)
after_callback.assert_not_called()
exception_callback.assert_has_calls(
txn_mocks.after_callback.assert_not_called()
txn_mocks.exception_callback.assert_has_calls(
[
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
]
)
self.assertEqual(exception_callback.call_count, 6) # no additional calls
# no additional calls
self.assertEqual(txn_mocks.exception_callback.call_count, 5)
# Nothing should have committed.
self.assertEqual(txn_mocks.commit.call_count, 0)
# Every transaction should have been rolled back.
self.assertEqual(txn_mocks.rollback.call_count, 5)
def test_successful_retry(self) -> None:
"""Test callbacks for a failed transaction followed by a successful attempt."""
@@ -217,19 +285,25 @@ class CallbacksTestCase(unittest.HomeserverTestCase):
_test_txn = Mock(
side_effect=[self.db_pool.engine.module.OperationalError, None]
)
after_callback, exception_callback = self._run_interaction(_test_txn)
txn_mocks = self._run_interaction(_test_txn)
# Calling both `after_callback`s when the first attempt failed is rather
# surprising (https://github.com/matrix-org/synapse/issues/12184).
# Let's document the behaviour in a test.
after_callback.assert_has_calls(
txn_mocks.after_callback.assert_has_calls(
[
call(123, 456, extra=789),
call(123, 456, extra=789),
]
)
self.assertEqual(after_callback.call_count, 2) # no additional calls
exception_callback.assert_not_called()
# no additional calls
self.assertEqual(txn_mocks.after_callback.call_count, 2)
txn_mocks.exception_callback.assert_not_called()
# The last attempt should have committed.
self.assertEqual(txn_mocks.commit.call_count, 1)
# The first attempt should have been rolled back.
self.assertEqual(txn_mocks.rollback.call_count, 1)
class CancellationTestCase(unittest.HomeserverTestCase):
@@ -282,7 +356,6 @@ class CancellationTestCase(unittest.HomeserverTestCase):
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
call(987, 654, extra=321),
]
)
self.assertEqual(exception_callback.call_count, 6) # no additional calls
self.assertEqual(exception_callback.call_count, 5) # no additional calls