Compare commits

...

13 Commits

Author SHA1 Message Date
Sean Quah
9ee8761934 Handle cancellation in DatabasePool.runInteraction()
To handle cancellation, we ensure that `after_callback`s and
`exception_callback`s are always run, since the transaction will
complete on another thread regardless of cancellation.

We also wait until everything above is done before releasing the
`CancelledError`, so that logging contexts won't get used after they
have been finished.

Signed-off-by: Sean Quah <seanq@element.io>
2022-03-09 17:49:47 +00:00
Sean Quah
97d7240725 Add tests for database callbacks after cancellation
Signed-off-by: Sean Quah <seanq@element.io>
2022-03-09 17:39:42 +00:00
Sean Quah
8594175400 Add tests for database callbacks
Signed-off-by: Sean Quah <seanq@element.io>
2022-03-09 17:03:29 +00:00
Sean Quah
3d01893d4e Remove dead code
Signed-off-by: Sean Quah <seanq@element.io>
2022-03-09 16:38:33 +00:00
Sean Quah
1b7712611c Add tests for logcontexts during @cached and @cachedList cancellation 2022-03-08 17:43:36 +00:00
Sean Quah
a1ef8ca71b Add basic cancellation tests for @cached and @cachedList decorators 2022-03-08 17:43:36 +00:00
Sean Quah
93073a4b11 Fix logcontexts when @cached and @cachedList lookups are cancelled
`@cached` and `@cachedList` must wait until the wrapped method has
completed before raising `CancelledError`s, otherwise the wrapped method
will continue running in the background with a logging context that has
been marked as finished.
2022-03-08 17:42:09 +00:00
Sean Quah
c3ed401754 Add delay_cancellation utility function
`delay_cancellation` behaves like `stop_cancellation`, except it
delays `CancelledError`s until the original `Deferred` resolves.
This is handy for unifying cleanup paths and ensuring that uncancelled
coroutines don't use finished logcontexts.

Signed-off-by: Sean Quah <seanq@element.io>
2022-03-08 17:14:56 +00:00
Sean Quah
b245f7aa1b Add newsfile 2022-03-08 17:11:51 +00:00
Sean Quah
c93a1aeae9 Add ReadWriteLock cancellation tests 2022-03-08 17:11:51 +00:00
Sean Quah
0e118a09b7 Don't cancel Deferreds that readers or writers are waiting on 2022-03-08 17:11:51 +00:00
Sean Quah
2b5f3ed4ce Fix clean up when waiting readers or writers are cancelled
Signed-off-by: Sean Quah <seanq@element.io>
2022-03-08 17:11:51 +00:00
Sean Quah
7b19bc68ce Convert ReadWriteLock to use async context managers
Has the side effect of fixing clean up for readers cancelled while
waiting. Breaks the assumption that resolution of a writer `Deferred`
means that previous readers and writers have completed, which will be
fixed in the next commit.

Signed-off-by: Sean Quah <seanq@element.io>
2022-03-08 17:11:51 +00:00
8 changed files with 735 additions and 123 deletions

1
changelog.d/12120.misc Normal file
View File

@@ -0,0 +1 @@
Add support for cancellation to `ReadWriteLock`.

View File

