Compare commits
13 Commits
v1.140.0rc
...
squah/canc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9ee8761934 | ||
|
|
97d7240725 | ||
|
|
8594175400 | ||
|
|
3d01893d4e | ||
|
|
1b7712611c | ||
|
|
a1ef8ca71b | ||
|
|
93073a4b11 | ||
|
|
c3ed401754 | ||
|
|
b245f7aa1b | ||
|
|
c93a1aeae9 | ||
|
|
0e118a09b7 | ||
|
|
2b5f3ed4ce | ||
|
|
7b19bc68ce |
1
changelog.d/12120.misc
Normal file
1
changelog.d/12120.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add support for cancellation to `ReadWriteLock`.
|
||||
@@ -350,7 +350,7 @@ class PaginationHandler:
|
||||
"""
|
||||
self._purges_in_progress_by_room.add(room_id)
|
||||
try:
|
||||
with await self.pagination_lock.write(room_id):
|
||||
async with self.pagination_lock.write(room_id):
|
||||
await self.storage.purge_events.purge_history(
|
||||
room_id, token, delete_local_events
|
||||
)
|
||||
@@ -406,7 +406,7 @@ class PaginationHandler:
|
||||
room_id: room to be purged
|
||||
force: set true to skip checking for joined users.
|
||||
"""
|
||||
with await self.pagination_lock.write(room_id):
|
||||
async with self.pagination_lock.write(room_id):
|
||||
# first check that we have no users in this room
|
||||
if not force:
|
||||
joined = await self.store.is_host_joined(room_id, self._server_name)
|
||||
@@ -448,7 +448,7 @@ class PaginationHandler:
|
||||
|
||||
room_token = from_token.room_key
|
||||
|
||||
with await self.pagination_lock.read(room_id):
|
||||
async with self.pagination_lock.read(room_id):
|
||||
(
|
||||
membership,
|
||||
member_event_id,
|
||||
@@ -615,7 +615,7 @@ class PaginationHandler:
|
||||
|
||||
self._purges_in_progress_by_room.add(room_id)
|
||||
try:
|
||||
with await self.pagination_lock.write(room_id):
|
||||
async with self.pagination_lock.write(room_id):
|
||||
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
|
||||
self._delete_by_id[
|
||||
delete_id
|
||||
|
||||
@@ -41,6 +41,7 @@ from prometheus_client import Histogram
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.config.database import DatabaseConnectionConfig
|
||||
@@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.background_updates import BackgroundUpdater
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -732,34 +734,47 @@ class DatabasePool:
|
||||
Returns:
|
||||
The result of func
|
||||
"""
|
||||
after_callbacks: List[_CallbackListEntry] = []
|
||||
exception_callbacks: List[_CallbackListEntry] = []
|
||||
|
||||
if not current_context():
|
||||
logger.warning("Starting db txn '%s' from sentinel context", desc)
|
||||
async def _runInteraction() -> R:
|
||||
after_callbacks: List[_CallbackListEntry] = []
|
||||
exception_callbacks: List[_CallbackListEntry] = []
|
||||
|
||||
try:
|
||||
with opentracing.start_active_span(f"db.{desc}"):
|
||||
result = await self.runWithConnection(
|
||||
self.new_transaction,
|
||||
desc,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
db_autocommit=db_autocommit,
|
||||
isolation_level=isolation_level,
|
||||
**kwargs,
|
||||
)
|
||||
if not current_context():
|
||||
logger.warning("Starting db txn '%s' from sentinel context", desc)
|
||||
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
after_callback(*after_args, **after_kwargs)
|
||||
except Exception:
|
||||
for after_callback, after_args, after_kwargs in exception_callbacks:
|
||||
after_callback(*after_args, **after_kwargs)
|
||||
raise
|
||||
try:
|
||||
with opentracing.start_active_span(f"db.{desc}"):
|
||||
result = await self.runWithConnection(
|
||||
self.new_transaction,
|
||||
desc,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
db_autocommit=db_autocommit,
|
||||
isolation_level=isolation_level,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return cast(R, result)
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
after_callback(*after_args, **after_kwargs)
|
||||
|
||||
return cast(R, result)
|
||||
except Exception:
|
||||
for after_callback, after_args, after_kwargs in exception_callbacks:
|
||||
after_callback(*after_args, **after_kwargs)
|
||||
raise
|
||||
|
||||
# To handle cancellation, we ensure that `after_callback`s and
|
||||
# `exception_callback`s are always run, since the transaction will complete
|
||||
# on another thread regardless of cancellation.
|
||||
#
|
||||
# We also wait until everything above is done before releasing the
|
||||
# `CancelledError`, so that logging contexts won't get used after they have been
|
||||
# finished.
|
||||
return await delay_cancellation(
|
||||
defer.ensureDeferred(_runInteraction()), all=True
|
||||
)
|
||||
|
||||
async def runWithConnection(
|
||||
self,
|
||||
|
||||
@@ -18,9 +18,10 @@ import collections
|
||||
import inspect
|
||||
import itertools
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
@@ -40,7 +41,7 @@ from typing import (
|
||||
)
|
||||
|
||||
import attr
|
||||
from typing_extensions import ContextManager, Literal
|
||||
from typing_extensions import AsyncContextManager, Literal
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import CancelledError
|
||||
@@ -491,7 +492,7 @@ class ReadWriteLock:
|
||||
|
||||
Example:
|
||||
|
||||
with await read_write_lock.read("test_key"):
|
||||
async with read_write_lock.read("test_key"):
|
||||
# do some work
|
||||
"""
|
||||
|
||||
@@ -514,7 +515,7 @@ class ReadWriteLock:
|
||||
# Latest writer queued
|
||||
self.key_to_current_writer: Dict[str, defer.Deferred] = {}
|
||||
|
||||
async def read(self, key: str) -> ContextManager:
|
||||
def read(self, key: str) -> AsyncContextManager:
|
||||
new_defer: "defer.Deferred[None]" = defer.Deferred()
|
||||
|
||||
curr_readers = self.key_to_current_readers.setdefault(key, set())
|
||||
@@ -522,14 +523,16 @@ class ReadWriteLock:
|
||||
|
||||
curr_readers.add(new_defer)
|
||||
|
||||
# We wait for the latest writer to finish writing. We can safely ignore
|
||||
# any existing readers... as they're readers.
|
||||
if curr_writer:
|
||||
await make_deferred_yieldable(curr_writer)
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager() -> Iterator[None]:
|
||||
@asynccontextmanager
|
||||
async def _ctx_manager() -> AsyncIterator[None]:
|
||||
try:
|
||||
# We wait for the latest writer to finish writing. We can safely ignore
|
||||
# any existing readers... as they're readers.
|
||||
# May raise a `CancelledError` if the `Deferred` wrapping us is
|
||||
# cancelled. The `Deferred` we are waiting on must not be cancelled,
|
||||
# since we do not own it.
|
||||
if curr_writer:
|
||||
await make_deferred_yieldable(stop_cancellation(curr_writer))
|
||||
yield
|
||||
finally:
|
||||
with PreserveLoggingContext():
|
||||
@@ -538,7 +541,7 @@ class ReadWriteLock:
|
||||
|
||||
return _ctx_manager()
|
||||
|
||||
async def write(self, key: str) -> ContextManager:
|
||||
def write(self, key: str) -> AsyncContextManager:
|
||||
new_defer: "defer.Deferred[None]" = defer.Deferred()
|
||||
|
||||
curr_readers = self.key_to_current_readers.get(key, set())
|
||||
@@ -549,25 +552,41 @@ class ReadWriteLock:
|
||||
if curr_writer:
|
||||
to_wait_on.append(curr_writer)
|
||||
|
||||
# We can clear the list of current readers since the new writer waits
|
||||
# We can clear the list of current readers since `new_defer` waits
|
||||
# for them to finish.
|
||||
curr_readers.clear()
|
||||
self.key_to_current_writer[key] = new_defer
|
||||
|
||||
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager() -> Iterator[None]:
|
||||
@asynccontextmanager
|
||||
async def _ctx_manager() -> AsyncIterator[None]:
|
||||
to_wait_on_defer = defer.gatherResults(to_wait_on)
|
||||
try:
|
||||
# Wait for all current readers and the latest writer to finish.
|
||||
# May raise a `CancelledError` if the `Deferred` wrapping us is
|
||||
# cancelled. The `Deferred`s we are waiting on must not be cancelled,
|
||||
# since we do not own them.
|
||||
await make_deferred_yieldable(stop_cancellation(to_wait_on_defer))
|
||||
yield
|
||||
finally:
|
||||
with PreserveLoggingContext():
|
||||
new_defer.callback(None)
|
||||
# `self.key_to_current_writer[key]` may be missing if there was another
|
||||
# writer waiting for us and it completed entirely within the
|
||||
# `new_defer.callback()` call above.
|
||||
if self.key_to_current_writer.get(key) == new_defer:
|
||||
self.key_to_current_writer.pop(key)
|
||||
|
||||
def release() -> None:
|
||||
with PreserveLoggingContext():
|
||||
new_defer.callback(None)
|
||||
# `self.key_to_current_writer[key]` may be missing if there was another
|
||||
# writer waiting for us and it completed entirely within the
|
||||
# `new_defer.callback()` call above.
|
||||
if self.key_to_current_writer.get(key) == new_defer:
|
||||
self.key_to_current_writer.pop(key)
|
||||
|
||||
if to_wait_on_defer.called:
|
||||
release()
|
||||
else:
|
||||
# We don't have the lock yet, probably because we were cancelled
|
||||
# while waiting for it. We can't call `release()` yet, since
|
||||
# `new_defer` must only resolve once all previous readers and
|
||||
# writers have finished.
|
||||
# NB: `release()` won't have a logcontext in this path.
|
||||
to_wait_on_defer.addCallback(lambda _: release())
|
||||
|
||||
return _ctx_manager()
|
||||
|
||||
@@ -695,3 +714,59 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
||||
new_deferred: defer.Deferred[T] = defer.Deferred()
|
||||
deferred.chainDeferred(new_deferred)
|
||||
return new_deferred
|
||||
|
||||
|
||||
def delay_cancellation(deferred: "defer.Deferred[T]", all: bool) -> "defer.Deferred[T]":
|
||||
"""Delay cancellation of a `Deferred` until it resolves.
|
||||
|
||||
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
|
||||
resolve with a `CancelledError` until the original `Deferred` resolves.
|
||||
|
||||
Args:
|
||||
deferred: The `Deferred` to protect against cancellation. Must not follow the
|
||||
Synapse logcontext rules if `all` is `False`.
|
||||
all: `True` to delay multiple cancellations. `False` to delay only the first
|
||||
cancellation.
|
||||
|
||||
Returns:
|
||||
A new `Deferred`, which will contain the result of the original `Deferred`.
|
||||
The new `Deferred` will not propagate cancellation through to the original.
|
||||
When cancelled, the new `Deferred` will wait until the original `Deferred`
|
||||
resolves before failing with a `CancelledError`.
|
||||
|
||||
The new `Deferred` will only follow the Synapse logcontext rules if `all` is
|
||||
`True` and `deferred` follows the Synapse logcontext rules. Otherwise the new
|
||||
`Deferred` should be wrapped with `make_deferred_yieldable`.
|
||||
"""
|
||||
|
||||
def cancel_errback(failure: Failure) -> Union[Failure, "defer.Deferred[T]"]:
|
||||
"""Insert another `Deferred` into the chain to delay cancellation.
|
||||
|
||||
Called when the original `Deferred` resolves or the new `Deferred` is
|
||||
cancelled.
|
||||
"""
|
||||
failure.trap(CancelledError)
|
||||
|
||||
if deferred.called and not deferred.paused:
|
||||
# The `CancelledError` came from the original `Deferred`. Pass it through.
|
||||
return failure
|
||||
|
||||
# Construct another `Deferred` that will only fail with the `CancelledError`
|
||||
# once the original `Deferred` resolves.
|
||||
delay_deferred: "defer.Deferred[T]" = defer.Deferred()
|
||||
deferred.chainDeferred(delay_deferred)
|
||||
|
||||
if all:
|
||||
# Intercept cancellations recursively. Each cancellation will cause another
|
||||
# `Deferred` to be inserted into the chain.
|
||||
delay_deferred.addErrback(cancel_errback)
|
||||
|
||||
# Override the result with the `CancelledError`.
|
||||
delay_deferred.addBoth(lambda _: failure)
|
||||
|
||||
return delay_deferred
|
||||
|
||||
new_deferred: "defer.Deferred[T]" = defer.Deferred()
|
||||
deferred.chainDeferred(new_deferred)
|
||||
new_deferred.addErrback(cancel_errback)
|
||||
return new_deferred
|
||||
|
||||
@@ -40,6 +40,7 @@ from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import delay_cancellation
|
||||
from synapse.util.caches.deferred_cache import DeferredCache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
|
||||
@@ -322,6 +323,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
|
||||
ret = cache.set(cache_key, ret, callback=invalidate_callback)
|
||||
|
||||
# We started a new call to `self.orig`, so we must always wait for it to
|
||||
# complete. Otherwise we might mark our current logging context as
|
||||
# finished while `self.orig` is still using it in the background.
|
||||
ret = delay_cancellation(ret, all=True)
|
||||
|
||||
return make_deferred_yieldable(ret)
|
||||
|
||||
wrapped = cast(_CachedFunction, _wrapped)
|
||||
@@ -482,6 +488,11 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
|
||||
lambda _: results, unwrapFirstError
|
||||
)
|
||||
if missing:
|
||||
# We started a new call to `self.orig`, so we must always wait for it to
|
||||
# complete. Otherwise we might mark our current logging context as
|
||||
# finished while `self.orig` is still using it in the background.
|
||||
d = delay_cancellation(d, all=True)
|
||||
return make_deferred_yieldable(d)
|
||||
else:
|
||||
return defer.succeed(results)
|
||||
|
||||
@@ -12,29 +12,183 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage.database import make_tuple_comparison_clause
|
||||
from synapse.storage.engines import BaseDatabaseEngine
|
||||
from typing import Callable, NoReturn, Tuple
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import CancelledError, Deferred
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from synapse.logging.context import LoggingContext
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingTransaction,
|
||||
make_tuple_comparison_clause,
|
||||
)
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
def _stub_db_engine(**kwargs) -> BaseDatabaseEngine:
|
||||
# returns a DatabaseEngine, circumventing the abc mechanism
|
||||
# any kwargs are set as attributes on the class before instantiating it
|
||||
t = type(
|
||||
"TestBaseDatabaseEngine",
|
||||
(BaseDatabaseEngine,),
|
||||
dict(BaseDatabaseEngine.__dict__),
|
||||
)
|
||||
# defeat the abc mechanism
|
||||
t.__abstractmethods__ = set()
|
||||
for k, v in kwargs.items():
|
||||
setattr(t, k, v)
|
||||
return t(None, None)
|
||||
|
||||
|
||||
class TupleComparisonClauseTestCase(unittest.TestCase):
|
||||
def test_native_tuple_comparison(self):
|
||||
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
|
||||
self.assertEqual(clause, "(a,b) > (?,?)")
|
||||
self.assertEqual(args, [1, 2])
|
||||
|
||||
|
||||
class CallbacksTestCase(unittest.HomeserverTestCase):
|
||||
"""Tests for transaction callbacks."""
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.db_pool: DatabasePool = self.store.db_pool
|
||||
|
||||
def _run_interaction(
|
||||
self, func: Callable[[LoggingTransaction, int], None]
|
||||
) -> Tuple[Mock, Mock]:
|
||||
"""Run the given function in a database transaction, with callbacks registered.
|
||||
|
||||
Args:
|
||||
func: The function to be run in a transaction. The transaction will be
|
||||
retried if `func` raises an `OperationalError`.
|
||||
|
||||
Returns:
|
||||
Two mocks, which were registered as an `after_callback` and an
|
||||
`exception_callback` respectively, on every transaction attempt.
|
||||
"""
|
||||
after_callback = Mock()
|
||||
exception_callback = Mock()
|
||||
|
||||
def _test_txn(txn: LoggingTransaction) -> None:
|
||||
txn.call_after(after_callback, 123, 456, extra=789)
|
||||
txn.call_on_exception(exception_callback, 987, 654, extra=321)
|
||||
func(txn)
|
||||
|
||||
try:
|
||||
self.get_success_or_raise(
|
||||
self.db_pool.runInteraction("test_transaction", _test_txn)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return after_callback, exception_callback
|
||||
|
||||
def test_after_callback(self) -> None:
|
||||
"""Test that the after callback is called when a transaction succeeds."""
|
||||
after_callback, exception_callback = self._run_interaction(lambda txn: None)
|
||||
|
||||
after_callback.assert_called_once_with(123, 456, extra=789)
|
||||
exception_callback.assert_not_called()
|
||||
|
||||
def test_exception_callback(self) -> None:
|
||||
"""Test that the exception callback is called when a transaction fails."""
|
||||
after_callback, exception_callback = self._run_interaction(lambda txn: 1 / 0)
|
||||
|
||||
after_callback.assert_not_called()
|
||||
exception_callback.assert_called_once_with(987, 654, extra=321)
|
||||
|
||||
def test_failed_retry(self) -> None:
|
||||
"""Test that the exception callback is called for every failed attempt."""
|
||||
|
||||
def _test_txn(txn: LoggingTransaction) -> NoReturn:
|
||||
"""Simulate a retryable failure on every attempt."""
|
||||
raise self.db_pool.engine.module.OperationalError()
|
||||
|
||||
after_callback, exception_callback = self._run_interaction(_test_txn)
|
||||
|
||||
after_callback.assert_not_called()
|
||||
exception_callback.assert_has_calls(
|
||||
[
|
||||
((987, 654), {"extra": 321}),
|
||||
((987, 654), {"extra": 321}),
|
||||
((987, 654), {"extra": 321}),
|
||||
((987, 654), {"extra": 321}),
|
||||
((987, 654), {"extra": 321}),
|
||||
((987, 654), {"extra": 321}),
|
||||
]
|
||||
)
|
||||
self.assertEqual(exception_callback.call_count, 6) # no additional calls
|
||||
|
||||
def test_successful_retry(self) -> None:
|
||||
"""Test callbacks for a failed transaction followed by a successful attempt."""
|
||||
first_attempt = True
|
||||
|
||||
def _test_txn(txn: LoggingTransaction) -> None:
|
||||
"""Simulate a retryable failure on the first attempt only."""
|
||||
nonlocal first_attempt
|
||||
if first_attempt:
|
||||
first_attempt = False
|
||||
raise self.db_pool.engine.module.OperationalError()
|
||||
else:
|
||||
return None
|
||||
|
||||
after_callback, exception_callback = self._run_interaction(_test_txn)
|
||||
|
||||
# Calling both `after_callback`s when the first attempt failed is rather
|
||||
# dubious. But let's document the behaviour in a test.
|
||||
after_callback.assert_has_calls(
|
||||
[
|
||||
((123, 456), {"extra": 789}),
|
||||
((123, 456), {"extra": 789}),
|
||||
]
|
||||
)
|
||||
self.assertEqual(after_callback.call_count, 2) # no additional calls
|
||||
exception_callback.assert_not_called()
|
||||
|
||||
|
||||
class CancellationTestCase(unittest.HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = hs.get_datastores().main
|
||||
self.db_pool: DatabasePool = self.store.db_pool
|
||||
|
||||
def test_after_callback(self) -> None:
|
||||
"""Test that the after callback is called when a transaction succeeds."""
|
||||
d: "Deferred[None]"
|
||||
after_callback = Mock()
|
||||
exception_callback = Mock()
|
||||
|
||||
def _test_txn(txn: LoggingTransaction) -> None:
|
||||
txn.call_after(after_callback, 123, 456, extra=789)
|
||||
txn.call_on_exception(exception_callback, 987, 654, extra=321)
|
||||
d.cancel()
|
||||
|
||||
d = defer.ensureDeferred(
|
||||
self.db_pool.runInteraction("test_transaction", _test_txn)
|
||||
)
|
||||
self.get_failure(d, CancelledError)
|
||||
|
||||
after_callback.assert_called_once_with(123, 456, extra=789)
|
||||
exception_callback.assert_not_called()
|
||||
|
||||
def test_exception_callback(self) -> None:
|
||||
"""Test that the exception callback is called when a transaction fails."""
|
||||
d: "Deferred[None]"
|
||||
after_callback = Mock()
|
||||
exception_callback = Mock()
|
||||
|
||||
def _test_txn(txn: LoggingTransaction) -> None:
|
||||
txn.call_after(after_callback, 123, 456, extra=789)
|
||||
txn.call_on_exception(exception_callback, 987, 654, extra=321)
|
||||
d.cancel()
|
||||
# Simulate a retryable failure on every attempt.
|
||||
raise self.db_pool.engine.module.OperationalError()
|
||||
|
||||
d = defer.ensureDeferred(
|
||||
self.db_pool.runInteraction("test_transaction", _test_txn)
|
||||
)
|
||||
self.get_failure(d, CancelledError)
|
||||
|
||||
after_callback.assert_not_called()
|
||||
exception_callback.assert_has_calls(
|
||||
[
|
||||
((987, 654), {"extra": 321}),
|
||||
((987, 654), {"extra": 321}),
|
||||
((987, 654), {"extra": 321}),
|
||||
((987, 654), {"extra": 321}),
|
||||
((987, 654), {"extra": 321}),
|
||||
((987, 654), {"extra": 321}),
|
||||
]
|
||||
)
|
||||
self.assertEqual(exception_callback.call_count, 6) # no additional calls
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import Set
|
||||
from unittest import mock
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.defer import CancelledError, Deferred
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.logging.context import (
|
||||
@@ -28,7 +28,7 @@ from synapse.logging.context import (
|
||||
make_deferred_yieldable,
|
||||
)
|
||||
from synapse.util.caches import descriptors
|
||||
from synapse.util.caches.descriptors import cached, lru_cache
|
||||
from synapse.util.caches.descriptors import cached, cachedList, lru_cache
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import get_awaitable_result
|
||||
@@ -415,6 +415,74 @@ class DescriptorTestCase(unittest.TestCase):
|
||||
obj.invalidate()
|
||||
top_invalidate.assert_called_once()
|
||||
|
||||
def test_cancel(self):
|
||||
"""Test that cancelling a lookup does not cancel other lookups"""
|
||||
complete_lookup: "Deferred[None]" = Deferred()
|
||||
|
||||
class Cls:
|
||||
@cached()
|
||||
async def fn(self, arg1):
|
||||
await complete_lookup
|
||||
return str(arg1)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
d1 = obj.fn(123)
|
||||
d2 = obj.fn(123)
|
||||
self.assertFalse(d1.called)
|
||||
self.assertFalse(d2.called)
|
||||
|
||||
# Cancel `d1`, which is the lookup that caused `fn` to run.
|
||||
d1.cancel()
|
||||
|
||||
# `d2` should complete normally.
|
||||
complete_lookup.callback(None)
|
||||
self.failureResultOf(d1, CancelledError)
|
||||
self.assertEqual(d2.result, "123")
|
||||
|
||||
def test_cancel_logcontexts(self):
|
||||
"""Test that cancellation does not break logcontexts.
|
||||
|
||||
* The `CancelledError` must be raised with the correct logcontext.
|
||||
* The inner lookup must not resume with a finished logcontext.
|
||||
* The inner lookup must not restore a finished logcontext when done.
|
||||
"""
|
||||
complete_lookup: "Deferred[None]" = Deferred()
|
||||
|
||||
class Cls:
|
||||
inner_context_was_finished = False
|
||||
|
||||
@cached()
|
||||
async def fn(self, arg1):
|
||||
await make_deferred_yieldable(complete_lookup)
|
||||
self.inner_context_was_finished = current_context().finished
|
||||
return str(arg1)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
async def do_lookup():
|
||||
with LoggingContext("c1") as c1:
|
||||
try:
|
||||
await obj.fn(123)
|
||||
self.fail("No CancelledError thrown")
|
||||
except CancelledError:
|
||||
self.assertEqual(
|
||||
current_context(),
|
||||
c1,
|
||||
"CancelledError was not raised with the correct logcontext",
|
||||
)
|
||||
# suppress the error and succeed
|
||||
|
||||
d = defer.ensureDeferred(do_lookup())
|
||||
d.cancel()
|
||||
|
||||
complete_lookup.callback(None)
|
||||
self.successResultOf(d)
|
||||
self.assertFalse(
|
||||
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
|
||||
)
|
||||
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||
|
||||
|
||||
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
|
||||
"""More tests for @cached
|
||||
@@ -787,3 +855,78 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||
obj.fn.invalidate((10, 2))
|
||||
invalidate0.assert_called_once()
|
||||
invalidate1.assert_called_once()
|
||||
|
||||
def test_cancel(self):
|
||||
"""Test that cancelling a lookup does not cancel other lookups"""
|
||||
complete_lookup: "Deferred[None]" = Deferred()
|
||||
|
||||
class Cls:
|
||||
@cached()
|
||||
def fn(self, arg1):
|
||||
pass
|
||||
|
||||
@cachedList("fn", "args")
|
||||
async def list_fn(self, args):
|
||||
await complete_lookup
|
||||
return {arg: str(arg) for arg in args}
|
||||
|
||||
obj = Cls()
|
||||
|
||||
d1 = obj.list_fn([123, 456])
|
||||
d2 = obj.list_fn([123, 456, 789])
|
||||
self.assertFalse(d1.called)
|
||||
self.assertFalse(d2.called)
|
||||
|
||||
d1.cancel()
|
||||
|
||||
# `d2` should complete normally.
|
||||
complete_lookup.callback(None)
|
||||
self.failureResultOf(d1, CancelledError)
|
||||
self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
|
||||
|
||||
def test_cancel_logcontexts(self):
|
||||
"""Test that cancellation does not break logcontexts.
|
||||
|
||||
* The `CancelledError` must be raised with the correct logcontext.
|
||||
* The inner lookup must not resume with a finished logcontext.
|
||||
* The inner lookup must not restore a finished logcontext when done.
|
||||
"""
|
||||
complete_lookup: "Deferred[None]" = Deferred()
|
||||
|
||||
class Cls:
|
||||
inner_context_was_finished = False
|
||||
|
||||
@cached()
|
||||
def fn(self, arg1):
|
||||
pass
|
||||
|
||||
@cachedList("fn", "args")
|
||||
async def list_fn(self, args):
|
||||
await make_deferred_yieldable(complete_lookup)
|
||||
self.inner_context_was_finished = current_context().finished
|
||||
return {arg: str(arg) for arg in args}
|
||||
|
||||
obj = Cls()
|
||||
|
||||
async def do_lookup():
|
||||
with LoggingContext("c1") as c1:
|
||||
try:
|
||||
await obj.list_fn([123])
|
||||
self.fail("No CancelledError thrown")
|
||||
except CancelledError:
|
||||
self.assertEqual(
|
||||
current_context(),
|
||||
c1,
|
||||
"CancelledError was not raised with the correct logcontext",
|
||||
)
|
||||
# suppress the error and succeed
|
||||
|
||||
d = defer.ensureDeferred(do_lookup())
|
||||
d.cancel()
|
||||
|
||||
complete_lookup.callback(None)
|
||||
self.successResultOf(d)
|
||||
self.assertFalse(
|
||||
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
|
||||
)
|
||||
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||
|
||||
@@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import AsyncContextManager, Callable, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.defer import CancelledError, Deferred
|
||||
|
||||
from synapse.util.async_helpers import ReadWriteLock
|
||||
|
||||
@@ -32,76 +34,120 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
|
||||
def test_rwlock(self):
|
||||
rwlock = ReadWriteLock()
|
||||
key = "key"
|
||||
|
||||
key = object()
|
||||
def start_reader_or_writer(
|
||||
read_or_write: Callable[[str], AsyncContextManager]
|
||||
) -> Tuple["Deferred[None]", "Deferred[None]"]:
|
||||
acquired_d: "Deferred[None]" = Deferred()
|
||||
release_d: "Deferred[None]" = Deferred()
|
||||
|
||||
async def action():
|
||||
async with read_or_write(key):
|
||||
acquired_d.callback(None)
|
||||
await release_d
|
||||
|
||||
defer.ensureDeferred(action())
|
||||
return acquired_d, release_d
|
||||
|
||||
ds = [
|
||||
rwlock.read(key), # 0
|
||||
rwlock.read(key), # 1
|
||||
rwlock.write(key), # 2
|
||||
rwlock.write(key), # 3
|
||||
rwlock.read(key), # 4
|
||||
rwlock.read(key), # 5
|
||||
rwlock.write(key), # 6
|
||||
start_reader_or_writer(rwlock.read), # 0
|
||||
start_reader_or_writer(rwlock.read), # 1
|
||||
start_reader_or_writer(rwlock.write), # 2
|
||||
start_reader_or_writer(rwlock.write), # 3
|
||||
start_reader_or_writer(rwlock.read), # 4
|
||||
start_reader_or_writer(rwlock.read), # 5
|
||||
start_reader_or_writer(rwlock.write), # 6
|
||||
]
|
||||
ds = [defer.ensureDeferred(d) for d in ds]
|
||||
# `Deferred`s that resolve when each reader or writer acquires the lock.
|
||||
acquired_ds = [acquired_d for acquired_d, _release_d in ds]
|
||||
# `Deferred`s that will trigger the release of locks when resolved.
|
||||
release_ds = [release_d for _acquired_d, release_d in ds]
|
||||
|
||||
self._assert_called_before_not_after(ds, 2)
|
||||
self._assert_called_before_not_after(acquired_ds, 2)
|
||||
|
||||
with ds[0].result:
|
||||
self._assert_called_before_not_after(ds, 2)
|
||||
self._assert_called_before_not_after(ds, 2)
|
||||
self._assert_called_before_not_after(acquired_ds, 2)
|
||||
release_ds[0].callback(None)
|
||||
self._assert_called_before_not_after(acquired_ds, 2)
|
||||
|
||||
with ds[1].result:
|
||||
self._assert_called_before_not_after(ds, 2)
|
||||
self._assert_called_before_not_after(ds, 3)
|
||||
self._assert_called_before_not_after(acquired_ds, 2)
|
||||
release_ds[1].callback(None)
|
||||
self._assert_called_before_not_after(acquired_ds, 3)
|
||||
|
||||
with ds[2].result:
|
||||
self._assert_called_before_not_after(ds, 3)
|
||||
self._assert_called_before_not_after(ds, 4)
|
||||
self._assert_called_before_not_after(acquired_ds, 3)
|
||||
release_ds[2].callback(None)
|
||||
self._assert_called_before_not_after(acquired_ds, 4)
|
||||
|
||||
with ds[3].result:
|
||||
self._assert_called_before_not_after(ds, 4)
|
||||
self._assert_called_before_not_after(ds, 6)
|
||||
self._assert_called_before_not_after(acquired_ds, 4)
|
||||
release_ds[3].callback(None)
|
||||
self._assert_called_before_not_after(acquired_ds, 6)
|
||||
|
||||
with ds[5].result:
|
||||
self._assert_called_before_not_after(ds, 6)
|
||||
self._assert_called_before_not_after(ds, 6)
|
||||
self._assert_called_before_not_after(acquired_ds, 6)
|
||||
release_ds[5].callback(None)
|
||||
self._assert_called_before_not_after(acquired_ds, 6)
|
||||
|
||||
with ds[4].result:
|
||||
self._assert_called_before_not_after(ds, 6)
|
||||
self._assert_called_before_not_after(ds, 7)
|
||||
self._assert_called_before_not_after(acquired_ds, 6)
|
||||
release_ds[4].callback(None)
|
||||
self._assert_called_before_not_after(acquired_ds, 7)
|
||||
|
||||
with ds[6].result:
|
||||
pass
|
||||
release_ds[6].callback(None)
|
||||
|
||||
d = defer.ensureDeferred(rwlock.write(key))
|
||||
self.assertTrue(d.called)
|
||||
with d.result:
|
||||
pass
|
||||
acquired_d, release_d = start_reader_or_writer(rwlock.write)
|
||||
self.assertTrue(acquired_d.called)
|
||||
release_d.callback(None)
|
||||
|
||||
d = defer.ensureDeferred(rwlock.read(key))
|
||||
self.assertTrue(d.called)
|
||||
with d.result:
|
||||
pass
|
||||
acquired_d, release_d = start_reader_or_writer(rwlock.read)
|
||||
self.assertTrue(acquired_d.called)
|
||||
release_d.callback(None)
|
||||
|
||||
def _start_reader_or_writer(
|
||||
self,
|
||||
read_or_write: Callable[[str], AsyncContextManager],
|
||||
key: str,
|
||||
name: str,
|
||||
) -> Tuple["Deferred[None]", "Deferred[None]"]:
|
||||
"""Starts a reader or writer which acquires the lock, blocks, then completes."""
|
||||
unblock_d: "Deferred[None]" = Deferred()
|
||||
|
||||
async def reader_or_writer():
|
||||
async with read_or_write(key):
|
||||
await unblock_d
|
||||
return f"{name} completed"
|
||||
|
||||
d = defer.ensureDeferred(reader_or_writer())
|
||||
return d, unblock_d
|
||||
|
||||
def _start_blocking_reader(
|
||||
self, rwlock: ReadWriteLock, key: str, name: str
|
||||
) -> Tuple["Deferred[None]", "Deferred[None]"]:
|
||||
"""Starts a reader which acquires the lock, blocks, then releases the lock."""
|
||||
return self._start_reader_or_writer(rwlock.read, key, name)
|
||||
|
||||
def _start_blocking_writer(
|
||||
self, rwlock: ReadWriteLock, key: str, name: str
|
||||
) -> Tuple["Deferred[None]", "Deferred[None]"]:
|
||||
"""Starts a writer which acquires the lock, blocks, then releases the lock."""
|
||||
return self._start_reader_or_writer(rwlock.write, key, name)
|
||||
|
||||
def _start_nonblocking_reader(self, rwlock: ReadWriteLock, key: str, name: str):
|
||||
"""Starts a reader which acquires the lock, then releases it immediately."""
|
||||
d, unblock_d = self._start_reader_or_writer(rwlock.read, key, name)
|
||||
unblock_d.callback(None)
|
||||
return d
|
||||
|
||||
def _start_nonblocking_writer(self, rwlock: ReadWriteLock, key: str, name: str):
|
||||
"""Starts a writer which acquires the lock, then releases it immediately."""
|
||||
d, unblock_d = self._start_reader_or_writer(rwlock.write, key, name)
|
||||
unblock_d.callback(None)
|
||||
return d
|
||||
|
||||
def test_lock_handoff_to_nonblocking_writer(self):
|
||||
"""Test a writer handing the lock to another writer that completes instantly."""
|
||||
rwlock = ReadWriteLock()
|
||||
key = "key"
|
||||
|
||||
unblock: "Deferred[None]" = Deferred()
|
||||
|
||||
async def blocking_write():
|
||||
with await rwlock.write(key):
|
||||
await unblock
|
||||
|
||||
async def nonblocking_write():
|
||||
with await rwlock.write(key):
|
||||
pass
|
||||
|
||||
d1 = defer.ensureDeferred(blocking_write())
|
||||
d2 = defer.ensureDeferred(nonblocking_write())
|
||||
d1, unblock = self._start_blocking_writer(rwlock, key, "write 1")
|
||||
d2 = self._start_nonblocking_writer(rwlock, key, "write 2")
|
||||
self.assertFalse(d1.called)
|
||||
self.assertFalse(d2.called)
|
||||
|
||||
@@ -111,5 +157,172 @@ class ReadWriteLockTestCase(unittest.TestCase):
|
||||
self.assertTrue(d2.called)
|
||||
|
||||
# The `ReadWriteLock` should operate as normal.
|
||||
d3 = defer.ensureDeferred(nonblocking_write())
|
||||
d3 = self._start_nonblocking_writer(rwlock, key, "write 3")
|
||||
self.assertTrue(d3.called)
|
||||
|
||||
def test_cancellation_while_holding_read_lock(self):
|
||||
"""Test cancellation while holding a read lock.
|
||||
|
||||
A waiting writer should be given the lock when the reader holding the lock is
|
||||
cancelled.
|
||||
"""
|
||||
rwlock = ReadWriteLock()
|
||||
key = "key"
|
||||
|
||||
# 1. A reader takes the lock and blocks.
|
||||
reader_d, _ = self._start_blocking_reader(rwlock, key, "read")
|
||||
|
||||
# 2. A writer waits for the reader to complete.
|
||||
writer_d = self._start_nonblocking_writer(rwlock, key, "write")
|
||||
self.assertFalse(writer_d.called)
|
||||
|
||||
# 3. The reader is cancelled.
|
||||
reader_d.cancel()
|
||||
self.failureResultOf(reader_d, CancelledError)
|
||||
|
||||
# 4. The writer should take the lock and complete.
|
||||
self.assertTrue(
|
||||
writer_d.called, "Writer is stuck waiting for a cancelled reader"
|
||||
)
|
||||
self.assertEqual("write completed", self.successResultOf(writer_d))
|
||||
|
||||
def test_cancellation_while_holding_write_lock(self):
|
||||
"""Test cancellation while holding a write lock.
|
||||
|
||||
A waiting reader should be given the lock when the writer holding the lock is
|
||||
cancelled.
|
||||
"""
|
||||
rwlock = ReadWriteLock()
|
||||
key = "key"
|
||||
|
||||
# 1. A writer takes the lock and blocks.
|
||||
writer_d, _ = self._start_blocking_writer(rwlock, key, "write")
|
||||
|
||||
# 2. A reader waits for the writer to complete.
|
||||
reader_d = self._start_nonblocking_reader(rwlock, key, "read")
|
||||
self.assertFalse(reader_d.called)
|
||||
|
||||
# 3. The writer is cancelled.
|
||||
writer_d.cancel()
|
||||
self.failureResultOf(writer_d, CancelledError)
|
||||
|
||||
# 4. The reader should take the lock and complete.
|
||||
self.assertTrue(
|
||||
reader_d.called, "Reader is stuck waiting for a cancelled writer"
|
||||
)
|
||||
self.assertEqual("read completed", self.successResultOf(reader_d))
|
||||
|
||||
def test_cancellation_while_waiting_for_read_lock(self):
|
||||
"""Test cancellation while waiting for a read lock.
|
||||
|
||||
Tests that cancelling a waiting reader:
|
||||
* does not cancel the writer it is waiting on
|
||||
* does not cancel the next writer waiting on it
|
||||
* does not allow the next writer to acquire the lock before an earlier writer
|
||||
has finished
|
||||
* does not keep the next writer waiting indefinitely
|
||||
|
||||
These correspond to the asserts with explicit messages.
|
||||
"""
|
||||
rwlock = ReadWriteLock()
|
||||
key = "key"
|
||||
|
||||
# 1. A writer takes the lock and blocks.
|
||||
writer1_d, unblock_writer1 = self._start_blocking_writer(rwlock, key, "write 1")
|
||||
|
||||
# 2. A reader waits for the first writer to complete.
|
||||
# This reader will be cancelled later.
|
||||
reader_d = self._start_nonblocking_reader(rwlock, key, "read")
|
||||
self.assertFalse(reader_d.called)
|
||||
|
||||
# 3. A second writer waits for both the first writer and the reader to complete.
|
||||
writer2_d = self._start_nonblocking_writer(rwlock, key, "write 2")
|
||||
self.assertFalse(writer2_d.called)
|
||||
|
||||
# 4. The waiting reader is cancelled.
|
||||
# Neither of the writers should be cancelled.
|
||||
# The second writer should still be waiting, but only on the first writer.
|
||||
reader_d.cancel()
|
||||
self.failureResultOf(reader_d, CancelledError)
|
||||
self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled")
|
||||
self.assertFalse(
|
||||
writer2_d.called,
|
||||
"Second writer was unexpectedly cancelled or given the lock before the "
|
||||
"first writer finished",
|
||||
)
|
||||
|
||||
# 5. Unblock the first writer, which should complete.
|
||||
unblock_writer1.callback(None)
|
||||
self.assertEqual("write 1 completed", self.successResultOf(writer1_d))
|
||||
|
||||
# 6. The second writer should take the lock and complete.
|
||||
self.assertTrue(
|
||||
writer2_d.called, "Second writer is stuck waiting for a cancelled reader"
|
||||
)
|
||||
self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
|
||||
|
||||
def test_cancellation_while_waiting_for_write_lock(self):
|
||||
"""Test cancellation while waiting for a write lock.
|
||||
|
||||
Tests that cancelling a waiting writer:
|
||||
* does not cancel the reader or writer it is waiting on
|
||||
* does not cancel the next writer waiting on it
|
||||
* does not allow the next writer to acquire the lock before an earlier reader
|
||||
and writer have finished
|
||||
* does not keep the next writer waiting indefinitely
|
||||
|
||||
These correspond to the asserts with explicit messages.
|
||||
"""
|
||||
rwlock = ReadWriteLock()
|
||||
key = "key"
|
||||
|
||||
# 1. A reader takes the lock and blocks.
|
||||
reader_d, unblock_reader = self._start_blocking_reader(rwlock, key, "read")
|
||||
|
||||
# 2. A writer waits for the reader to complete.
|
||||
writer1_d, unblock_writer1 = self._start_blocking_writer(rwlock, key, "write 1")
|
||||
|
||||
# 3. A second writer waits for both the reader and first writer to complete.
|
||||
# This writer will be cancelled later.
|
||||
writer2_d = self._start_nonblocking_writer(rwlock, key, "write 2")
|
||||
self.assertFalse(writer2_d.called)
|
||||
|
||||
# 4. A third writer waits for the second writer to complete.
|
||||
writer3_d = self._start_nonblocking_writer(rwlock, key, "write 3")
|
||||
self.assertFalse(writer3_d.called)
|
||||
|
||||
# 5. The second writer is cancelled.
|
||||
# The reader, first writer and third writer should not be cancelled.
|
||||
# The first writer should still be waiting on the reader.
|
||||
# The third writer should still be waiting, even though the second writer has
|
||||
# been cancelled.
|
||||
writer2_d.cancel()
|
||||
self.failureResultOf(writer2_d, CancelledError)
|
||||
self.assertFalse(reader_d.called, "Reader was unexpectedly cancelled")
|
||||
self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled")
|
||||
self.assertFalse(
|
||||
writer3_d.called,
|
||||
"Third writer was unexpectedly cancelled or given the lock before the first"
|
||||
"writer finished",
|
||||
)
|
||||
|
||||
# 6. Unblock the reader, which should complete.
|
||||
# The first writer should be given the lock and block.
|
||||
# The third writer should still be waiting.
|
||||
unblock_reader.callback(None)
|
||||
self.assertEqual("read completed", self.successResultOf(reader_d))
|
||||
self.assertFalse(
|
||||
writer3_d.called,
|
||||
"Third writer was unexpectedly given the lock before the first writer "
|
||||
"finished",
|
||||
)
|
||||
|
||||
# 7. Unblock the first writer, which should complete.
|
||||
unblock_writer1.callback(None)
|
||||
self.assertEqual("write 1 completed", self.successResultOf(writer1_d))
|
||||
|
||||
# 8. The third writer should take the lock and complete.
|
||||
self.assertTrue(
|
||||
writer3_d.called, "Third writer is stuck waiting for a cancelled writer"
|
||||
)
|
||||
self.assertEqual("write 3 completed", self.successResultOf(writer3_d))
|
||||
|
||||
Reference in New Issue
Block a user