From 2c73e8daef1c50abe9741f1fa321a34f9fdc5909 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 26 Feb 2026 22:41:06 +0100 Subject: [PATCH] Allow long lived syncs to be cancelled if client has gone away (#19499) --- changelog.d/19499.misc | 1 + pyproject.toml | 6 + synapse/api/auth/mas.py | 3 +- synapse/api/auth/msc3861_delegated.py | 3 +- synapse/appservice/api.py | 3 +- synapse/config/cache.py | 9 +- synapse/federation/federation_server.py | 6 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_list.py | 3 +- synapse/handlers/sync.py | 7 +- synapse/handlers/typing.py | 4 +- synapse/replication/http/_base.py | 2 +- synapse/rest/client/sync.py | 2 + synapse/util/async_helpers.py | 70 ++++- synapse/util/caches/response_cache.py | 152 +++++++++- synapse/util/clock.py | 16 +- synapse/util/wheel_timer.py | 12 +- tests/rest/client/test_sync.py | 63 +++++ tests/util/caches/test_response_cache.py | 345 ++++++++++++++++++++++- tests/util/test_async_helpers.py | 22 +- tests/util/test_wheel_timer.py | 11 +- 21 files changed, 703 insertions(+), 39 deletions(-) create mode 100644 changelog.d/19499.misc diff --git a/changelog.d/19499.misc b/changelog.d/19499.misc new file mode 100644 index 0000000000..c2641c1e58 --- /dev/null +++ b/changelog.d/19499.misc @@ -0,0 +1 @@ +Cancel long-running sync requests if the client has gone away. diff --git a/pyproject.toml b/pyproject.toml index 0b8dff9058..2d55d3c7a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -411,6 +411,12 @@ indent-style = "space" skip-magic-trailing-comma = false line-ending = "auto" +[tool.ruff.lint.flake8-bugbear] +extend-immutable-calls = [ + # Durations are immutable + "synapse.util.duration.Duration", +] + [tool.maturin] manifest-path = "rust/Cargo.toml" module-name = "synapse.synapse_rust" diff --git a/synapse/api/auth/mas.py b/synapse/api/auth/mas.py index c4cca3723c..79c15a5329 100644 --- a/synapse/api/auth/mas.py +++ b/synapse/api/auth/mas.py @@ -45,6 +45,7 @@ from synapse.synapse_rust.http_client import HttpClient from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util.caches.cached_call import RetryOnExceptionCachedCall from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext +from synapse.util.duration import Duration from synapse.util.json import json_decoder from . import introspection_response_timer @@ -139,7 +140,7 @@ class MasDelegatedAuth(BaseAuth): clock=self._clock, name="mas_token_introspection", server_name=self.server_name, - timeout_ms=120_000, + timeout=Duration(minutes=2), # don't log because the keys are access tokens enable_logging=False, ) diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py index 7999d6e459..27ab4af805 100644 --- a/synapse/api/auth/msc3861_delegated.py +++ b/synapse/api/auth/msc3861_delegated.py @@ -49,6 +49,7 @@ from synapse.synapse_rust.http_client import HttpClient from synapse.types import Requester, UserID, create_requester from synapse.util.caches.cached_call import RetryOnExceptionCachedCall from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext +from synapse.util.duration import Duration from synapse.util.json import json_decoder from . import introspection_response_timer @@ -205,7 +206,7 @@ class MSC3861DelegatedAuth(BaseAuth): clock=self._clock, name="token_introspection", server_name=self.server_name, - timeout_ms=120_000, + timeout=Duration(minutes=2), # don't log because the keys are access tokens enable_logging=False, ) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 71094de9be..2bbf77a352 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -46,6 +46,7 @@ from synapse.logging import opentracing from synapse.metrics import SERVER_NAME_LABEL from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache +from synapse.util.duration import Duration if TYPE_CHECKING: from synapse.server import HomeServer @@ -132,7 +133,7 @@ class ApplicationServiceApi(SimpleHttpClient): clock=hs.get_clock(), name="as_protocol_meta", server_name=self.server_name, - timeout_ms=HOUR_IN_MS, + timeout=Duration(hours=1), ) def _get_headers(self, service: "ApplicationService") -> dict[bytes, list[bytes]]: diff --git a/synapse/config/cache.py b/synapse/config/cache.py index c9ce826e1a..ac94b17ff6 100644 --- a/synapse/config/cache.py +++ b/synapse/config/cache.py @@ -29,6 +29,7 @@ import attr from synapse.types import JsonDict from synapse.util.check_dependencies import check_requirements +from synapse.util.duration import Duration from ._base import Config, ConfigError @@ -108,7 +109,7 @@ class CacheConfig(Config): global_factor: float track_memory_usage: bool expiry_time_msec: int | None - sync_response_cache_duration: int + sync_response_cache_duration: Duration @staticmethod def reset() -> None: @@ -207,10 +208,14 @@ class CacheConfig(Config): min_cache_ttl = self.cache_autotuning.get("min_cache_ttl") self.cache_autotuning["min_cache_ttl"] = self.parse_duration(min_cache_ttl) - self.sync_response_cache_duration = self.parse_duration( + sync_response_cache_duration_ms = self.parse_duration( cache_config.get("sync_response_cache_duration", "2m") ) + self.sync_response_cache_duration = Duration( + milliseconds=sync_response_cache_duration_ms + ) + def resize_all_caches(self) -> None: """Ensure all cache sizes are up-to-date. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index b909f1e595..1912d545f5 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -166,7 +166,7 @@ class FederationServer(FederationBase): clock=hs.get_clock(), name="fed_txn_handler", server_name=self.server_name, - timeout_ms=30000, + timeout=Duration(seconds=30), ) self.transaction_actions = TransactionActions(self.store) @@ -179,13 +179,13 @@ class FederationServer(FederationBase): clock=hs.get_clock(), name="state_resp", server_name=self.server_name, - timeout_ms=30000, + timeout=Duration(seconds=30), ) self._state_ids_resp_cache: ResponseCache[tuple[str, str]] = ResponseCache( clock=hs.get_clock(), name="state_ids_resp", server_name=self.server_name, - timeout_ms=30000, + timeout=Duration(seconds=30), ) self._federation_metrics_domains = ( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e03a912319..1c3489a00e 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -185,7 +185,7 @@ class RoomCreationHandler: clock=hs.get_clock(), name="room_upgrade", server_name=self.server_name, - timeout_ms=FIVE_MINUTES_IN_MS, + timeout=Duration(minutes=5), ) self._server_notices_mxid = hs.config.servernotices.server_notices_mxid diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 6377931b39..b25fd0a1e7 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -44,6 +44,7 @@ from synapse.storage.databases.main.room import LargestRoomStats from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache +from synapse.util.duration import Duration if TYPE_CHECKING: from synapse.server import HomeServer @@ -79,7 +80,7 @@ class RoomListHandler: clock=hs.get_clock(), name="remote_room_list", server_name=self.server_name, - timeout_ms=30 * 1000, + timeout=Duration(seconds=30), ) async def get_local_public_room_list( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 72e91d66ac..2f405004de 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -77,6 +77,7 @@ from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.lrucache import LruCache from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext +from synapse.util.cancellation import cancellable from synapse.util.metrics import Measure from synapse.visibility import filter_and_transform_events_for_client @@ -307,7 +308,7 @@ class SyncHandler: clock=hs.get_clock(), name="sync", server_name=self.server_name, - timeout_ms=hs.config.caches.sync_response_cache_duration, + timeout=hs.config.caches.sync_response_cache_duration, ) # ExpiringCache((User, Device)) -> LruCache(user_id => event_id) @@ -367,6 +368,10 @@ class SyncHandler: logger.debug("Returning sync response for %s", user_id) return res + # TODO: We mark this as cancellable, and we have tests for it, but we + # haven't gone through and exhaustively checked that all the code paths in + # this method are actually cancellable. + @cancellable async def _wait_for_sync_for_user( self, sync_config: SyncConfig, diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e66396fecc..6daf304432 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -102,7 +102,7 @@ class FollowerTypingHandler: self._room_typing: dict[str, set[str]] = {} self._member_last_federation_poke: dict[RoomMember, int] = {} - self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000) + self.wheel_timer: WheelTimer[RoomMember] = WheelTimer() self._latest_room_serial = 0 self._rooms_updated: set[str] = set() @@ -120,7 +120,7 @@ class FollowerTypingHandler: self._rooms_updated = set() self._member_last_federation_poke = {} - self.wheel_timer = WheelTimer(bucket_size=5000) + self.wheel_timer = WheelTimer() @wrap_as_background_process("typing._handle_timeouts") async def _handle_timeouts(self) -> None: diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 2bab9c2d71..87d6e80898 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -130,7 +130,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): clock=hs.get_clock(), name="repl." + self.NAME, server_name=self.server_name, - timeout_ms=30 * 60 * 1000, + timeout=Duration(minutes=30), ) # We reserve `instance_name` as a parameter to sending requests, so we diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 458bf08a19..91f2f16771 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -59,6 +59,7 @@ from synapse.rest.admin.experimental_features import ExperimentalFeature from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken from synapse.types.rest.client import SlidingSyncBody from synapse.util.caches.lrucache import LruCache +from synapse.util.cancellation import cancellable from synapse.util.json import json_decoder from ._base import client_patterns, set_timeline_upper_limit @@ -138,6 +139,7 @@ class SyncRestServlet(RestServlet): cfg=hs.config.ratelimiting.rc_presence_per_user, ) + @cancellable async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]: # This will always be set by the time Twisted calls us. assert request.args is not None diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 818f8b1a69..53fb24ec5b 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -57,6 +57,7 @@ from synapse.logging.context import ( run_coroutine_in_background, run_in_background, ) +from synapse.util.cancellation import cancellable from synapse.util.clock import Clock from synapse.util.duration import Duration @@ -83,6 +84,13 @@ class AbstractObservableDeferred(Generic[_T], metaclass=abc.ABCMeta): """ ... + @abc.abstractmethod + def has_observers(self) -> bool: + """Returns True if there are any observers currently observing this + ObservableDeferred. + """ + ... + class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): """Wraps a deferred object so that we can add observer deferreds. These @@ -122,6 +130,11 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): for observer in observers: try: observer.callback(r) + except defer.CancelledError: + # We do not want to propagate cancellations to the original + # deferred, or to other observers, so we can just ignore + # this. + pass except Exception as e: logger.exception( "%r threw an exception on .callback(%r), ignoring...", @@ -145,6 +158,11 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): f.value.__failure__ = f try: observer.errback(f) + except defer.CancelledError: + # We do not want to propagate cancellations to the original + # deferred, or to other observers, so we can just ignore + # this. + pass except Exception as e: logger.exception( "%r threw an exception on .errback(%r), ignoring...", @@ -160,6 +178,7 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): deferred.addCallbacks(callback, errback) + @cancellable def observe(self) -> "defer.Deferred[_T]": """Observe the underlying deferred. @@ -169,7 +188,7 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): """ if not self._result: assert isinstance(self._observers, list) - d: "defer.Deferred[_T]" = defer.Deferred() + d: "defer.Deferred[_T]" = defer.Deferred(canceller=self._remove_observer) self._observers.append(d) return d elif self._result[0]: @@ -180,6 +199,12 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): def observers(self) -> "Collection[defer.Deferred[_T]]": return self._observers + def has_observers(self) -> bool: + """Returns True if there are any observers currently observing this + ObservableDeferred. + """ + return bool(self._observers) + def has_called(self) -> bool: return self._result is not None @@ -204,6 +229,28 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]): self._deferred, ) + def _remove_observer(self, observer: "defer.Deferred[_T]") -> None: + """Removes an observer from the list of observers. + + Used as a canceller for the observer deferreds, so that if an observer + is cancelled it is removed from the list of observers. + """ + if self._result is not None: + # The underlying deferred has already resolved, so the observer has + # already been resolved. Nothing to do. + return + + assert isinstance(self._observers, list) + try: + self._observers.remove(observer) + except ValueError: + # The observer was not in the list. This can happen if the underlying + # deferred resolves at around the same time as we try to remove the + # observer. In this case, it's possible that we tried to remove the + # observer just after it was added to the list, but before it was + # resolved and removed from the list by the callback/errback above. + pass + T = TypeVar("T") @@ -962,6 +1009,27 @@ def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]: return new_deferred +def observe_deferred(d: "defer.Deferred[T]") -> "defer.Deferred[T]": + """Returns a new `Deferred` that observes the given `Deferred`. + + The returned `Deferred` will resolve with the same result as the given + `Deferred`, but will not "chain" on the deferred so that using the returned + deferred does not affect the given `Deferred` in any way. + """ + new_deferred: "defer.Deferred[T]" = defer.Deferred() + + def callback(r: T) -> T: + new_deferred.callback(r) + return r + + def errback(f: Failure) -> Failure: + new_deferred.errback(f) + return f + + d.addCallbacks(callback, errback) + return new_deferred + + class AwakenableSleeper: """Allows explicitly waking up deferreds related to an entity that are currently sleeping. diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 0289e13f6a..70cea9b77c 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -39,10 +39,15 @@ from synapse.logging.opentracing import ( start_active_span, start_active_span_follows_from, ) -from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred +from synapse.util.async_helpers import ( + ObservableDeferred, + delay_cancellation, +) from synapse.util.caches import EvictionReason, register_cache +from synapse.util.cancellation import cancellable, is_function_cancellable from synapse.util.clock import Clock from synapse.util.duration import Duration +from synapse.util.wheel_timer import WheelTimer logger = logging.getLogger(__name__) @@ -79,8 +84,8 @@ class ResponseCacheContext(Generic[KV]): @attr.s(auto_attribs=True) -class ResponseCacheEntry: - result: AbstractObservableDeferred +class ResponseCacheEntry(Generic[KV]): + result: ObservableDeferred[KV] """The (possibly incomplete) result of the operation. Note that we continue to store an ObservableDeferred even after the operation @@ -91,6 +96,15 @@ class ResponseCacheEntry: opentracing_span_context: "opentracing.SpanContext | None" """The opentracing span which generated/is generating the result""" + cancellable: bool + """Whether the deferred is safe to be cancelled.""" + + last_observer_removed_time_ms: int | None = None + """The last time that an observer was removed from this entry. + + Used to determine when to evict the entry if it has no observers. + """ + class ResponseCache(Generic[KV]): """ @@ -98,6 +112,22 @@ class ResponseCache(Generic[KV]): returned from the cache. This means that if the client retries the request while the response is still being computed, that original response will be used rather than trying to compute a new response. + + If a timeout is not specified then the cache entry will be kept while the + wrapped function is still running, and will be removed immediately once it + completes. + + If a timeout is specified then the cache entry will be kept for the duration + of the timeout after the wrapped function completes. If the wrapped function + is cancellable and during processing nothing waits on the result for longer + than the timeout then the wrapped function will be cancelled and the cache + entry will be removed. + + This behaviour is useful for caching responses to requests which are + expensive to compute, but which may be retried by clients if they time out. + For example, /sync requests which may take a long time to compute, and which + clients will retry. However, if the client stops retrying for a while then + we want to stop processing the request and free up the resources. """ def __init__( @@ -106,7 +136,7 @@ class ResponseCache(Generic[KV]): clock: Clock, name: str, server_name: str, - timeout_ms: float = 0, + timeout: Duration | None = None, enable_logging: bool = True, ): """ @@ -121,7 +151,7 @@ class ResponseCache(Generic[KV]): self._result_cache: dict[KV, ResponseCacheEntry] = {} self.clock = clock - self.timeout = Duration(milliseconds=timeout_ms) + self.timeout = timeout self._name = name self._metrics = register_cache( @@ -133,6 +163,13 @@ class ResponseCache(Generic[KV]): ) self._enable_logging = enable_logging + self._prune_timer: WheelTimer[KV] | None = None + if self.timeout: + # Set up the timers for pruning inflight entries. The times here are + # how often we check for entries to prune. + self._prune_timer = WheelTimer(bucket_size=self.timeout / 10) + self.clock.looping_call(self._prune_inflight_entries, self.timeout / 10) + def size(self) -> int: return len(self._result_cache) @@ -172,6 +209,7 @@ class ResponseCache(Generic[KV]): context: ResponseCacheContext[KV], deferred: "defer.Deferred[RV]", opentracing_span_context: "opentracing.SpanContext | None", + cancellable: bool, ) -> ResponseCacheEntry: """Set the entry for the given key to the given deferred. @@ -183,13 +221,16 @@ class ResponseCache(Generic[KV]): context: Information about the cache miss deferred: The deferred which resolves to the result. opentracing_span_context: An opentracing span wrapping the calculation + cancellable: Whether the deferred is safe to be cancelled Returns: The cache entry object. """ result = ObservableDeferred(deferred, consumeErrors=True) key = context.cache_key - entry = ResponseCacheEntry(result, opentracing_span_context) + entry = ResponseCacheEntry( + result, opentracing_span_context, cancellable=cancellable + ) self._result_cache[key] = entry def on_complete(r: RV) -> RV: @@ -233,6 +274,7 @@ class ResponseCache(Generic[KV]): self._metrics.inc_evictions(EvictionReason.time) self._result_cache.pop(key, None) + @cancellable async def wrap( self, key: KV, @@ -301,8 +343,44 @@ class ResponseCache(Generic[KV]): return await callback(*args, **kwargs) d = run_in_background(cb) - entry = self._set(context, d, span_context) - return await make_deferred_yieldable(entry.result.observe()) + entry = self._set( + context, d, span_context, cancellable=is_function_cancellable(callback) + ) + try: + return await make_deferred_yieldable(entry.result.observe()) + except defer.CancelledError: + pass + + # We've been cancelled. + # + # Since we've kicked off the background operation, we can't just + # give up and return here and need to wait for the background + # operation to stop. We don't want to stop the background process + # immediately to give a chance for retries to come in and wait for + # the result. + # + # Instead, we temporarily swallow the cancellation and mark the + # cache key as one to potentially timeout. + + # Update the `last_observer_removed_time_ms` so that the pruning + # mechanism can kick in if needed. + now = self.clock.time_msec() + entry.last_observer_removed_time_ms = now + if self._prune_timer is not None and self.timeout: + self._prune_timer.insert(now, key, now + self.timeout.as_millis()) + + # Wait on the original deferred, which will continue to run in the + # background until it completes. We don't want to add an observer as + # this would prevent the entry from being pruned. + # + # Note that this deferred has been consumed by the + # ObservableDeferred, so we don't know what it will return. That + # doesn't matter as we just want to throw a CancelledError once it completes anyway. + try: + await make_deferred_yieldable(delay_cancellation(d)) + except Exception: + pass + raise defer.CancelledError() result = entry.result.observe() if self._enable_logging: @@ -320,4 +398,60 @@ class ResponseCache(Generic[KV]): f"ResponseCache[{self._name}].wait", contexts=(span_context,) if span_context else (), ): - return await make_deferred_yieldable(result) + try: + return await make_deferred_yieldable(result) + except defer.CancelledError: + # If we're cancelled then we update the + # `last_observer_removed_time_ms` so that the pruning mechanism + # can kick in if needed. + now = self.clock.time_msec() + entry.last_observer_removed_time_ms = now + if self._prune_timer is not None and self.timeout: + self._prune_timer.insert(now, key, now + self.timeout.as_millis()) + raise + + def _prune_inflight_entries(self) -> None: + """Prune entries which have been in the cache for too long without + observers""" + assert self._prune_timer is not None + assert self.timeout is not None + + now = self.clock.time_msec() + keys_to_check = self._prune_timer.fetch(now) + + # Loop through the keys and check if they should be evicted. We evict + # entries which have no active observers, and which have been in the + # cache for longer than the timeout since the last observer was removed. + for key in keys_to_check: + entry = self._result_cache.get(key) + if not entry: + continue + + if not entry.cancellable: + # this entry is not cancellable, so we should keep it in the cache until it completes. + continue + + if entry.result.has_called(): + # this entry has already completed, so we should have scheduled it for + # removal at the right time. We can just skip it here and wait for the + # scheduled call to remove it. + continue + + if entry.result.has_observers(): + # this entry has observers, so we should keep it in the cache for now. + continue + + if entry.last_observer_removed_time_ms is None: + # this should never happen, but just in case, we should keep the entry + # in the cache until we have a valid last_observer_removed_time_ms to + # compare against. + continue + + if now - entry.last_observer_removed_time_ms > self.timeout.as_millis(): + self._metrics.inc_evictions(EvictionReason.time) + self._result_cache.pop(key, None) + try: + entry.result.cancel() + except Exception: + # we ignore exceptions from cancel, as it is best effort anyway. + pass diff --git a/synapse/util/clock.py b/synapse/util/clock.py index a3872d6f93..7232a1331c 100644 --- a/synapse/util/clock.py +++ b/synapse/util/clock.py @@ -62,6 +62,16 @@ this setting won't inherit the log level from the parent logger. logging.setLoggerClass(original_logger_class) +def _try_wakeup_deferred(d: Deferred) -> None: + """Try to wake up a deferred, but ignore any exceptions raised by the + callback. This is useful when we want to wake up a deferred that may have + already been cancelled, and we don't care about the result.""" + try: + d.callback(None) + except Exception: + pass + + class Clock: """ A Clock wraps a Twisted reactor and provides utilities on top of it. @@ -114,7 +124,11 @@ class Clock: with context.PreserveLoggingContext(): # We can ignore the lint here since this class is the one location callLater should # be called. - self._reactor.callLater(duration.as_secs(), d.callback, duration.as_secs()) # type: ignore[call-later-not-tracked] + self._reactor.callLater( + duration.as_secs(), + lambda _: _try_wakeup_deferred(d), + duration.as_secs(), + ) # type: ignore[call-later-not-tracked] await d def time(self) -> float: diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py index c63faa96df..fe4622fe99 100644 --- a/synapse/util/wheel_timer.py +++ b/synapse/util/wheel_timer.py @@ -23,6 +23,8 @@ from typing import Generic, Hashable, TypeVar import attr +from synapse.util.duration import Duration + logger = logging.getLogger(__name__) T = TypeVar("T", bound=Hashable) @@ -39,13 +41,13 @@ class WheelTimer(Generic[T]): expired. """ - def __init__(self, bucket_size: int = 5000) -> None: + def __init__(self, bucket_size: Duration = Duration(seconds=5)) -> None: """ Args: bucket_size: Size of buckets in ms. Corresponds roughly to the accuracy of the timer. """ - self.bucket_size: int = bucket_size + self.bucket_size = bucket_size self.entries: list[_Entry[T]] = [] def insert(self, now: int, obj: T, then: int) -> None: @@ -56,8 +58,8 @@ class WheelTimer(Generic[T]): obj: Object to be inserted then: When to return the object strictly after. """ - then_key = int(then / self.bucket_size) + 1 - now_key = int(now / self.bucket_size) + then_key = int(then / self.bucket_size.as_millis()) + 1 + now_key = int(now / self.bucket_size.as_millis()) if self.entries: min_key = self.entries[0].end_key @@ -100,7 +102,7 @@ class WheelTimer(Generic[T]): Returns: List of objects that have timed out """ - now_key = int(now / self.bucket_size) + now_key = int(now / self.bucket_size.as_millis()) ret: list[T] = [] while self.entries and self.entries[0].end_key <= now_key: diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index fcbf3fd53c..e6ada1adb2 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -41,6 +41,7 @@ from tests import unittest from tests.federation.transport.test_knocking import ( KnockingStrippedStateEventHelperMixin, ) +from tests.rest.client.test_rooms import make_request_with_cancellation_test from tests.server import TimedOutException logger = logging.getLogger(__name__) @@ -1145,3 +1146,65 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase): self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"]) self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"]) + + +class SyncCancellationTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + room.register_servlets, + ] + + def test_initial_sync(self) -> None: + """Tests that an initial sync request can be cancelled.""" + user_id = self.register_user("user", "password") + tok = self.login("user", "password") + + # Populate the account with a few rooms + for _ in range(5): + room_id = self.helper.create_room_as(user_id, tok=tok) + self.helper.send(room_id, tok=tok) + + channel = make_request_with_cancellation_test( + "test_initial_sync", + self.reactor, + self.site, + "GET", + "/_matrix/client/v3/sync", + token=tok, + ) + + self.assertEqual(200, channel.code, msg=channel.result["body"]) + + def test_incremental_sync(self) -> None: + """Tests that an incremental sync request can be cancelled.""" + user_id = self.register_user("user", "password") + tok = self.login("user", "password") + + # Populate the account with a few rooms + room_ids = [] + for _ in range(5): + room_id = self.helper.create_room_as(user_id, tok=tok) + self.helper.send(room_id, tok=tok) + room_ids.append(room_id) + + # Do an initial sync to get a since token. + channel = self.make_request("GET", "/sync", access_token=tok) + self.assertEqual(200, channel.code, msg=channel.result) + since = channel.json_body["next_batch"] + + # Send some more messages to generate activity in the rooms. + for room_id in room_ids: + self.helper.send(room_id, tok=tok) + + channel = make_request_with_cancellation_test( + "test_incremental_sync", + self.reactor, + self.site, + "GET", + f"/_matrix/client/v3/sync?since={since}&timeout=10000", + token=tok, + ) + + self.assertEqual(200, channel.code, msg=channel.result["body"]) diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py index def5c817db..af42769890 100644 --- a/tests/util/caches/test_response_cache.py +++ b/tests/util/caches/test_response_cache.py @@ -19,6 +19,7 @@ # # +from functools import wraps from unittest.mock import Mock from parameterized import parameterized @@ -26,6 +27,7 @@ from parameterized import parameterized from twisted.internet import defer from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext +from synapse.util.cancellation import cancellable from synapse.util.duration import Duration from tests.server import get_clock @@ -48,15 +50,23 @@ class ResponseCacheTestCase(TestCase): def with_cache(self, name: str, ms: int = 0) -> ResponseCache: return ResponseCache( - clock=self.clock, name=name, server_name="test_server", timeout_ms=ms + clock=self.clock, + name=name, + server_name="test_server", + timeout=Duration(milliseconds=ms), ) @staticmethod async def instant_return(o: str) -> str: return o - async def delayed_return(self, o: str) -> str: - await self.clock.sleep(Duration(seconds=1)) + @cancellable + async def delayed_return( + self, + o: str, + duration: Duration = Duration(seconds=1), # noqa + ) -> str: + await self.clock.sleep(duration) return o def test_cache_hit(self) -> None: @@ -223,3 +233,332 @@ class ResponseCacheTestCase(TestCase): self.assertCountEqual( [], cache.keys(), "cache should not have the result now" ) + + def test_cache_func_errors(self) -> None: + """If the callback raises an error, the error should be raised to all + callers and the result should not be cached""" + cache = self.with_cache("error_cache", ms=3000) + + expected_error = Exception("oh no") + + async def erring(o: str) -> str: + await self.clock.sleep(Duration(seconds=1)) + raise expected_error + + wrap_d = defer.ensureDeferred(cache.wrap(0, erring, "ignored")) + self.assertNoResult(wrap_d) + + # a second call should also return a pending deferred + wrap2_d = defer.ensureDeferred(cache.wrap(0, erring, "ignored")) + self.assertNoResult(wrap2_d) + + # let the call complete + self.reactor.advance(1) + + # both results should have completed with the error + self.assertFailure(wrap_d, Exception) + self.assertFailure(wrap2_d, Exception) + + def test_cache_cancel_first_wait(self) -> None: + """Test that cancellation of the deferred returned by wrap() on the + first call does not immediately cause a cancellation error to be raised + when its cancelled and the wrapped function continues execution (unless + it times out). + """ + cache = self.with_cache("cancel_cache", ms=3000) + + expected_result = "howdy" + + wrap_d = defer.ensureDeferred( + cache.wrap(0, self.delayed_return, expected_result) + ) + + # cancel the deferred before it has a chance to return + wrap_d.cancel() + + # The cancel should be ignored for now, and the inner function should + # still be running. + self.assertNoResult(wrap_d) + + # Advance the clock until the inner function should have returned, but + # not long enough for the cache entry to have expired. + self.reactor.advance(2) + + # The deferred we're waiting on should now return a cancelled error. + self.assertFailure(wrap_d, defer.CancelledError) + + # However future callers should get the result. + wrap_d2 = defer.ensureDeferred( + cache.wrap(0, self.delayed_return, expected_result) + ) + self.assertEqual(expected_result, self.successResultOf(wrap_d2)) + + def test_cache_cancel_first_wait_expire(self) -> None: + """Test that cancellation of the deferred returned by wrap() and the + entry expiring before the wrapped function returns. + + The wrapped function should be cancelled. + """ + cache = self.with_cache("cancel_expire_cache", ms=300) + + expected_result = "howdy" + + # Wrap the function so that we can keep track of when it completes or + # errors. + completed = False + cancelled = False + + @wraps(self.delayed_return) + async def wrapped(o: str) -> str: + nonlocal completed, cancelled + + try: + return await self.delayed_return(o) + except defer.CancelledError: + cancelled = True + raise + finally: + completed = True + + wrap_d = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result)) + + # cancel the deferred before it has a chance to return + wrap_d.cancel() + + # The cancel should be ignored for now, and the inner function should + # still be running. + self.assertNoResult(wrap_d) + self.assertFalse(completed, "wrapped function should not have completed yet") + + # Advance the clock until the cache entry should have expired, but not + # long enough for the inner function to have returned. + self.reactor.advance(0.7) + + # The deferred we're waiting on should now return a cancelled error. + self.assertFailure(wrap_d, defer.CancelledError) + self.assertTrue(completed, "wrapped function should have completed") + self.assertTrue(cancelled, "wrapped function should have been cancelled") + + def test_cache_cancel_first_wait_other_observers(self) -> None: + """Test that cancellation of the deferred returned by wrap() does not + cause a cancellation error to be raised if there are other observers + still waiting on the result. + """ + cache = self.with_cache("cancel_other_cache", ms=300) + + expected_result = "howdy" + + # Wrap the function so that we can keep track of when it completes or + # errors. + completed = False + cancelled = False + + @wraps(self.delayed_return) + async def wrapped(o: str) -> str: + nonlocal completed, cancelled + + try: + return await self.delayed_return(o) + except defer.CancelledError: + cancelled = True + raise + finally: + completed = True + + wrap_d1 = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result)) + wrap_d2 = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result)) + + # cancel the first deferred before it has a chance to return + wrap_d1.cancel() + + # The cancel should be ignored for now, and the inner function should + # still be running. + self.assertNoResult(wrap_d1) + self.assertNoResult(wrap_d2) + self.assertFalse(completed, "wrapped function should not have completed yet") + + # Advance the clock until the cache entry should have expired, but not + # long enough for the inner function to have returned. + self.reactor.advance(0.7) + + # Neither deferred should have returned yet, since the inner function + # should still be running. + self.assertNoResult(wrap_d1) + self.assertNoResult(wrap_d2) + self.assertFalse(completed, "wrapped function should not have completed yet") + + # Now advance the clock until the inner function should have returned. + self.reactor.advance(2.5) + + # The wrapped function should have completed without cancellation. + self.assertTrue(completed, "wrapped function should have completed") + self.assertFalse(cancelled, "wrapped function should not have been cancelled") + + # The first deferred we're waiting on should now return a cancelled error. + self.assertFailure(wrap_d1, defer.CancelledError) + + # The second deferred should return the result. + self.assertEqual(expected_result, self.successResultOf(wrap_d2)) + + def test_cache_add_and_cancel(self) -> None: + """Test that waiting on the cache and cancelling repeatedly keeps the + cache entry alive. + """ + cache = self.with_cache("cancel_add_cache", ms=300) + + expected_result = "howdy" + + # Wrap the function so that we can keep track of when it completes or + # errors. + completed = False + cancelled = False + + @wraps(self.delayed_return) + async def wrapped(o: str) -> str: + nonlocal completed, cancelled + + try: + return await self.delayed_return(o) + except defer.CancelledError: + cancelled = True + raise + finally: + completed = True + + # Repeatedly await for the result and cancel it, which should keep the + # cache entry alive even though the total time exceeds the cache + # timeout. + deferreds = [] + for _ in range(8): + # Await the deferred. + wrap_d = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result)) + + # cancel the deferred before it has a chance to return + self.reactor.advance(0.05) + wrap_d.cancel() + deferreds.append(wrap_d) + + # The cancel should not cause the inner function to be cancelled + # yet. + self.assertFalse( + completed, "wrapped function should not have completed yet" + ) + self.assertFalse( + cancelled, "wrapped function should not have been cancelled yet" + ) + + # Advance the clock until the cache entry should have expired, but not + # long enough for the inner function to have returned. + self.reactor.advance(0.05) + + # Now advance the clock until the inner function should have returned. + self.reactor.advance(0.2) + + # All the deferreds we're waiting on should now return a cancelled error. + for wrap_d in deferreds: + self.assertFailure(wrap_d, defer.CancelledError) + + # The wrapped function should have completed without cancellation. + self.assertTrue(completed, "wrapped function should have completed") + self.assertFalse(cancelled, "wrapped function should not have been cancelled") + + # Querying the cache should return the completed result + wrap_d = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result)) + self.assertEqual(expected_result, self.successResultOf(wrap_d)) + + def test_cache_cancel_non_cancellable(self) -> None: + """Test that cancellation of the deferred returned by wrap() on a + non-cancellable entry does not cause a cancellation error to be raised + when it's cancelled and the wrapped function continues execution. + """ + cache = self.with_cache("cancel_non_cancellable_cache", ms=300) + + expected_result = "howdy" + + # Wrap the function so that we can keep track of when it completes or + # errors. + completed = False + cancelled = False + + async def wrapped(o: str) -> str: + nonlocal completed, cancelled + + try: + return await self.delayed_return(o) + except defer.CancelledError: + cancelled = True + raise + finally: + completed = True + + wrap_d = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result)) + + # cancel the deferred before it has a chance to return + wrap_d.cancel() + + # The cancel should be ignored for now, and the inner function should + # still be running. + self.assertNoResult(wrap_d) + self.assertFalse(completed, "wrapped function should not have completed yet") + + # Advance the clock until the inner function should have returned, but + # not long enough for the cache entry to have expired. + self.reactor.advance(2) + + # The deferred we're waiting on should be cancelled, but a new call to + # the cache should return the result. + self.assertFailure(wrap_d, defer.CancelledError) + wrap_d2 = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result)) + self.assertEqual(expected_result, self.successResultOf(wrap_d2)) + + def test_cache_cancel_then_error(self) -> None: + """Test that cancellation of the deferred returned by wrap() that then + subsequently errors is correctly propagated to a second caller. + """ + + cache = self.with_cache("cancel_then_error_cache", ms=3000) + + expected_error = Exception("oh no") + + # Wrap the function so that we can keep track of when it completes or + # errors. + completed = False + cancelled = False + + @wraps(self.delayed_return) + async def wrapped(o: str) -> str: + nonlocal completed, cancelled + + try: + await self.delayed_return(o) + raise expected_error + except defer.CancelledError: + cancelled = True + raise + finally: + completed = True + + wrap_d1 = defer.ensureDeferred(cache.wrap(0, wrapped, "ignored")) + wrap_d2 = defer.ensureDeferred(cache.wrap(0, wrapped, "ignored")) + + # cancel the first deferred before it has a chance to return + wrap_d1.cancel() + + # The cancel should be ignored for now, and the inner function should + # still be running. + self.assertNoResult(wrap_d1) + self.assertNoResult(wrap_d2) + self.assertFalse(completed, "wrapped function should not have completed yet") + + # Advance the clock until the inner function should have returned. + self.reactor.advance(2) + + # The wrapped function should have completed with an error without cancellation. + self.assertTrue(completed, "wrapped function should have completed") + self.assertFalse(cancelled, "wrapped function should not have been cancelled") + + # The first deferred we're waiting on should now return a cancelled error. + self.assertFailure(wrap_d1, defer.CancelledError) + + # The second deferred should return the error. + self.assertFailure(wrap_d2, Exception) diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index a6b7ddf485..e4b1bb2b23 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -120,7 +120,7 @@ class ObservableDeferredTest(TestCase): assert results[1] is not None self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result") - def test_cancellation(self) -> None: + def test_cancellation_observer(self) -> None: """Test that cancelling an observer does not affect other observers.""" origin_d: "Deferred[int]" = Deferred() observable = ObservableDeferred(origin_d, consumeErrors=True) @@ -138,6 +138,10 @@ class ObservableDeferredTest(TestCase): self.assertFalse(observer1.called) self.failureResultOf(observer2, CancelledError) self.assertFalse(observer3.called) + # check that we remove the cancelled observer from the list of observers + # as a clean up. + self.assertEqual(len(observable.observers()), 2) + self.assertNotIn(observer2, observable.observers()) # other observers resolve as normal origin_d.callback(123) @@ -148,6 +152,22 @@ class ObservableDeferredTest(TestCase): observer4 = observable.observe() self.assertEqual(observer4.result, 123, "observer 4 callback result") + def test_cancellation_observee(self) -> None: + """Test that cancelling the original deferred cancels all observers.""" + origin_d: "Deferred[int]" = Deferred() + observable = ObservableDeferred(origin_d, consumeErrors=True) + + observer1 = observable.observe() + observer2 = observable.observe() + + self.assertFalse(observer1.called) + self.assertFalse(observer2.called) + + # cancel the original deferred + origin_d.cancel() + self.failureResultOf(observer1, CancelledError) + self.failureResultOf(observer2, CancelledError) + class TimeoutDeferredTest(TestCase): def setUp(self) -> None: diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py index 6fa575a18e..3dd9a9891f 100644 --- a/tests/util/test_wheel_timer.py +++ b/tests/util/test_wheel_timer.py @@ -19,6 +19,7 @@ # # +from synapse.util.duration import Duration from synapse.util.wheel_timer import WheelTimer from .. import unittest @@ -26,7 +27,7 @@ from .. import unittest class WheelTimerTestCase(unittest.TestCase): def test_single_insert_fetch(self) -> None: - wheel: WheelTimer[object] = WheelTimer(bucket_size=5) + wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5)) wheel.insert(100, "1", 150) @@ -39,7 +40,7 @@ class WheelTimerTestCase(unittest.TestCase): self.assertListEqual(wheel.fetch(170), []) def test_multi_insert(self) -> None: - wheel: WheelTimer[object] = WheelTimer(bucket_size=5) + wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5)) wheel.insert(100, "1", 150) wheel.insert(105, "2", 130) @@ -54,13 +55,13 @@ class WheelTimerTestCase(unittest.TestCase): self.assertListEqual(wheel.fetch(210), []) def test_insert_past(self) -> None: - wheel: WheelTimer[object] = WheelTimer(bucket_size=5) + wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5)) wheel.insert(100, "1", 50) self.assertListEqual(wheel.fetch(120), ["1"]) def test_insert_past_multi(self) -> None: - wheel: WheelTimer[object] = WheelTimer(bucket_size=5) + wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5)) wheel.insert(100, "1", 150) wheel.insert(100, "2", 140) @@ -72,7 +73,7 @@ class WheelTimerTestCase(unittest.TestCase): self.assertListEqual(wheel.fetch(240), []) def test_multi_insert_then_past(self) -> None: - wheel: WheelTimer[object] = WheelTimer(bucket_size=5) + wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5)) wheel.insert(100, "1", 150) wheel.insert(100, "2", 160)