@@ -350,7 +350,7 @@ class PaginationHandler:
"""
self._purges_in_progress_by_room.add(room_id)
try:
with await self.pagination_lock.write(room_id):
async with self.pagination_lock.write(room_id):
await self.storage.purge_events.purge_history(
room_id, token, delete_local_events
)
@@ -406,7 +406,7 @@ class PaginationHandler:
room_id: room to be purged
force: set true to skip checking for joined users.
"""
with await self.pagination_lock.write(room_id):
async with self.pagination_lock.write(room_id):
# first check that we have no users in this room
if not force:
joined = await self.store.is_host_joined(room_id, self._server_name)
@@ -448,7 +448,7 @@ class PaginationHandler:
room_token = from_token.room_key
with await self.pagination_lock.read(room_id):
async with self.pagination_lock.read(room_id):
(
membership,
member_event_id,
@@ -615,7 +615,7 @@ class PaginationHandler:
self._purges_in_progress_by_room.add(room_id)
try:
with await self.pagination_lock.write(room_id):
async with self.pagination_lock.write(room_id):
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN
self._delete_by_id[
delete_id

View File

@@ -41,6 +41,7 @@ from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.config.database import DatabaseConnectionConfig
@@ -55,6 +56,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.util.async_helpers import delay_cancellation
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -732,34 +734,47 @@ class DatabasePool:
Returns:
The result of func
"""
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
async def _runInteraction() -> R:
after_callbacks: List[_CallbackListEntry] = []
exception_callbacks: List[_CallbackListEntry] = []
try:
with opentracing.start_active_span(f"db.{desc}"):
result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
db_autocommit=db_autocommit,
isolation_level=isolation_level,
**kwargs,
)
if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
raise
try:
with opentracing.start_active_span(f"db.{desc}"):
result = await self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
db_autocommit=db_autocommit,
isolation_level=isolation_level,
**kwargs,
)
return cast(R, result)
for after_callback, after_args, after_kwargs in after_callbacks:
after_callback(*after_args, **after_kwargs)
return cast(R, result)
except Exception:
for after_callback, after_args, after_kwargs in exception_callbacks:
after_callback(*after_args, **after_kwargs)
raise
# To handle cancellation, we ensure that `after_callback`s and
# `exception_callback`s are always run, since the transaction will complete
# on another thread regardless of cancellation.
#
# We also wait until everything above is done before releasing the
# `CancelledError`, so that logging contexts won't get used after they have been
# finished.
return await delay_cancellation(
defer.ensureDeferred(_runInteraction()), all=True
)
async def runWithConnection(
self,

View File

@@ -18,9 +18,10 @@ import collections
import inspect
import itertools
import logging
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Collection,
@@ -40,7 +41,7 @@ from typing import (
)
import attr
from typing_extensions import ContextManager, Literal
from typing_extensions import AsyncContextManager, Literal
from twisted.internet import defer
from twisted.internet.defer import CancelledError
@@ -491,7 +492,7 @@ class ReadWriteLock:
Example:
with await read_write_lock.read("test_key"):
async with read_write_lock.read("test_key"):
# do some work
"""
@@ -514,7 +515,7 @@ class ReadWriteLock:
# Latest writer queued
self.key_to_current_writer: Dict[str, defer.Deferred] = {}
async def read(self, key: str) -> ContextManager:
def read(self, key: str) -> AsyncContextManager:
new_defer: "defer.Deferred[None]" = defer.Deferred()
curr_readers = self.key_to_current_readers.setdefault(key, set())
@@ -522,14 +523,16 @@ class ReadWriteLock:
curr_readers.add(new_defer)
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
if curr_writer:
await make_deferred_yieldable(curr_writer)
@contextmanager
def _ctx_manager() -> Iterator[None]:
@asynccontextmanager
async def _ctx_manager() -> AsyncIterator[None]:
try:
# We wait for the latest writer to finish writing. We can safely ignore
# any existing readers... as they're readers.
# May raise a `CancelledError` if the `Deferred` wrapping us is
# cancelled. The `Deferred` we are waiting on must not be cancelled,
# since we do not own it.
if curr_writer:
await make_deferred_yieldable(stop_cancellation(curr_writer))
yield
finally:
with PreserveLoggingContext():
@@ -538,7 +541,7 @@ class ReadWriteLock:
return _ctx_manager()
async def write(self, key: str) -> ContextManager:
def write(self, key: str) -> AsyncContextManager:
new_defer: "defer.Deferred[None]" = defer.Deferred()
curr_readers = self.key_to_current_readers.get(key, set())
@@ -549,25 +552,41 @@ class ReadWriteLock:
if curr_writer:
to_wait_on.append(curr_writer)
# We can clear the list of current readers since the new writer waits
# We can clear the list of current readers since `new_defer` waits
# for them to finish.
curr_readers.clear()
self.key_to_current_writer[key] = new_defer
await make_deferred_yieldable(defer.gatherResults(to_wait_on))
@contextmanager
def _ctx_manager() -> Iterator[None]:
@asynccontextmanager
async def _ctx_manager() -> AsyncIterator[None]:
to_wait_on_defer = defer.gatherResults(to_wait_on)
try:
# Wait for all current readers and the latest writer to finish.
# May raise a `CancelledError` if the `Deferred` wrapping us is
# cancelled. The `Deferred`s we are waiting on must not be cancelled,
# since we do not own them.
await make_deferred_yieldable(stop_cancellation(to_wait_on_defer))
yield
finally:
with PreserveLoggingContext():
new_defer.callback(None)
# `self.key_to_current_writer[key]` may be missing if there was another
# writer waiting for us and it completed entirely within the
# `new_defer.callback()` call above.
if self.key_to_current_writer.get(key) == new_defer:
self.key_to_current_writer.pop(key)
def release() -> None:
with PreserveLoggingContext():
new_defer.callback(None)
# `self.key_to_current_writer[key]` may be missing if there was another
# writer waiting for us and it completed entirely within the
# `new_defer.callback()` call above.
if self.key_to_current_writer.get(key) == new_defer:
self.key_to_current_writer.pop(key)
if to_wait_on_defer.called:
release()
else:
# We don't have the lock yet, probably because we were cancelled
# while waiting for it. We can't call `release()` yet, since
# `new_defer` must only resolve once all previous readers and
# writers have finished.
# NB: `release()` won't have a logcontext in this path.
to_wait_on_defer.addCallback(lambda _: release())
return _ctx_manager()
@@ -695,3 +714,59 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
new_deferred: defer.Deferred[T] = defer.Deferred()
deferred.chainDeferred(new_deferred)
return new_deferred
def delay_cancellation(deferred: "defer.Deferred[T]", all: bool) -> "defer.Deferred[T]":
"""Delay cancellation of a `Deferred` until it resolves.
Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
resolve with a `CancelledError` until the original `Deferred` resolves.
Args:
deferred: The `Deferred` to protect against cancellation. Must not follow the
Synapse logcontext rules if `all` is `False`.
all: `True` to delay multiple cancellations. `False` to delay only the first
cancellation.
Returns:
A new `Deferred`, which will contain the result of the original `Deferred`.
The new `Deferred` will not propagate cancellation through to the original.
When cancelled, the new `Deferred` will wait until the original `Deferred`
resolves before failing with a `CancelledError`.
The new `Deferred` will only follow the Synapse logcontext rules if `all` is
`True` and `deferred` follows the Synapse logcontext rules. Otherwise the new
`Deferred` should be wrapped with `make_deferred_yieldable`.
"""
def cancel_errback(failure: Failure) -> Union[Failure, "defer.Deferred[T]"]:
"""Insert another `Deferred` into the chain to delay cancellation.
Called when the original `Deferred` resolves or the new `Deferred` is
cancelled.
"""
failure.trap(CancelledError)
if deferred.called and not deferred.paused:
# The `CancelledError` came from the original `Deferred`. Pass it through.
return failure
# Construct another `Deferred` that will only fail with the `CancelledError`
# once the original `Deferred` resolves.
delay_deferred: "defer.Deferred[T]" = defer.Deferred()
deferred.chainDeferred(delay_deferred)
if all:
# Intercept cancellations recursively. Each cancellation will cause another
# `Deferred` to be inserted into the chain.
delay_deferred.addErrback(cancel_errback)
# Override the result with the `CancelledError`.
delay_deferred.addBoth(lambda _: failure)
return delay_deferred
new_deferred: "defer.Deferred[T]" = defer.Deferred()
deferred.chainDeferred(new_deferred)
new_deferred.addErrback(cancel_errback)
return new_deferred

View File

@@ -40,6 +40,7 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import delay_cancellation
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.lrucache import LruCache
@@ -322,6 +323,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
ret = cache.set(cache_key, ret, callback=invalidate_callback)
# We started a new call to `self.orig`, so we must always wait for it to
# complete. Otherwise we might mark our current logging context as
# finished while `self.orig` is still using it in the background.
ret = delay_cancellation(ret, all=True)
return make_deferred_yieldable(ret)
wrapped = cast(_CachedFunction, _wrapped)
@@ -482,6 +488,11 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
lambda _: results, unwrapFirstError
)
if missing:
# We started a new call to `self.orig`, so we must always wait for it to
# complete. Otherwise we might mark our current logging context as
# finished while `self.orig` is still using it in the background.
d = delay_cancellation(d, all=True)
return make_deferred_yieldable(d)
else:
return defer.succeed(results)

View File

@@ -12,29 +12,183 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.database import make_tuple_comparison_clause
from synapse.storage.engines import BaseDatabaseEngine
from typing import Callable, NoReturn, Tuple
from unittest.mock import Mock
from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.logging.context import LoggingContext
from synapse.server import HomeServer
from synapse.storage.database import (
DatabasePool,
LoggingTransaction,
make_tuple_comparison_clause,
)
from synapse.util import Clock
from tests import unittest
def _stub_db_engine(**kwargs) -> BaseDatabaseEngine:
# returns a DatabaseEngine, circumventing the abc mechanism
# any kwargs are set as attributes on the class before instantiating it
t = type(
"TestBaseDatabaseEngine",
(BaseDatabaseEngine,),
dict(BaseDatabaseEngine.__dict__),
)
# defeat the abc mechanism
t.__abstractmethods__ = set()
for k, v in kwargs.items():
setattr(t, k, v)
return t(None, None)
class TupleComparisonClauseTestCase(unittest.TestCase):
def test_native_tuple_comparison(self):
clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)])
self.assertEqual(clause, "(a,b) > (?,?)")
self.assertEqual(args, [1, 2])
class CallbacksTestCase(unittest.HomeserverTestCase):
"""Tests for transaction callbacks."""
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
def _run_interaction(
self, func: Callable[[LoggingTransaction, int], None]
) -> Tuple[Mock, Mock]:
"""Run the given function in a database transaction, with callbacks registered.
Args:
func: The function to be run in a transaction. The transaction will be
retried if `func` raises an `OperationalError`.
Returns:
Two mocks, which were registered as an `after_callback` and an
`exception_callback` respectively, on every transaction attempt.
"""
after_callback = Mock()
exception_callback = Mock()
def _test_txn(txn: LoggingTransaction) -> None:
txn.call_after(after_callback, 123, 456, extra=789)
txn.call_on_exception(exception_callback, 987, 654, extra=321)
func(txn)
try:
self.get_success_or_raise(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
except Exception:
pass
return after_callback, exception_callback
def test_after_callback(self) -> None:
"""Test that the after callback is called when a transaction succeeds."""
after_callback, exception_callback = self._run_interaction(lambda txn: None)
after_callback.assert_called_once_with(123, 456, extra=789)
exception_callback.assert_not_called()
def test_exception_callback(self) -> None:
"""Test that the exception callback is called when a transaction fails."""
after_callback, exception_callback = self._run_interaction(lambda txn: 1 / 0)
after_callback.assert_not_called()
exception_callback.assert_called_once_with(987, 654, extra=321)
def test_failed_retry(self) -> None:
"""Test that the exception callback is called for every failed attempt."""
def _test_txn(txn: LoggingTransaction) -> NoReturn:
"""Simulate a retryable failure on every attempt."""
raise self.db_pool.engine.module.OperationalError()
after_callback, exception_callback = self._run_interaction(_test_txn)
after_callback.assert_not_called()
exception_callback.assert_has_calls(
[
((987, 654), {"extra": 321}),
((987, 654), {"extra": 321}),
((987, 654), {"extra": 321}),
((987, 654), {"extra": 321}),
((987, 654), {"extra": 321}),
((987, 654), {"extra": 321}),
]
)
self.assertEqual(exception_callback.call_count, 6) # no additional calls
def test_successful_retry(self) -> None:
"""Test callbacks for a failed transaction followed by a successful attempt."""
first_attempt = True
def _test_txn(txn: LoggingTransaction) -> None:
"""Simulate a retryable failure on the first attempt only."""
nonlocal first_attempt
if first_attempt:
first_attempt = False
raise self.db_pool.engine.module.OperationalError()
else:
return None
after_callback, exception_callback = self._run_interaction(_test_txn)
# Calling both `after_callback`s when the first attempt failed is rather
# dubious. But let's document the behaviour in a test.
after_callback.assert_has_calls(
[
((123, 456), {"extra": 789}),
((123, 456), {"extra": 789}),
]
)
self.assertEqual(after_callback.call_count, 2) # no additional calls
exception_callback.assert_not_called()
class CancellationTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
def test_after_callback(self) -> None:
"""Test that the after callback is called when a transaction succeeds."""
d: "Deferred[None]"
after_callback = Mock()
exception_callback = Mock()
def _test_txn(txn: LoggingTransaction) -> None:
txn.call_after(after_callback, 123, 456, extra=789)
txn.call_on_exception(exception_callback, 987, 654, extra=321)
d.cancel()
d = defer.ensureDeferred(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
self.get_failure(d, CancelledError)
after_callback.assert_called_once_with(123, 456, extra=789)
exception_callback.assert_not_called()
def test_exception_callback(self) -> None:
"""Test that the exception callback is called when a transaction fails."""
d: "Deferred[None]"
after_callback = Mock()
exception_callback = Mock()
def _test_txn(txn: LoggingTransaction) -> None:
txn.call_after(after_callback, 123, 456, extra=789)
txn.call_on_exception(exception_callback, 987, 654, extra=321)
d.cancel()
# Simulate a retryable failure on every attempt.
raise self.db_pool.engine.module.OperationalError()
d = defer.ensureDeferred(
self.db_pool.runInteraction("test_transaction", _test_txn)
)
self.get_failure(d, CancelledError)
after_callback.assert_not_called()
exception_callback.assert_has_calls(
[
((987, 654), {"extra": 321}),
((987, 654), {"extra": 321}),
((987, 654), {"extra": 321}),
((987, 654), {"extra": 321}),
((987, 654), {"extra": 321}),
((987, 654), {"extra": 321}),
]
)
self.assertEqual(exception_callback.call_count, 6) # no additional calls

View File

@@ -17,7 +17,7 @@ from typing import Set
from unittest import mock
from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred
from twisted.internet.defer import CancelledError, Deferred
from synapse.api.errors import SynapseError
from synapse.logging.context import (
@@ -28,7 +28,7 @@ from synapse.logging.context import (
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached, lru_cache
from synapse.util.caches.descriptors import cached, cachedList, lru_cache
from tests import unittest
from tests.test_utils import get_awaitable_result
@@ -415,6 +415,74 @@ class DescriptorTestCase(unittest.TestCase):
obj.invalidate()
top_invalidate.assert_called_once()
def test_cancel(self):
"""Test that cancelling a lookup does not cancel other lookups"""
complete_lookup: "Deferred[None]" = Deferred()
class Cls:
@cached()
async def fn(self, arg1):
await complete_lookup
return str(arg1)
obj = Cls()
d1 = obj.fn(123)
d2 = obj.fn(123)
self.assertFalse(d1.called)
self.assertFalse(d2.called)
# Cancel `d1`, which is the lookup that caused `fn` to run.
d1.cancel()
# `d2` should complete normally.
complete_lookup.callback(None)
self.failureResultOf(d1, CancelledError)
self.assertEqual(d2.result, "123")
def test_cancel_logcontexts(self):
"""Test that cancellation does not break logcontexts.
* The `CancelledError` must be raised with the correct logcontext.
* The inner lookup must not resume with a finished logcontext.
* The inner lookup must not restore a finished logcontext when done.
"""
complete_lookup: "Deferred[None]" = Deferred()
class Cls:
inner_context_was_finished = False
@cached()
async def fn(self, arg1):
await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished
return str(arg1)
obj = Cls()
async def do_lookup():
with LoggingContext("c1") as c1:
try:
await obj.fn(123)
self.fail("No CancelledError thrown")
except CancelledError:
self.assertEqual(
current_context(),
c1,
"CancelledError was not raised with the correct logcontext",
)
# suppress the error and succeed
d = defer.ensureDeferred(do_lookup())
d.cancel()
complete_lookup.callback(None)
self.successResultOf(d)
self.assertFalse(
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""More tests for @cached
@@ -787,3 +855,78 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.fn.invalidate((10, 2))
invalidate0.assert_called_once()
invalidate1.assert_called_once()
def test_cancel(self):
"""Test that cancelling a lookup does not cancel other lookups"""
complete_lookup: "Deferred[None]" = Deferred()
class Cls:
@cached()
def fn(self, arg1):
pass
@cachedList("fn", "args")
async def list_fn(self, args):
await complete_lookup
return {arg: str(arg) for arg in args}
obj = Cls()
d1 = obj.list_fn([123, 456])
d2 = obj.list_fn([123, 456, 789])
self.assertFalse(d1.called)
self.assertFalse(d2.called)
d1.cancel()
# `d2` should complete normally.
complete_lookup.callback(None)
self.failureResultOf(d1, CancelledError)
self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})
def test_cancel_logcontexts(self):
"""Test that cancellation does not break logcontexts.
* The `CancelledError` must be raised with the correct logcontext.
* The inner lookup must not resume with a finished logcontext.
* The inner lookup must not restore a finished logcontext when done.
"""
complete_lookup: "Deferred[None]" = Deferred()
class Cls:
inner_context_was_finished = False
@cached()
def fn(self, arg1):
pass
@cachedList("fn", "args")
async def list_fn(self, args):
await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished
return {arg: str(arg) for arg in args}
obj = Cls()
async def do_lookup():
with LoggingContext("c1") as c1:
try:
await obj.list_fn([123])
self.fail("No CancelledError thrown")
except CancelledError:
self.assertEqual(
current_context(),
c1,
"CancelledError was not raised with the correct logcontext",
)
# suppress the error and succeed
d = defer.ensureDeferred(do_lookup())
d.cancel()
complete_lookup.callback(None)
self.successResultOf(d)
self.assertFalse(
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
)
self.assertEqual(current_context(), SENTINEL_CONTEXT)

View File

@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import AsyncContextManager, Callable, Tuple
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.internet.defer import CancelledError, Deferred
from synapse.util.async_helpers import ReadWriteLock
@@ -32,76 +34,120 @@ class ReadWriteLockTestCase(unittest.TestCase):
def test_rwlock(self):
rwlock = ReadWriteLock()
key = "key"
key = object()
def start_reader_or_writer(
read_or_write: Callable[[str], AsyncContextManager]
) -> Tuple["Deferred[None]", "Deferred[None]"]:
acquired_d: "Deferred[None]" = Deferred()
release_d: "Deferred[None]" = Deferred()
async def action():
async with read_or_write(key):
acquired_d.callback(None)
await release_d
defer.ensureDeferred(action())
return acquired_d, release_d
ds = [
rwlock.read(key), # 0
rwlock.read(key), # 1
rwlock.write(key), # 2
rwlock.write(key), # 3
rwlock.read(key), # 4
rwlock.read(key), # 5
rwlock.write(key), # 6
start_reader_or_writer(rwlock.read), # 0
start_reader_or_writer(rwlock.read), # 1
start_reader_or_writer(rwlock.write), # 2
start_reader_or_writer(rwlock.write), # 3
start_reader_or_writer(rwlock.read), # 4
start_reader_or_writer(rwlock.read), # 5
start_reader_or_writer(rwlock.write), # 6
]
ds = [defer.ensureDeferred(d) for d in ds]
# `Deferred`s that resolve when each reader or writer acquires the lock.
acquired_ds = [acquired_d for acquired_d, _release_d in ds]
# `Deferred`s that will trigger the release of locks when resolved.
release_ds = [release_d for _acquired_d, release_d in ds]
self._assert_called_before_not_after(ds, 2)
self._assert_called_before_not_after(acquired_ds, 2)
with ds[0].result:
self._assert_called_before_not_after(ds, 2)
self._assert_called_before_not_after(ds, 2)
self._assert_called_before_not_after(acquired_ds, 2)
release_ds[0].callback(None)
self._assert_called_before_not_after(acquired_ds, 2)
with ds[1].result:
self._assert_called_before_not_after(ds, 2)
self._assert_called_before_not_after(ds, 3)
self._assert_called_before_not_after(acquired_ds, 2)
release_ds[1].callback(None)
self._assert_called_before_not_after(acquired_ds, 3)
with ds[2].result:
self._assert_called_before_not_after(ds, 3)
self._assert_called_before_not_after(ds, 4)
self._assert_called_before_not_after(acquired_ds, 3)
release_ds[2].callback(None)
self._assert_called_before_not_after(acquired_ds, 4)
with ds[3].result:
self._assert_called_before_not_after(ds, 4)
self._assert_called_before_not_after(ds, 6)
self._assert_called_before_not_after(acquired_ds, 4)
release_ds[3].callback(None)
self._assert_called_before_not_after(acquired_ds, 6)
with ds[5].result:
self._assert_called_before_not_after(ds, 6)
self._assert_called_before_not_after(ds, 6)
self._assert_called_before_not_after(acquired_ds, 6)
release_ds[5].callback(None)
self._assert_called_before_not_after(acquired_ds, 6)
with ds[4].result:
self._assert_called_before_not_after(ds, 6)
self._assert_called_before_not_after(ds, 7)
self._assert_called_before_not_after(acquired_ds, 6)
release_ds[4].callback(None)
self._assert_called_before_not_after(acquired_ds, 7)
with ds[6].result:
pass
release_ds[6].callback(None)
d = defer.ensureDeferred(rwlock.write(key))
self.assertTrue(d.called)
with d.result:
pass
acquired_d, release_d = start_reader_or_writer(rwlock.write)
self.assertTrue(acquired_d.called)
release_d.callback(None)
d = defer.ensureDeferred(rwlock.read(key))
self.assertTrue(d.called)
with d.result:
pass
acquired_d, release_d = start_reader_or_writer(rwlock.read)
self.assertTrue(acquired_d.called)
release_d.callback(None)
def _start_reader_or_writer(
self,
read_or_write: Callable[[str], AsyncContextManager],
key: str,
name: str,
) -> Tuple["Deferred[None]", "Deferred[None]"]:
"""Starts a reader or writer which acquires the lock, blocks, then completes."""
unblock_d: "Deferred[None]" = Deferred()
async def reader_or_writer():
async with read_or_write(key):
await unblock_d
return f"{name} completed"
d = defer.ensureDeferred(reader_or_writer())
return d, unblock_d
def _start_blocking_reader(
self, rwlock: ReadWriteLock, key: str, name: str
) -> Tuple["Deferred[None]", "Deferred[None]"]:
"""Starts a reader which acquires the lock, blocks, then releases the lock."""
return self._start_reader_or_writer(rwlock.read, key, name)
def _start_blocking_writer(
self, rwlock: ReadWriteLock, key: str, name: str
) -> Tuple["Deferred[None]", "Deferred[None]"]:
"""Starts a writer which acquires the lock, blocks, then releases the lock."""
return self._start_reader_or_writer(rwlock.write, key, name)
def _start_nonblocking_reader(self, rwlock: ReadWriteLock, key: str, name: str):
"""Starts a reader which acquires the lock, then releases it immediately."""
d, unblock_d = self._start_reader_or_writer(rwlock.read, key, name)
unblock_d.callback(None)
return d
def _start_nonblocking_writer(self, rwlock: ReadWriteLock, key: str, name: str):
"""Starts a writer which acquires the lock, then releases it immediately."""
d, unblock_d = self._start_reader_or_writer(rwlock.write, key, name)
unblock_d.callback(None)
return d
def test_lock_handoff_to_nonblocking_writer(self):
"""Test a writer handing the lock to another writer that completes instantly."""
rwlock = ReadWriteLock()
key = "key"
unblock: "Deferred[None]" = Deferred()
async def blocking_write():
with await rwlock.write(key):
await unblock
async def nonblocking_write():
with await rwlock.write(key):
pass
d1 = defer.ensureDeferred(blocking_write())
d2 = defer.ensureDeferred(nonblocking_write())
d1, unblock = self._start_blocking_writer(rwlock, key, "write 1")
d2 = self._start_nonblocking_writer(rwlock, key, "write 2")
self.assertFalse(d1.called)
self.assertFalse(d2.called)
@@ -111,5 +157,172 @@ class ReadWriteLockTestCase(unittest.TestCase):
self.assertTrue(d2.called)
# The `ReadWriteLock` should operate as normal.
d3 = defer.ensureDeferred(nonblocking_write())
d3 = self._start_nonblocking_writer(rwlock, key, "write 3")
self.assertTrue(d3.called)
def test_cancellation_while_holding_read_lock(self):
"""Test cancellation while holding a read lock.
A waiting writer should be given the lock when the reader holding the lock is
cancelled.
"""
rwlock = ReadWriteLock()
key = "key"
# 1. A reader takes the lock and blocks.
reader_d, _ = self._start_blocking_reader(rwlock, key, "read")
# 2. A writer waits for the reader to complete.
writer_d = self._start_nonblocking_writer(rwlock, key, "write")
self.assertFalse(writer_d.called)
# 3. The reader is cancelled.
reader_d.cancel()
self.failureResultOf(reader_d, CancelledError)
# 4. The writer should take the lock and complete.
self.assertTrue(
writer_d.called, "Writer is stuck waiting for a cancelled reader"
)
self.assertEqual("write completed", self.successResultOf(writer_d))
def test_cancellation_while_holding_write_lock(self):
"""Test cancellation while holding a write lock.
A waiting reader should be given the lock when the writer holding the lock is
cancelled.
"""
rwlock = ReadWriteLock()
key = "key"
# 1. A writer takes the lock and blocks.
writer_d, _ = self._start_blocking_writer(rwlock, key, "write")
# 2. A reader waits for the writer to complete.
reader_d = self._start_nonblocking_reader(rwlock, key, "read")
self.assertFalse(reader_d.called)
# 3. The writer is cancelled.
writer_d.cancel()
self.failureResultOf(writer_d, CancelledError)
# 4. The reader should take the lock and complete.
self.assertTrue(
reader_d.called, "Reader is stuck waiting for a cancelled writer"
)
self.assertEqual("read completed", self.successResultOf(reader_d))
def test_cancellation_while_waiting_for_read_lock(self):
"""Test cancellation while waiting for a read lock.
Tests that cancelling a waiting reader:
* does not cancel the writer it is waiting on
* does not cancel the next writer waiting on it
* does not allow the next writer to acquire the lock before an earlier writer
has finished
* does not keep the next writer waiting indefinitely
These correspond to the asserts with explicit messages.
"""
rwlock = ReadWriteLock()
key = "key"
# 1. A writer takes the lock and blocks.
writer1_d, unblock_writer1 = self._start_blocking_writer(rwlock, key, "write 1")
# 2. A reader waits for the first writer to complete.
# This reader will be cancelled later.
reader_d = self._start_nonblocking_reader(rwlock, key, "read")
self.assertFalse(reader_d.called)
# 3. A second writer waits for both the first writer and the reader to complete.
writer2_d = self._start_nonblocking_writer(rwlock, key, "write 2")
self.assertFalse(writer2_d.called)
# 4. The waiting reader is cancelled.
# Neither of the writers should be cancelled.
# The second writer should still be waiting, but only on the first writer.
reader_d.cancel()
self.failureResultOf(reader_d, CancelledError)
self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled")
self.assertFalse(
writer2_d.called,
"Second writer was unexpectedly cancelled or given the lock before the "
"first writer finished",
)
# 5. Unblock the first writer, which should complete.
unblock_writer1.callback(None)
self.assertEqual("write 1 completed", self.successResultOf(writer1_d))
# 6. The second writer should take the lock and complete.
self.assertTrue(
writer2_d.called, "Second writer is stuck waiting for a cancelled reader"
)
self.assertEqual("write 2 completed", self.successResultOf(writer2_d))
def test_cancellation_while_waiting_for_write_lock(self):
"""Test cancellation while waiting for a write lock.
Tests that cancelling a waiting writer:
* does not cancel the reader or writer it is waiting on
* does not cancel the next writer waiting on it
* does not allow the next writer to acquire the lock before an earlier reader
and writer have finished
* does not keep the next writer waiting indefinitely
These correspond to the asserts with explicit messages.
"""
rwlock = ReadWriteLock()
key = "key"
# 1. A reader takes the lock and blocks.
reader_d, unblock_reader = self._start_blocking_reader(rwlock, key, "read")
# 2. A writer waits for the reader to complete.
writer1_d, unblock_writer1 = self._start_blocking_writer(rwlock, key, "write 1")
# 3. A second writer waits for both the reader and first writer to complete.
# This writer will be cancelled later.
writer2_d = self._start_nonblocking_writer(rwlock, key, "write 2")
self.assertFalse(writer2_d.called)
# 4. A third writer waits for the second writer to complete.
writer3_d = self._start_nonblocking_writer(rwlock, key, "write 3")
self.assertFalse(writer3_d.called)
# 5. The second writer is cancelled.
# The reader, first writer and third writer should not be cancelled.
# The first writer should still be waiting on the reader.
# The third writer should still be waiting, even though the second writer has
# been cancelled.
writer2_d.cancel()
self.failureResultOf(writer2_d, CancelledError)
self.assertFalse(reader_d.called, "Reader was unexpectedly cancelled")
self.assertFalse(writer1_d.called, "First writer was unexpectedly cancelled")
self.assertFalse(
writer3_d.called,
"Third writer was unexpectedly cancelled or given the lock before the first"
"writer finished",
)
# 6. Unblock the reader, which should complete.
# The first writer should be given the lock and block.
# The third writer should still be waiting.
unblock_reader.callback(None)
self.assertEqual("read completed", self.successResultOf(reader_d))
self.assertFalse(
writer3_d.called,
"Third writer was unexpectedly given the lock before the first writer "
"finished",
)
# 7. Unblock the first writer, which should complete.
unblock_writer1.callback(None)
self.assertEqual("write 1 completed", self.successResultOf(writer1_d))
# 8. The third writer should take the lock and complete.
self.assertTrue(
writer3_d.called, "Third writer is stuck waiting for a cancelled writer"
)
self.assertEqual("write 3 completed", self.successResultOf(writer3_d))