Add a AwakenableSleeper class
This commit is contained in:
@@ -734,3 +734,60 @@ def delay_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
||||
new_deferred: "defer.Deferred[T]" = defer.Deferred(handle_cancel)
|
||||
deferred.chainDeferred(new_deferred)
|
||||
return new_deferred
|
||||
|
||||
|
||||
class AwakenableSleeper:
|
||||
"""Allows explicitly waking up deferreds related to an entity that are
|
||||
currently sleeping.
|
||||
"""
|
||||
|
||||
def __init__(self, reactor: IReactorTime) -> None:
|
||||
self._streams: Dict[str, Set[defer.Deferred[None]]] = {}
|
||||
self._reactor = reactor
|
||||
|
||||
def wake(self, name: str) -> None:
|
||||
"""Wake everything related to `name` that is currently sleeping."""
|
||||
stream_set = self._streams.pop(name, set())
|
||||
for deferred in set(stream_set):
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
deferred.callback(None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def sleep(self, name: str, delay_ms: int) -> None:
|
||||
"""Sleep for the given number of milliseconds, or return if the given
|
||||
`name` is explicitly woken up.
|
||||
"""
|
||||
|
||||
# Create a deferred that gets called in N seconds
|
||||
sleep_deferred: "defer.Deferred[None]" = defer.Deferred()
|
||||
call = self._reactor.callLater(delay_ms / 1000, sleep_deferred.callback, None)
|
||||
|
||||
# Create a deferred that will get called if `wake` is called with
|
||||
# the same `name`.
|
||||
stream_set = self._streams.setdefault(name, set())
|
||||
notify_deferred: "defer.Deferred[None]" = defer.Deferred()
|
||||
stream_set.add(notify_deferred)
|
||||
|
||||
try:
|
||||
# Wait for either the delay or for `wake` to be called.
|
||||
await make_deferred_yieldable(
|
||||
defer.DeferredList(
|
||||
[sleep_deferred, notify_deferred],
|
||||
fireOnOneCallback=True,
|
||||
fireOnOneErrback=True,
|
||||
consumeErrors=True,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
# Clean up the state
|
||||
stream_set.discard(notify_deferred)
|
||||
|
||||
curr_stream_set = self._streams.get(name)
|
||||
if curr_stream_set is not None and len(curr_stream_set) == 0:
|
||||
self._streams.pop(name)
|
||||
|
||||
# Cancel the sleep if we were woken up
|
||||
if call.active():
|
||||
call.cancel()
|
||||
|
||||
@@ -28,6 +28,7 @@ from synapse.logging.context import (
|
||||
make_deferred_yieldable,
|
||||
)
|
||||
from synapse.util.async_helpers import (
|
||||
AwakenableSleeper,
|
||||
ObservableDeferred,
|
||||
concurrently_execute,
|
||||
delay_cancellation,
|
||||
@@ -35,6 +36,7 @@ from synapse.util.async_helpers import (
|
||||
timeout_deferred,
|
||||
)
|
||||
|
||||
from tests.server import get_clock
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
@@ -467,3 +469,39 @@ class DelayCancellationTests(TestCase):
|
||||
# logging context.
|
||||
blocking_d.callback(None)
|
||||
self.successResultOf(d)
|
||||
|
||||
|
||||
class AwakenableSleeperTests(TestCase):
|
||||
"Tests AwakenableSleeper"
|
||||
|
||||
def test_sleep(self):
|
||||
reactor, _ = get_clock()
|
||||
sleeper = AwakenableSleeper(reactor)
|
||||
|
||||
d = defer.ensureDeferred(sleeper.sleep("name", 1000))
|
||||
|
||||
reactor.pump([0.0])
|
||||
self.assertFalse(d.called)
|
||||
|
||||
reactor.advance(0.5)
|
||||
self.assertFalse(d.called)
|
||||
|
||||
reactor.advance(0.6)
|
||||
self.assertTrue(d.called)
|
||||
|
||||
def test_explicit_wake(self):
|
||||
reactor, _ = get_clock()
|
||||
sleeper = AwakenableSleeper(reactor)
|
||||
|
||||
d = defer.ensureDeferred(sleeper.sleep("name", 1000))
|
||||
|
||||
reactor.pump([0.0])
|
||||
self.assertFalse(d.called)
|
||||
|
||||
reactor.advance(0.5)
|
||||
self.assertFalse(d.called)
|
||||
|
||||
sleeper.wake("name")
|
||||
self.assertTrue(d.called)
|
||||
|
||||
reactor.advance(0.6)
|
||||
|
||||
Reference in New Issue
Block a user