Allow long lived syncs to be cancelled if client has gone away (#19499)
This commit is contained in:
1
changelog.d/19499.misc
Normal file
1
changelog.d/19499.misc
Normal file
@@ -0,0 +1 @@
|
||||
Cancel long-running sync requests if the client has gone away.
|
||||
@@ -411,6 +411,12 @@ indent-style = "space"
|
||||
skip-magic-trailing-comma = false
|
||||
line-ending = "auto"
|
||||
|
||||
[tool.ruff.lint.flake8-bugbear]
|
||||
extend-immutable-calls = [
|
||||
# Durations are immutable
|
||||
"synapse.util.duration.Duration",
|
||||
]
|
||||
|
||||
[tool.maturin]
|
||||
manifest-path = "rust/Cargo.toml"
|
||||
module-name = "synapse.synapse_rust"
|
||||
|
||||
@@ -45,6 +45,7 @@ from synapse.synapse_rust.http_client import HttpClient
|
||||
from synapse.types import JsonDict, Requester, UserID, create_requester
|
||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
|
||||
from synapse.util.duration import Duration
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
from . import introspection_response_timer
|
||||
@@ -139,7 +140,7 @@ class MasDelegatedAuth(BaseAuth):
|
||||
clock=self._clock,
|
||||
name="mas_token_introspection",
|
||||
server_name=self.server_name,
|
||||
timeout_ms=120_000,
|
||||
timeout=Duration(minutes=2),
|
||||
# don't log because the keys are access tokens
|
||||
enable_logging=False,
|
||||
)
|
||||
|
||||
@@ -49,6 +49,7 @@ from synapse.synapse_rust.http_client import HttpClient
|
||||
from synapse.types import Requester, UserID, create_requester
|
||||
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
|
||||
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
|
||||
from synapse.util.duration import Duration
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
from . import introspection_response_timer
|
||||
@@ -205,7 +206,7 @@ class MSC3861DelegatedAuth(BaseAuth):
|
||||
clock=self._clock,
|
||||
name="token_introspection",
|
||||
server_name=self.server_name,
|
||||
timeout_ms=120_000,
|
||||
timeout=Duration(minutes=2),
|
||||
# don't log because the keys are access tokens
|
||||
enable_logging=False,
|
||||
)
|
||||
|
||||
@@ -46,6 +46,7 @@ from synapse.logging import opentracing
|
||||
from synapse.metrics import SERVER_NAME_LABEL
|
||||
from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -132,7 +133,7 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
clock=hs.get_clock(),
|
||||
name="as_protocol_meta",
|
||||
server_name=self.server_name,
|
||||
timeout_ms=HOUR_IN_MS,
|
||||
timeout=Duration(hours=1),
|
||||
)
|
||||
|
||||
def _get_headers(self, service: "ApplicationService") -> dict[bytes, list[bytes]]:
|
||||
|
||||
@@ -29,6 +29,7 @@ import attr
|
||||
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.check_dependencies import check_requirements
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
@@ -108,7 +109,7 @@ class CacheConfig(Config):
|
||||
global_factor: float
|
||||
track_memory_usage: bool
|
||||
expiry_time_msec: int | None
|
||||
sync_response_cache_duration: int
|
||||
sync_response_cache_duration: Duration
|
||||
|
||||
@staticmethod
|
||||
def reset() -> None:
|
||||
@@ -207,10 +208,14 @@ class CacheConfig(Config):
|
||||
min_cache_ttl = self.cache_autotuning.get("min_cache_ttl")
|
||||
self.cache_autotuning["min_cache_ttl"] = self.parse_duration(min_cache_ttl)
|
||||
|
||||
self.sync_response_cache_duration = self.parse_duration(
|
||||
sync_response_cache_duration_ms = self.parse_duration(
|
||||
cache_config.get("sync_response_cache_duration", "2m")
|
||||
)
|
||||
|
||||
self.sync_response_cache_duration = Duration(
|
||||
milliseconds=sync_response_cache_duration_ms
|
||||
)
|
||||
|
||||
def resize_all_caches(self) -> None:
|
||||
"""Ensure all cache sizes are up-to-date.
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ class FederationServer(FederationBase):
|
||||
clock=hs.get_clock(),
|
||||
name="fed_txn_handler",
|
||||
server_name=self.server_name,
|
||||
timeout_ms=30000,
|
||||
timeout=Duration(seconds=30),
|
||||
)
|
||||
|
||||
self.transaction_actions = TransactionActions(self.store)
|
||||
@@ -179,13 +179,13 @@ class FederationServer(FederationBase):
|
||||
clock=hs.get_clock(),
|
||||
name="state_resp",
|
||||
server_name=self.server_name,
|
||||
timeout_ms=30000,
|
||||
timeout=Duration(seconds=30),
|
||||
)
|
||||
self._state_ids_resp_cache: ResponseCache[tuple[str, str]] = ResponseCache(
|
||||
clock=hs.get_clock(),
|
||||
name="state_ids_resp",
|
||||
server_name=self.server_name,
|
||||
timeout_ms=30000,
|
||||
timeout=Duration(seconds=30),
|
||||
)
|
||||
|
||||
self._federation_metrics_domains = (
|
||||
|
||||
@@ -185,7 +185,7 @@ class RoomCreationHandler:
|
||||
clock=hs.get_clock(),
|
||||
name="room_upgrade",
|
||||
server_name=self.server_name,
|
||||
timeout_ms=FIVE_MINUTES_IN_MS,
|
||||
timeout=Duration(minutes=5),
|
||||
)
|
||||
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ from synapse.storage.databases.main.room import LargestRoomStats
|
||||
from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID
|
||||
from synapse.util.caches.descriptors import _CacheContext, cached
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -79,7 +80,7 @@ class RoomListHandler:
|
||||
clock=hs.get_clock(),
|
||||
name="remote_room_list",
|
||||
server_name=self.server_name,
|
||||
timeout_ms=30 * 1000,
|
||||
timeout=Duration(seconds=30),
|
||||
)
|
||||
|
||||
async def get_local_public_room_list(
|
||||
|
||||
@@ -77,6 +77,7 @@ from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.visibility import filter_and_transform_events_for_client
|
||||
|
||||
@@ -307,7 +308,7 @@ class SyncHandler:
|
||||
clock=hs.get_clock(),
|
||||
name="sync",
|
||||
server_name=self.server_name,
|
||||
timeout_ms=hs.config.caches.sync_response_cache_duration,
|
||||
timeout=hs.config.caches.sync_response_cache_duration,
|
||||
)
|
||||
|
||||
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
|
||||
@@ -367,6 +368,10 @@ class SyncHandler:
|
||||
logger.debug("Returning sync response for %s", user_id)
|
||||
return res
|
||||
|
||||
# TODO: We mark this as cancellable, and we have tests for it, but we
|
||||
# haven't gone through and exhaustively checked that all the code paths in
|
||||
# this method are actually cancellable.
|
||||
@cancellable
|
||||
async def _wait_for_sync_for_user(
|
||||
self,
|
||||
sync_config: SyncConfig,
|
||||
|
||||
@@ -102,7 +102,7 @@ class FollowerTypingHandler:
|
||||
self._room_typing: dict[str, set[str]] = {}
|
||||
|
||||
self._member_last_federation_poke: dict[RoomMember, int] = {}
|
||||
self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000)
|
||||
self.wheel_timer: WheelTimer[RoomMember] = WheelTimer()
|
||||
self._latest_room_serial = 0
|
||||
|
||||
self._rooms_updated: set[str] = set()
|
||||
@@ -120,7 +120,7 @@ class FollowerTypingHandler:
|
||||
self._rooms_updated = set()
|
||||
|
||||
self._member_last_federation_poke = {}
|
||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||
self.wheel_timer = WheelTimer()
|
||||
|
||||
@wrap_as_background_process("typing._handle_timeouts")
|
||||
async def _handle_timeouts(self) -> None:
|
||||
|
||||
@@ -130,7 +130,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
|
||||
clock=hs.get_clock(),
|
||||
name="repl." + self.NAME,
|
||||
server_name=self.server_name,
|
||||
timeout_ms=30 * 60 * 1000,
|
||||
timeout=Duration(minutes=30),
|
||||
)
|
||||
|
||||
# We reserve `instance_name` as a parameter to sending requests, so we
|
||||
|
||||
@@ -59,6 +59,7 @@ from synapse.rest.admin.experimental_features import ExperimentalFeature
|
||||
from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken
|
||||
from synapse.types.rest.client import SlidingSyncBody
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.json import json_decoder
|
||||
|
||||
from ._base import client_patterns, set_timeline_upper_limit
|
||||
@@ -138,6 +139,7 @@ class SyncRestServlet(RestServlet):
|
||||
cfg=hs.config.ratelimiting.rc_presence_per_user,
|
||||
)
|
||||
|
||||
@cancellable
|
||||
async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]:
|
||||
# This will always be set by the time Twisted calls us.
|
||||
assert request.args is not None
|
||||
|
||||
@@ -57,6 +57,7 @@ from synapse.logging.context import (
|
||||
run_coroutine_in_background,
|
||||
run_in_background,
|
||||
)
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
@@ -83,6 +84,13 @@ class AbstractObservableDeferred(Generic[_T], metaclass=abc.ABCMeta):
|
||||
"""
|
||||
...
|
||||
|
||||
@abc.abstractmethod
|
||||
def has_observers(self) -> bool:
|
||||
"""Returns True if there are any observers currently observing this
|
||||
ObservableDeferred.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
|
||||
"""Wraps a deferred object so that we can add observer deferreds. These
|
||||
@@ -122,6 +130,11 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
|
||||
for observer in observers:
|
||||
try:
|
||||
observer.callback(r)
|
||||
except defer.CancelledError:
|
||||
# We do not want to propagate cancellations to the original
|
||||
# deferred, or to other observers, so we can just ignore
|
||||
# this.
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"%r threw an exception on .callback(%r), ignoring...",
|
||||
@@ -145,6 +158,11 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
|
||||
f.value.__failure__ = f
|
||||
try:
|
||||
observer.errback(f)
|
||||
except defer.CancelledError:
|
||||
# We do not want to propagate cancellations to the original
|
||||
# deferred, or to other observers, so we can just ignore
|
||||
# this.
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"%r threw an exception on .errback(%r), ignoring...",
|
||||
@@ -160,6 +178,7 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
|
||||
|
||||
deferred.addCallbacks(callback, errback)
|
||||
|
||||
@cancellable
|
||||
def observe(self) -> "defer.Deferred[_T]":
|
||||
"""Observe the underlying deferred.
|
||||
|
||||
@@ -169,7 +188,7 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
|
||||
"""
|
||||
if not self._result:
|
||||
assert isinstance(self._observers, list)
|
||||
d: "defer.Deferred[_T]" = defer.Deferred()
|
||||
d: "defer.Deferred[_T]" = defer.Deferred(canceller=self._remove_observer)
|
||||
self._observers.append(d)
|
||||
return d
|
||||
elif self._result[0]:
|
||||
@@ -180,6 +199,12 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
|
||||
def observers(self) -> "Collection[defer.Deferred[_T]]":
|
||||
return self._observers
|
||||
|
||||
def has_observers(self) -> bool:
|
||||
"""Returns True if there are any observers currently observing this
|
||||
ObservableDeferred.
|
||||
"""
|
||||
return bool(self._observers)
|
||||
|
||||
def has_called(self) -> bool:
|
||||
return self._result is not None
|
||||
|
||||
@@ -204,6 +229,28 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
|
||||
self._deferred,
|
||||
)
|
||||
|
||||
def _remove_observer(self, observer: "defer.Deferred[_T]") -> None:
|
||||
"""Removes an observer from the list of observers.
|
||||
|
||||
Used as a canceller for the observer deferreds, so that if an observer
|
||||
is cancelled it is removed from the list of observers.
|
||||
"""
|
||||
if self._result is not None:
|
||||
# The underlying deferred has already resolved, so the observer has
|
||||
# already been resolved. Nothing to do.
|
||||
return
|
||||
|
||||
assert isinstance(self._observers, list)
|
||||
try:
|
||||
self._observers.remove(observer)
|
||||
except ValueError:
|
||||
# The observer was not in the list. This can happen if the underlying
|
||||
# deferred resolves at around the same time as we try to remove the
|
||||
# observer. In this case, it's possible that we tried to remove the
|
||||
# observer just after it was added to the list, but before it was
|
||||
# resolved and removed from the list by the callback/errback above.
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -962,6 +1009,27 @@ def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
|
||||
return new_deferred
|
||||
|
||||
|
||||
def observe_deferred(d: "defer.Deferred[T]") -> "defer.Deferred[T]":
|
||||
"""Returns a new `Deferred` that observes the given `Deferred`.
|
||||
|
||||
The returned `Deferred` will resolve with the same result as the given
|
||||
`Deferred`, but will not "chain" on the deferred so that using the returned
|
||||
deferred does not affect the given `Deferred` in any way.
|
||||
"""
|
||||
new_deferred: "defer.Deferred[T]" = defer.Deferred()
|
||||
|
||||
def callback(r: T) -> T:
|
||||
new_deferred.callback(r)
|
||||
return r
|
||||
|
||||
def errback(f: Failure) -> Failure:
|
||||
new_deferred.errback(f)
|
||||
return f
|
||||
|
||||
d.addCallbacks(callback, errback)
|
||||
return new_deferred
|
||||
|
||||
|
||||
class AwakenableSleeper:
|
||||
"""Allows explicitly waking up deferreds related to an entity that are
|
||||
currently sleeping.
|
||||
|
||||
@@ -39,10 +39,15 @@ from synapse.logging.opentracing import (
|
||||
start_active_span,
|
||||
start_active_span_follows_from,
|
||||
)
|
||||
from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred
|
||||
from synapse.util.async_helpers import (
|
||||
ObservableDeferred,
|
||||
delay_cancellation,
|
||||
)
|
||||
from synapse.util.caches import EvictionReason, register_cache
|
||||
from synapse.util.cancellation import cancellable, is_function_cancellable
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.duration import Duration
|
||||
from synapse.util.wheel_timer import WheelTimer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -79,8 +84,8 @@ class ResponseCacheContext(Generic[KV]):
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class ResponseCacheEntry:
|
||||
result: AbstractObservableDeferred
|
||||
class ResponseCacheEntry(Generic[KV]):
|
||||
result: ObservableDeferred[KV]
|
||||
"""The (possibly incomplete) result of the operation.
|
||||
|
||||
Note that we continue to store an ObservableDeferred even after the operation
|
||||
@@ -91,6 +96,15 @@ class ResponseCacheEntry:
|
||||
opentracing_span_context: "opentracing.SpanContext | None"
|
||||
"""The opentracing span which generated/is generating the result"""
|
||||
|
||||
cancellable: bool
|
||||
"""Whether the deferred is safe to be cancelled."""
|
||||
|
||||
last_observer_removed_time_ms: int | None = None
|
||||
"""The last time that an observer was removed from this entry.
|
||||
|
||||
Used to determine when to evict the entry if it has no observers.
|
||||
"""
|
||||
|
||||
|
||||
class ResponseCache(Generic[KV]):
|
||||
"""
|
||||
@@ -98,6 +112,22 @@ class ResponseCache(Generic[KV]):
|
||||
returned from the cache. This means that if the client retries the request
|
||||
while the response is still being computed, that original response will be
|
||||
used rather than trying to compute a new response.
|
||||
|
||||
If a timeout is not specified then the cache entry will be kept while the
|
||||
wrapped function is still running, and will be removed immediately once it
|
||||
completes.
|
||||
|
||||
If a timeout is specified then the cache entry will be kept for the duration
|
||||
of the timeout after the wrapped function completes. If the wrapped function
|
||||
is cancellable and during processing nothing waits on the result for longer
|
||||
than the timeout then the wrapped function will be cancelled and the cache
|
||||
entry will be removed.
|
||||
|
||||
This behaviour is useful for caching responses to requests which are
|
||||
expensive to compute, but which may be retried by clients if they time out.
|
||||
For example, /sync requests which may take a long time to compute, and which
|
||||
clients will retry. However, if the client stops retrying for a while then
|
||||
we want to stop processing the request and free up the resources.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -106,7 +136,7 @@ class ResponseCache(Generic[KV]):
|
||||
clock: Clock,
|
||||
name: str,
|
||||
server_name: str,
|
||||
timeout_ms: float = 0,
|
||||
timeout: Duration | None = None,
|
||||
enable_logging: bool = True,
|
||||
):
|
||||
"""
|
||||
@@ -121,7 +151,7 @@ class ResponseCache(Generic[KV]):
|
||||
self._result_cache: dict[KV, ResponseCacheEntry] = {}
|
||||
|
||||
self.clock = clock
|
||||
self.timeout = Duration(milliseconds=timeout_ms)
|
||||
self.timeout = timeout
|
||||
|
||||
self._name = name
|
||||
self._metrics = register_cache(
|
||||
@@ -133,6 +163,13 @@ class ResponseCache(Generic[KV]):
|
||||
)
|
||||
self._enable_logging = enable_logging
|
||||
|
||||
self._prune_timer: WheelTimer[KV] | None = None
|
||||
if self.timeout:
|
||||
# Set up the timers for pruning inflight entries. The times here are
|
||||
# how often we check for entries to prune.
|
||||
self._prune_timer = WheelTimer(bucket_size=self.timeout / 10)
|
||||
self.clock.looping_call(self._prune_inflight_entries, self.timeout / 10)
|
||||
|
||||
def size(self) -> int:
|
||||
return len(self._result_cache)
|
||||
|
||||
@@ -172,6 +209,7 @@ class ResponseCache(Generic[KV]):
|
||||
context: ResponseCacheContext[KV],
|
||||
deferred: "defer.Deferred[RV]",
|
||||
opentracing_span_context: "opentracing.SpanContext | None",
|
||||
cancellable: bool,
|
||||
) -> ResponseCacheEntry:
|
||||
"""Set the entry for the given key to the given deferred.
|
||||
|
||||
@@ -183,13 +221,16 @@ class ResponseCache(Generic[KV]):
|
||||
context: Information about the cache miss
|
||||
deferred: The deferred which resolves to the result.
|
||||
opentracing_span_context: An opentracing span wrapping the calculation
|
||||
cancellable: Whether the deferred is safe to be cancelled
|
||||
|
||||
Returns:
|
||||
The cache entry object.
|
||||
"""
|
||||
result = ObservableDeferred(deferred, consumeErrors=True)
|
||||
key = context.cache_key
|
||||
entry = ResponseCacheEntry(result, opentracing_span_context)
|
||||
entry = ResponseCacheEntry(
|
||||
result, opentracing_span_context, cancellable=cancellable
|
||||
)
|
||||
self._result_cache[key] = entry
|
||||
|
||||
def on_complete(r: RV) -> RV:
|
||||
@@ -233,6 +274,7 @@ class ResponseCache(Generic[KV]):
|
||||
self._metrics.inc_evictions(EvictionReason.time)
|
||||
self._result_cache.pop(key, None)
|
||||
|
||||
@cancellable
|
||||
async def wrap(
|
||||
self,
|
||||
key: KV,
|
||||
@@ -301,8 +343,44 @@ class ResponseCache(Generic[KV]):
|
||||
return await callback(*args, **kwargs)
|
||||
|
||||
d = run_in_background(cb)
|
||||
entry = self._set(context, d, span_context)
|
||||
return await make_deferred_yieldable(entry.result.observe())
|
||||
entry = self._set(
|
||||
context, d, span_context, cancellable=is_function_cancellable(callback)
|
||||
)
|
||||
try:
|
||||
return await make_deferred_yieldable(entry.result.observe())
|
||||
except defer.CancelledError:
|
||||
pass
|
||||
|
||||
# We've been cancelled.
|
||||
#
|
||||
# Since we've kicked off the background operation, we can't just
|
||||
# give up and return here and need to wait for the background
|
||||
# operation to stop. We don't want to stop the background process
|
||||
# immediately to give a chance for retries to come in and wait for
|
||||
# the result.
|
||||
#
|
||||
# Instead, we temporarily swallow the cancellation and mark the
|
||||
# cache key as one to potentially timeout.
|
||||
|
||||
# Update the `last_observer_removed_time_ms` so that the pruning
|
||||
# mechanism can kick in if needed.
|
||||
now = self.clock.time_msec()
|
||||
entry.last_observer_removed_time_ms = now
|
||||
if self._prune_timer is not None and self.timeout:
|
||||
self._prune_timer.insert(now, key, now + self.timeout.as_millis())
|
||||
|
||||
# Wait on the original deferred, which will continue to run in the
|
||||
# background until it completes. We don't want to add an observer as
|
||||
# this would prevent the entry from being pruned.
|
||||
#
|
||||
# Note that this deferred has been consumed by the
|
||||
# ObservableDeferred, so we don't know what it will return. That
|
||||
# doesn't matter as we just want to throw a CancelledError once it completes anyway.
|
||||
try:
|
||||
await make_deferred_yieldable(delay_cancellation(d))
|
||||
except Exception:
|
||||
pass
|
||||
raise defer.CancelledError()
|
||||
|
||||
result = entry.result.observe()
|
||||
if self._enable_logging:
|
||||
@@ -320,4 +398,60 @@ class ResponseCache(Generic[KV]):
|
||||
f"ResponseCache[{self._name}].wait",
|
||||
contexts=(span_context,) if span_context else (),
|
||||
):
|
||||
return await make_deferred_yieldable(result)
|
||||
try:
|
||||
return await make_deferred_yieldable(result)
|
||||
except defer.CancelledError:
|
||||
# If we're cancelled then we update the
|
||||
# `last_observer_removed_time_ms` so that the pruning mechanism
|
||||
# can kick in if needed.
|
||||
now = self.clock.time_msec()
|
||||
entry.last_observer_removed_time_ms = now
|
||||
if self._prune_timer is not None and self.timeout:
|
||||
self._prune_timer.insert(now, key, now + self.timeout.as_millis())
|
||||
raise
|
||||
|
||||
def _prune_inflight_entries(self) -> None:
|
||||
"""Prune entries which have been in the cache for too long without
|
||||
observers"""
|
||||
assert self._prune_timer is not None
|
||||
assert self.timeout is not None
|
||||
|
||||
now = self.clock.time_msec()
|
||||
keys_to_check = self._prune_timer.fetch(now)
|
||||
|
||||
# Loop through the keys and check if they should be evicted. We evict
|
||||
# entries which have no active observers, and which have been in the
|
||||
# cache for longer than the timeout since the last observer was removed.
|
||||
for key in keys_to_check:
|
||||
entry = self._result_cache.get(key)
|
||||
if not entry:
|
||||
continue
|
||||
|
||||
if not entry.cancellable:
|
||||
# this entry is not cancellable, so we should keep it in the cache until it completes.
|
||||
continue
|
||||
|
||||
if entry.result.has_called():
|
||||
# this entry has already completed, so we should have scheduled it for
|
||||
# removal at the right time. We can just skip it here and wait for the
|
||||
# scheduled call to remove it.
|
||||
continue
|
||||
|
||||
if entry.result.has_observers():
|
||||
# this entry has observers, so we should keep it in the cache for now.
|
||||
continue
|
||||
|
||||
if entry.last_observer_removed_time_ms is None:
|
||||
# this should never happen, but just in case, we should keep the entry
|
||||
# in the cache until we have a valid last_observer_removed_time_ms to
|
||||
# compare against.
|
||||
continue
|
||||
|
||||
if now - entry.last_observer_removed_time_ms > self.timeout.as_millis():
|
||||
self._metrics.inc_evictions(EvictionReason.time)
|
||||
self._result_cache.pop(key, None)
|
||||
try:
|
||||
entry.result.cancel()
|
||||
except Exception:
|
||||
# we ignore exceptions from cancel, as it is best effort anyway.
|
||||
pass
|
||||
|
||||
@@ -62,6 +62,16 @@ this setting won't inherit the log level from the parent logger.
|
||||
logging.setLoggerClass(original_logger_class)
|
||||
|
||||
|
||||
def _try_wakeup_deferred(d: Deferred) -> None:
|
||||
"""Try to wake up a deferred, but ignore any exceptions raised by the
|
||||
callback. This is useful when we want to wake up a deferred that may have
|
||||
already been cancelled, and we don't care about the result."""
|
||||
try:
|
||||
d.callback(None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class Clock:
|
||||
"""
|
||||
A Clock wraps a Twisted reactor and provides utilities on top of it.
|
||||
@@ -114,7 +124,11 @@ class Clock:
|
||||
with context.PreserveLoggingContext():
|
||||
# We can ignore the lint here since this class is the one location callLater should
|
||||
# be called.
|
||||
self._reactor.callLater(duration.as_secs(), d.callback, duration.as_secs()) # type: ignore[call-later-not-tracked]
|
||||
self._reactor.callLater(
|
||||
duration.as_secs(),
|
||||
lambda _: _try_wakeup_deferred(d),
|
||||
duration.as_secs(),
|
||||
) # type: ignore[call-later-not-tracked]
|
||||
await d
|
||||
|
||||
def time(self) -> float:
|
||||
|
||||
@@ -23,6 +23,8 @@ from typing import Generic, Hashable, TypeVar
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T", bound=Hashable)
|
||||
@@ -39,13 +41,13 @@ class WheelTimer(Generic[T]):
|
||||
expired.
|
||||
"""
|
||||
|
||||
def __init__(self, bucket_size: int = 5000) -> None:
|
||||
def __init__(self, bucket_size: Duration = Duration(seconds=5)) -> None:
|
||||
"""
|
||||
Args:
|
||||
bucket_size: Size of buckets in ms. Corresponds roughly to the
|
||||
accuracy of the timer.
|
||||
"""
|
||||
self.bucket_size: int = bucket_size
|
||||
self.bucket_size = bucket_size
|
||||
self.entries: list[_Entry[T]] = []
|
||||
|
||||
def insert(self, now: int, obj: T, then: int) -> None:
|
||||
@@ -56,8 +58,8 @@ class WheelTimer(Generic[T]):
|
||||
obj: Object to be inserted
|
||||
then: When to return the object strictly after.
|
||||
"""
|
||||
then_key = int(then / self.bucket_size) + 1
|
||||
now_key = int(now / self.bucket_size)
|
||||
then_key = int(then / self.bucket_size.as_millis()) + 1
|
||||
now_key = int(now / self.bucket_size.as_millis())
|
||||
|
||||
if self.entries:
|
||||
min_key = self.entries[0].end_key
|
||||
@@ -100,7 +102,7 @@ class WheelTimer(Generic[T]):
|
||||
Returns:
|
||||
List of objects that have timed out
|
||||
"""
|
||||
now_key = int(now / self.bucket_size)
|
||||
now_key = int(now / self.bucket_size.as_millis())
|
||||
|
||||
ret: list[T] = []
|
||||
while self.entries and self.entries[0].end_key <= now_key:
|
||||
|
||||
@@ -41,6 +41,7 @@ from tests import unittest
|
||||
from tests.federation.transport.test_knocking import (
|
||||
KnockingStrippedStateEventHelperMixin,
|
||||
)
|
||||
from tests.rest.client.test_rooms import make_request_with_cancellation_test
|
||||
from tests.server import TimedOutException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1145,3 +1146,65 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
|
||||
self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
|
||||
|
||||
|
||||
class SyncCancellationTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def test_initial_sync(self) -> None:
|
||||
"""Tests that an initial sync request can be cancelled."""
|
||||
user_id = self.register_user("user", "password")
|
||||
tok = self.login("user", "password")
|
||||
|
||||
# Populate the account with a few rooms
|
||||
for _ in range(5):
|
||||
room_id = self.helper.create_room_as(user_id, tok=tok)
|
||||
self.helper.send(room_id, tok=tok)
|
||||
|
||||
channel = make_request_with_cancellation_test(
|
||||
"test_initial_sync",
|
||||
self.reactor,
|
||||
self.site,
|
||||
"GET",
|
||||
"/_matrix/client/v3/sync",
|
||||
token=tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
|
||||
def test_incremental_sync(self) -> None:
|
||||
"""Tests that an incremental sync request can be cancelled."""
|
||||
user_id = self.register_user("user", "password")
|
||||
tok = self.login("user", "password")
|
||||
|
||||
# Populate the account with a few rooms
|
||||
room_ids = []
|
||||
for _ in range(5):
|
||||
room_id = self.helper.create_room_as(user_id, tok=tok)
|
||||
self.helper.send(room_id, tok=tok)
|
||||
room_ids.append(room_id)
|
||||
|
||||
# Do an initial sync to get a since token.
|
||||
channel = self.make_request("GET", "/sync", access_token=tok)
|
||||
self.assertEqual(200, channel.code, msg=channel.result)
|
||||
since = channel.json_body["next_batch"]
|
||||
|
||||
# Send some more messages to generate activity in the rooms.
|
||||
for room_id in room_ids:
|
||||
self.helper.send(room_id, tok=tok)
|
||||
|
||||
channel = make_request_with_cancellation_test(
|
||||
"test_incremental_sync",
|
||||
self.reactor,
|
||||
self.site,
|
||||
"GET",
|
||||
f"/_matrix/client/v3/sync?since={since}&timeout=10000",
|
||||
token=tok,
|
||||
)
|
||||
|
||||
self.assertEqual(200, channel.code, msg=channel.result["body"])
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#
|
||||
#
|
||||
|
||||
from functools import wraps
|
||||
from unittest.mock import Mock
|
||||
|
||||
from parameterized import parameterized
|
||||
@@ -26,6 +27,7 @@ from parameterized import parameterized
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
|
||||
from synapse.util.cancellation import cancellable
|
||||
from synapse.util.duration import Duration
|
||||
|
||||
from tests.server import get_clock
|
||||
@@ -48,15 +50,23 @@ class ResponseCacheTestCase(TestCase):
|
||||
|
||||
def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
|
||||
return ResponseCache(
|
||||
clock=self.clock, name=name, server_name="test_server", timeout_ms=ms
|
||||
clock=self.clock,
|
||||
name=name,
|
||||
server_name="test_server",
|
||||
timeout=Duration(milliseconds=ms),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def instant_return(o: str) -> str:
|
||||
return o
|
||||
|
||||
async def delayed_return(self, o: str) -> str:
|
||||
await self.clock.sleep(Duration(seconds=1))
|
||||
@cancellable
|
||||
async def delayed_return(
|
||||
self,
|
||||
o: str,
|
||||
duration: Duration = Duration(seconds=1), # noqa
|
||||
) -> str:
|
||||
await self.clock.sleep(duration)
|
||||
return o
|
||||
|
||||
def test_cache_hit(self) -> None:
|
||||
@@ -223,3 +233,332 @@ class ResponseCacheTestCase(TestCase):
|
||||
self.assertCountEqual(
|
||||
[], cache.keys(), "cache should not have the result now"
|
||||
)
|
||||
|
||||
def test_cache_func_errors(self) -> None:
|
||||
"""If the callback raises an error, the error should be raised to all
|
||||
callers and the result should not be cached"""
|
||||
cache = self.with_cache("error_cache", ms=3000)
|
||||
|
||||
expected_error = Exception("oh no")
|
||||
|
||||
async def erring(o: str) -> str:
|
||||
await self.clock.sleep(Duration(seconds=1))
|
||||
raise expected_error
|
||||
|
||||
wrap_d = defer.ensureDeferred(cache.wrap(0, erring, "ignored"))
|
||||
self.assertNoResult(wrap_d)
|
||||
|
||||
# a second call should also return a pending deferred
|
||||
wrap2_d = defer.ensureDeferred(cache.wrap(0, erring, "ignored"))
|
||||
self.assertNoResult(wrap2_d)
|
||||
|
||||
# let the call complete
|
||||
self.reactor.advance(1)
|
||||
|
||||
# both results should have completed with the error
|
||||
self.assertFailure(wrap_d, Exception)
|
||||
self.assertFailure(wrap2_d, Exception)
|
||||
|
||||
def test_cache_cancel_first_wait(self) -> None:
|
||||
"""Test that cancellation of the deferred returned by wrap() on the
|
||||
first call does not immediately cause a cancellation error to be raised
|
||||
when its cancelled and the wrapped function continues execution (unless
|
||||
it times out).
|
||||
"""
|
||||
cache = self.with_cache("cancel_cache", ms=3000)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
||||
wrap_d = defer.ensureDeferred(
|
||||
cache.wrap(0, self.delayed_return, expected_result)
|
||||
)
|
||||
|
||||
# cancel the deferred before it has a chance to return
|
||||
wrap_d.cancel()
|
||||
|
||||
# The cancel should be ignored for now, and the inner function should
|
||||
# still be running.
|
||||
self.assertNoResult(wrap_d)
|
||||
|
||||
# Advance the clock until the inner function should have returned, but
|
||||
# not long enough for the cache entry to have expired.
|
||||
self.reactor.advance(2)
|
||||
|
||||
# The deferred we're waiting on should now return a cancelled error.
|
||||
self.assertFailure(wrap_d, defer.CancelledError)
|
||||
|
||||
# However future callers should get the result.
|
||||
wrap_d2 = defer.ensureDeferred(
|
||||
cache.wrap(0, self.delayed_return, expected_result)
|
||||
)
|
||||
self.assertEqual(expected_result, self.successResultOf(wrap_d2))
|
||||
|
||||
def test_cache_cancel_first_wait_expire(self) -> None:
|
||||
"""Test that cancellation of the deferred returned by wrap() and the
|
||||
entry expiring before the wrapped function returns.
|
||||
|
||||
The wrapped function should be cancelled.
|
||||
"""
|
||||
cache = self.with_cache("cancel_expire_cache", ms=300)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
||||
# Wrap the function so that we can keep track of when it completes or
|
||||
# errors.
|
||||
completed = False
|
||||
cancelled = False
|
||||
|
||||
@wraps(self.delayed_return)
|
||||
async def wrapped(o: str) -> str:
|
||||
nonlocal completed, cancelled
|
||||
|
||||
try:
|
||||
return await self.delayed_return(o)
|
||||
except defer.CancelledError:
|
||||
cancelled = True
|
||||
raise
|
||||
finally:
|
||||
completed = True
|
||||
|
||||
wrap_d = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
|
||||
|
||||
# cancel the deferred before it has a chance to return
|
||||
wrap_d.cancel()
|
||||
|
||||
# The cancel should be ignored for now, and the inner function should
|
||||
# still be running.
|
||||
self.assertNoResult(wrap_d)
|
||||
self.assertFalse(completed, "wrapped function should not have completed yet")
|
||||
|
||||
# Advance the clock until the cache entry should have expired, but not
|
||||
# long enough for the inner function to have returned.
|
||||
self.reactor.advance(0.7)
|
||||
|
||||
# The deferred we're waiting on should now return a cancelled error.
|
||||
self.assertFailure(wrap_d, defer.CancelledError)
|
||||
self.assertTrue(completed, "wrapped function should have completed")
|
||||
self.assertTrue(cancelled, "wrapped function should have been cancelled")
|
||||
|
||||
def test_cache_cancel_first_wait_other_observers(self) -> None:
|
||||
"""Test that cancellation of the deferred returned by wrap() does not
|
||||
cause a cancellation error to be raised if there are other observers
|
||||
still waiting on the result.
|
||||
"""
|
||||
cache = self.with_cache("cancel_other_cache", ms=300)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
||||
# Wrap the function so that we can keep track of when it completes or
|
||||
# errors.
|
||||
completed = False
|
||||
cancelled = False
|
||||
|
||||
@wraps(self.delayed_return)
|
||||
async def wrapped(o: str) -> str:
|
||||
nonlocal completed, cancelled
|
||||
|
||||
try:
|
||||
return await self.delayed_return(o)
|
||||
except defer.CancelledError:
|
||||
cancelled = True
|
||||
raise
|
||||
finally:
|
||||
completed = True
|
||||
|
||||
wrap_d1 = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
|
||||
wrap_d2 = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
|
||||
|
||||
# cancel the first deferred before it has a chance to return
|
||||
wrap_d1.cancel()
|
||||
|
||||
# The cancel should be ignored for now, and the inner function should
|
||||
# still be running.
|
||||
self.assertNoResult(wrap_d1)
|
||||
self.assertNoResult(wrap_d2)
|
||||
self.assertFalse(completed, "wrapped function should not have completed yet")
|
||||
|
||||
# Advance the clock until the cache entry should have expired, but not
|
||||
# long enough for the inner function to have returned.
|
||||
self.reactor.advance(0.7)
|
||||
|
||||
# Neither deferred should have returned yet, since the inner function
|
||||
# should still be running.
|
||||
self.assertNoResult(wrap_d1)
|
||||
self.assertNoResult(wrap_d2)
|
||||
self.assertFalse(completed, "wrapped function should not have completed yet")
|
||||
|
||||
# Now advance the clock until the inner function should have returned.
|
||||
self.reactor.advance(2.5)
|
||||
|
||||
# The wrapped function should have completed without cancellation.
|
||||
self.assertTrue(completed, "wrapped function should have completed")
|
||||
self.assertFalse(cancelled, "wrapped function should not have been cancelled")
|
||||
|
||||
# The first deferred we're waiting on should now return a cancelled error.
|
||||
self.assertFailure(wrap_d1, defer.CancelledError)
|
||||
|
||||
# The second deferred should return the result.
|
||||
self.assertEqual(expected_result, self.successResultOf(wrap_d2))
|
||||
|
||||
def test_cache_add_and_cancel(self) -> None:
|
||||
"""Test that waiting on the cache and cancelling repeatedly keeps the
|
||||
cache entry alive.
|
||||
"""
|
||||
cache = self.with_cache("cancel_add_cache", ms=300)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
||||
# Wrap the function so that we can keep track of when it completes or
|
||||
# errors.
|
||||
completed = False
|
||||
cancelled = False
|
||||
|
||||
@wraps(self.delayed_return)
|
||||
async def wrapped(o: str) -> str:
|
||||
nonlocal completed, cancelled
|
||||
|
||||
try:
|
||||
return await self.delayed_return(o)
|
||||
except defer.CancelledError:
|
||||
cancelled = True
|
||||
raise
|
||||
finally:
|
||||
completed = True
|
||||
|
||||
# Repeatedly await for the result and cancel it, which should keep the
|
||||
# cache entry alive even though the total time exceeds the cache
|
||||
# timeout.
|
||||
deferreds = []
|
||||
for _ in range(8):
|
||||
# Await the deferred.
|
||||
wrap_d = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
|
||||
|
||||
# cancel the deferred before it has a chance to return
|
||||
self.reactor.advance(0.05)
|
||||
wrap_d.cancel()
|
||||
deferreds.append(wrap_d)
|
||||
|
||||
# The cancel should not cause the inner function to be cancelled
|
||||
# yet.
|
||||
self.assertFalse(
|
||||
completed, "wrapped function should not have completed yet"
|
||||
)
|
||||
self.assertFalse(
|
||||
cancelled, "wrapped function should not have been cancelled yet"
|
||||
)
|
||||
|
||||
# Advance the clock until the cache entry should have expired, but not
|
||||
# long enough for the inner function to have returned.
|
||||
self.reactor.advance(0.05)
|
||||
|
||||
# Now advance the clock until the inner function should have returned.
|
||||
self.reactor.advance(0.2)
|
||||
|
||||
# All the deferreds we're waiting on should now return a cancelled error.
|
||||
for wrap_d in deferreds:
|
||||
self.assertFailure(wrap_d, defer.CancelledError)
|
||||
|
||||
# The wrapped function should have completed without cancellation.
|
||||
self.assertTrue(completed, "wrapped function should have completed")
|
||||
self.assertFalse(cancelled, "wrapped function should not have been cancelled")
|
||||
|
||||
# Querying the cache should return the completed result
|
||||
wrap_d = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
|
||||
self.assertEqual(expected_result, self.successResultOf(wrap_d))
|
||||
|
||||
def test_cache_cancel_non_cancellable(self) -> None:
|
||||
"""Test that cancellation of the deferred returned by wrap() on a
|
||||
non-cancellable entry does not cause a cancellation error to be raised
|
||||
when it's cancelled and the wrapped function continues execution.
|
||||
"""
|
||||
cache = self.with_cache("cancel_non_cancellable_cache", ms=300)
|
||||
|
||||
expected_result = "howdy"
|
||||
|
||||
# Wrap the function so that we can keep track of when it completes or
|
||||
# errors.
|
||||
completed = False
|
||||
cancelled = False
|
||||
|
||||
async def wrapped(o: str) -> str:
|
||||
nonlocal completed, cancelled
|
||||
|
||||
try:
|
||||
return await self.delayed_return(o)
|
||||
except defer.CancelledError:
|
||||
cancelled = True
|
||||
raise
|
||||
finally:
|
||||
completed = True
|
||||
|
||||
wrap_d = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
|
||||
|
||||
# cancel the deferred before it has a chance to return
|
||||
wrap_d.cancel()
|
||||
|
||||
# The cancel should be ignored for now, and the inner function should
|
||||
# still be running.
|
||||
self.assertNoResult(wrap_d)
|
||||
self.assertFalse(completed, "wrapped function should not have completed yet")
|
||||
|
||||
# Advance the clock until the inner function should have returned, but
|
||||
# not long enough for the cache entry to have expired.
|
||||
self.reactor.advance(2)
|
||||
|
||||
# The deferred we're waiting on should be cancelled, but a new call to
|
||||
# the cache should return the result.
|
||||
self.assertFailure(wrap_d, defer.CancelledError)
|
||||
wrap_d2 = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
|
||||
self.assertEqual(expected_result, self.successResultOf(wrap_d2))
|
||||
|
||||
def test_cache_cancel_then_error(self) -> None:
|
||||
"""Test that cancellation of the deferred returned by wrap() that then
|
||||
subsequently errors is correctly propagated to a second caller.
|
||||
"""
|
||||
|
||||
cache = self.with_cache("cancel_then_error_cache", ms=3000)
|
||||
|
||||
expected_error = Exception("oh no")
|
||||
|
||||
# Wrap the function so that we can keep track of when it completes or
|
||||
# errors.
|
||||
completed = False
|
||||
cancelled = False
|
||||
|
||||
@wraps(self.delayed_return)
|
||||
async def wrapped(o: str) -> str:
|
||||
nonlocal completed, cancelled
|
||||
|
||||
try:
|
||||
await self.delayed_return(o)
|
||||
raise expected_error
|
||||
except defer.CancelledError:
|
||||
cancelled = True
|
||||
raise
|
||||
finally:
|
||||
completed = True
|
||||
|
||||
wrap_d1 = defer.ensureDeferred(cache.wrap(0, wrapped, "ignored"))
|
||||
wrap_d2 = defer.ensureDeferred(cache.wrap(0, wrapped, "ignored"))
|
||||
|
||||
# cancel the first deferred before it has a chance to return
|
||||
wrap_d1.cancel()
|
||||
|
||||
# The cancel should be ignored for now, and the inner function should
|
||||
# still be running.
|
||||
self.assertNoResult(wrap_d1)
|
||||
self.assertNoResult(wrap_d2)
|
||||
self.assertFalse(completed, "wrapped function should not have completed yet")
|
||||
|
||||
# Advance the clock until the inner function should have returned.
|
||||
self.reactor.advance(2)
|
||||
|
||||
# The wrapped function should have completed with an error without cancellation.
|
||||
self.assertTrue(completed, "wrapped function should have completed")
|
||||
self.assertFalse(cancelled, "wrapped function should not have been cancelled")
|
||||
|
||||
# The first deferred we're waiting on should now return a cancelled error.
|
||||
self.assertFailure(wrap_d1, defer.CancelledError)
|
||||
|
||||
# The second deferred should return the error.
|
||||
self.assertFailure(wrap_d2, Exception)
|
||||
|
||||
@@ -120,7 +120,7 @@ class ObservableDeferredTest(TestCase):
|
||||
assert results[1] is not None
|
||||
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
|
||||
|
||||
def test_cancellation(self) -> None:
|
||||
def test_cancellation_observer(self) -> None:
|
||||
"""Test that cancelling an observer does not affect other observers."""
|
||||
origin_d: "Deferred[int]" = Deferred()
|
||||
observable = ObservableDeferred(origin_d, consumeErrors=True)
|
||||
@@ -138,6 +138,10 @@ class ObservableDeferredTest(TestCase):
|
||||
self.assertFalse(observer1.called)
|
||||
self.failureResultOf(observer2, CancelledError)
|
||||
self.assertFalse(observer3.called)
|
||||
# check that we remove the cancelled observer from the list of observers
|
||||
# as a clean up.
|
||||
self.assertEqual(len(observable.observers()), 2)
|
||||
self.assertNotIn(observer2, observable.observers())
|
||||
|
||||
# other observers resolve as normal
|
||||
origin_d.callback(123)
|
||||
@@ -148,6 +152,22 @@ class ObservableDeferredTest(TestCase):
|
||||
observer4 = observable.observe()
|
||||
self.assertEqual(observer4.result, 123, "observer 4 callback result")
|
||||
|
||||
def test_cancellation_observee(self) -> None:
|
||||
"""Test that cancelling the original deferred cancels all observers."""
|
||||
origin_d: "Deferred[int]" = Deferred()
|
||||
observable = ObservableDeferred(origin_d, consumeErrors=True)
|
||||
|
||||
observer1 = observable.observe()
|
||||
observer2 = observable.observe()
|
||||
|
||||
self.assertFalse(observer1.called)
|
||||
self.assertFalse(observer2.called)
|
||||
|
||||
# cancel the original deferred
|
||||
origin_d.cancel()
|
||||
self.failureResultOf(observer1, CancelledError)
|
||||
self.failureResultOf(observer2, CancelledError)
|
||||
|
||||
|
||||
class TimeoutDeferredTest(TestCase):
|
||||
def setUp(self) -> None:
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#
|
||||
#
|
||||
|
||||
from synapse.util.duration import Duration
|
||||
from synapse.util.wheel_timer import WheelTimer
|
||||
|
||||
from .. import unittest
|
||||
@@ -26,7 +27,7 @@ from .. import unittest
|
||||
|
||||
class WheelTimerTestCase(unittest.TestCase):
|
||||
def test_single_insert_fetch(self) -> None:
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5))
|
||||
|
||||
wheel.insert(100, "1", 150)
|
||||
|
||||
@@ -39,7 +40,7 @@ class WheelTimerTestCase(unittest.TestCase):
|
||||
self.assertListEqual(wheel.fetch(170), [])
|
||||
|
||||
def test_multi_insert(self) -> None:
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5))
|
||||
|
||||
wheel.insert(100, "1", 150)
|
||||
wheel.insert(105, "2", 130)
|
||||
@@ -54,13 +55,13 @@ class WheelTimerTestCase(unittest.TestCase):
|
||||
self.assertListEqual(wheel.fetch(210), [])
|
||||
|
||||
def test_insert_past(self) -> None:
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5))
|
||||
|
||||
wheel.insert(100, "1", 50)
|
||||
self.assertListEqual(wheel.fetch(120), ["1"])
|
||||
|
||||
def test_insert_past_multi(self) -> None:
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5))
|
||||
|
||||
wheel.insert(100, "1", 150)
|
||||
wheel.insert(100, "2", 140)
|
||||
@@ -72,7 +73,7 @@ class WheelTimerTestCase(unittest.TestCase):
|
||||
self.assertListEqual(wheel.fetch(240), [])
|
||||
|
||||
def test_multi_insert_then_past(self) -> None:
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
|
||||
wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5))
|
||||
|
||||
wheel.insert(100, "1", 150)
|
||||
wheel.insert(100, "2", 160)
|
||||
|
||||
Reference in New Issue
Block a user