1
0

Allow long lived syncs to be cancelled if client has gone away (#19499)

This commit is contained in:
Erik Johnston
2026-02-26 22:41:06 +01:00
committed by GitHub
parent f78d011df1
commit 2c73e8daef
21 changed files with 703 additions and 39 deletions

1
changelog.d/19499.misc Normal file
View File

@@ -0,0 +1 @@
Cancel long-running sync requests if the client has gone away.

View File

@@ -411,6 +411,12 @@ indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
[tool.ruff.lint.flake8-bugbear]
extend-immutable-calls = [
# Durations are immutable
"synapse.util.duration.Duration",
]
[tool.maturin]
manifest-path = "rust/Cargo.toml"
module-name = "synapse.synapse_rust"

View File

@@ -45,6 +45,7 @@ from synapse.synapse_rust.http_client import HttpClient
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
from synapse.util.duration import Duration
from synapse.util.json import json_decoder
from . import introspection_response_timer
@@ -139,7 +140,7 @@ class MasDelegatedAuth(BaseAuth):
clock=self._clock,
name="mas_token_introspection",
server_name=self.server_name,
timeout_ms=120_000,
timeout=Duration(minutes=2),
# don't log because the keys are access tokens
enable_logging=False,
)

View File

@@ -49,6 +49,7 @@ from synapse.synapse_rust.http_client import HttpClient
from synapse.types import Requester, UserID, create_requester
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
from synapse.util.duration import Duration
from synapse.util.json import json_decoder
from . import introspection_response_timer
@@ -205,7 +206,7 @@ class MSC3861DelegatedAuth(BaseAuth):
clock=self._clock,
name="token_introspection",
server_name=self.server_name,
timeout_ms=120_000,
timeout=Duration(minutes=2),
# don't log because the keys are access tokens
enable_logging=False,
)

View File

@@ -46,6 +46,7 @@ from synapse.logging import opentracing
from synapse.metrics import SERVER_NAME_LABEL
from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, ThirdPartyInstanceID
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.duration import Duration
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -132,7 +133,7 @@ class ApplicationServiceApi(SimpleHttpClient):
clock=hs.get_clock(),
name="as_protocol_meta",
server_name=self.server_name,
timeout_ms=HOUR_IN_MS,
timeout=Duration(hours=1),
)
def _get_headers(self, service: "ApplicationService") -> dict[bytes, list[bytes]]:

View File

@@ -29,6 +29,7 @@ import attr
from synapse.types import JsonDict
from synapse.util.check_dependencies import check_requirements
from synapse.util.duration import Duration
from ._base import Config, ConfigError
@@ -108,7 +109,7 @@ class CacheConfig(Config):
global_factor: float
track_memory_usage: bool
expiry_time_msec: int | None
sync_response_cache_duration: int
sync_response_cache_duration: Duration
@staticmethod
def reset() -> None:
@@ -207,10 +208,14 @@ class CacheConfig(Config):
min_cache_ttl = self.cache_autotuning.get("min_cache_ttl")
self.cache_autotuning["min_cache_ttl"] = self.parse_duration(min_cache_ttl)
self.sync_response_cache_duration = self.parse_duration(
sync_response_cache_duration_ms = self.parse_duration(
cache_config.get("sync_response_cache_duration", "2m")
)
self.sync_response_cache_duration = Duration(
milliseconds=sync_response_cache_duration_ms
)
def resize_all_caches(self) -> None:
"""Ensure all cache sizes are up-to-date.

View File

@@ -166,7 +166,7 @@ class FederationServer(FederationBase):
clock=hs.get_clock(),
name="fed_txn_handler",
server_name=self.server_name,
timeout_ms=30000,
timeout=Duration(seconds=30),
)
self.transaction_actions = TransactionActions(self.store)
@@ -179,13 +179,13 @@ class FederationServer(FederationBase):
clock=hs.get_clock(),
name="state_resp",
server_name=self.server_name,
timeout_ms=30000,
timeout=Duration(seconds=30),
)
self._state_ids_resp_cache: ResponseCache[tuple[str, str]] = ResponseCache(
clock=hs.get_clock(),
name="state_ids_resp",
server_name=self.server_name,
timeout_ms=30000,
timeout=Duration(seconds=30),
)
self._federation_metrics_domains = (

View File

@@ -185,7 +185,7 @@ class RoomCreationHandler:
clock=hs.get_clock(),
name="room_upgrade",
server_name=self.server_name,
timeout_ms=FIVE_MINUTES_IN_MS,
timeout=Duration(minutes=5),
)
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid

View File

@@ -44,6 +44,7 @@ from synapse.storage.databases.main.room import LargestRoomStats
from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.duration import Duration
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -79,7 +80,7 @@ class RoomListHandler:
clock=hs.get_clock(),
name="remote_room_list",
server_name=self.server_name,
timeout_ms=30 * 1000,
timeout=Duration(seconds=30),
)
async def get_local_public_room_list(

View File

@@ -77,6 +77,7 @@ from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
from synapse.util.cancellation import cancellable
from synapse.util.metrics import Measure
from synapse.visibility import filter_and_transform_events_for_client
@@ -307,7 +308,7 @@ class SyncHandler:
clock=hs.get_clock(),
name="sync",
server_name=self.server_name,
timeout_ms=hs.config.caches.sync_response_cache_duration,
timeout=hs.config.caches.sync_response_cache_duration,
)
# ExpiringCache((User, Device)) -> LruCache(user_id => event_id)
@@ -367,6 +368,10 @@ class SyncHandler:
logger.debug("Returning sync response for %s", user_id)
return res
# TODO: We mark this as cancellable, and we have tests for it, but we
# haven't gone through and exhaustively checked that all the code paths in
# this method are actually cancellable.
@cancellable
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,

View File

@@ -102,7 +102,7 @@ class FollowerTypingHandler:
self._room_typing: dict[str, set[str]] = {}
self._member_last_federation_poke: dict[RoomMember, int] = {}
self.wheel_timer: WheelTimer[RoomMember] = WheelTimer(bucket_size=5000)
self.wheel_timer: WheelTimer[RoomMember] = WheelTimer()
self._latest_room_serial = 0
self._rooms_updated: set[str] = set()
@@ -120,7 +120,7 @@ class FollowerTypingHandler:
self._rooms_updated = set()
self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
self.wheel_timer = WheelTimer()
@wrap_as_background_process("typing._handle_timeouts")
async def _handle_timeouts(self) -> None:

View File

@@ -130,7 +130,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
clock=hs.get_clock(),
name="repl." + self.NAME,
server_name=self.server_name,
timeout_ms=30 * 60 * 1000,
timeout=Duration(minutes=30),
)
# We reserve `instance_name` as a parameter to sending requests, so we

View File

@@ -59,6 +59,7 @@ from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken
from synapse.types.rest.client import SlidingSyncBody
from synapse.util.caches.lrucache import LruCache
from synapse.util.cancellation import cancellable
from synapse.util.json import json_decoder
from ._base import client_patterns, set_timeline_upper_limit
@@ -138,6 +139,7 @@ class SyncRestServlet(RestServlet):
cfg=hs.config.ratelimiting.rc_presence_per_user,
)
@cancellable
async def on_GET(self, request: SynapseRequest) -> tuple[int, JsonDict]:
# This will always be set by the time Twisted calls us.
assert request.args is not None

View File

@@ -57,6 +57,7 @@ from synapse.logging.context import (
run_coroutine_in_background,
run_in_background,
)
from synapse.util.cancellation import cancellable
from synapse.util.clock import Clock
from synapse.util.duration import Duration
@@ -83,6 +84,13 @@ class AbstractObservableDeferred(Generic[_T], metaclass=abc.ABCMeta):
"""
...
@abc.abstractmethod
def has_observers(self) -> bool:
"""Returns True if there are any observers currently observing this
ObservableDeferred.
"""
...
class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
"""Wraps a deferred object so that we can add observer deferreds. These
@@ -122,6 +130,11 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
for observer in observers:
try:
observer.callback(r)
except defer.CancelledError:
# We do not want to propagate cancellations to the original
# deferred, or to other observers, so we can just ignore
# this.
pass
except Exception as e:
logger.exception(
"%r threw an exception on .callback(%r), ignoring...",
@@ -145,6 +158,11 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
f.value.__failure__ = f
try:
observer.errback(f)
except defer.CancelledError:
# We do not want to propagate cancellations to the original
# deferred, or to other observers, so we can just ignore
# this.
pass
except Exception as e:
logger.exception(
"%r threw an exception on .errback(%r), ignoring...",
@@ -160,6 +178,7 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
deferred.addCallbacks(callback, errback)
@cancellable
def observe(self) -> "defer.Deferred[_T]":
"""Observe the underlying deferred.
@@ -169,7 +188,7 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
"""
if not self._result:
assert isinstance(self._observers, list)
d: "defer.Deferred[_T]" = defer.Deferred()
d: "defer.Deferred[_T]" = defer.Deferred(canceller=self._remove_observer)
self._observers.append(d)
return d
elif self._result[0]:
@@ -180,6 +199,12 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
def observers(self) -> "Collection[defer.Deferred[_T]]":
return self._observers
def has_observers(self) -> bool:
"""Returns True if there are any observers currently observing this
ObservableDeferred.
"""
return bool(self._observers)
def has_called(self) -> bool:
return self._result is not None
@@ -204,6 +229,28 @@ class ObservableDeferred(Generic[_T], AbstractObservableDeferred[_T]):
self._deferred,
)
def _remove_observer(self, observer: "defer.Deferred[_T]") -> None:
"""Removes an observer from the list of observers.
Used as a canceller for the observer deferreds, so that if an observer
is cancelled it is removed from the list of observers.
"""
if self._result is not None:
# The underlying deferred has already resolved, so the observer has
# already been resolved. Nothing to do.
return
assert isinstance(self._observers, list)
try:
self._observers.remove(observer)
except ValueError:
# The observer was not in the list. This can happen if the underlying
# deferred resolves at around the same time as we try to remove the
# observer. In this case, it's possible that we tried to remove the
# observer just after it was added to the list, but before it was
# resolved and removed from the list by the callback/errback above.
pass
T = TypeVar("T")
@@ -962,6 +1009,27 @@ def delay_cancellation(awaitable: Awaitable[T]) -> Awaitable[T]:
return new_deferred
def observe_deferred(d: "defer.Deferred[T]") -> "defer.Deferred[T]":
"""Returns a new `Deferred` that observes the given `Deferred`.
The returned `Deferred` will resolve with the same result as the given
`Deferred`, but will not "chain" on the deferred so that using the returned
deferred does not affect the given `Deferred` in any way.
"""
new_deferred: "defer.Deferred[T]" = defer.Deferred()
def callback(r: T) -> T:
new_deferred.callback(r)
return r
def errback(f: Failure) -> Failure:
new_deferred.errback(f)
return f
d.addCallbacks(callback, errback)
return new_deferred
class AwakenableSleeper:
"""Allows explicitly waking up deferreds related to an entity that are
currently sleeping.

View File

@@ -39,10 +39,15 @@ from synapse.logging.opentracing import (
start_active_span,
start_active_span_follows_from,
)
from synapse.util.async_helpers import AbstractObservableDeferred, ObservableDeferred
from synapse.util.async_helpers import (
ObservableDeferred,
delay_cancellation,
)
from synapse.util.caches import EvictionReason, register_cache
from synapse.util.cancellation import cancellable, is_function_cancellable
from synapse.util.clock import Clock
from synapse.util.duration import Duration
from synapse.util.wheel_timer import WheelTimer
logger = logging.getLogger(__name__)
@@ -79,8 +84,8 @@ class ResponseCacheContext(Generic[KV]):
@attr.s(auto_attribs=True)
class ResponseCacheEntry:
result: AbstractObservableDeferred
class ResponseCacheEntry(Generic[KV]):
result: ObservableDeferred[KV]
"""The (possibly incomplete) result of the operation.
Note that we continue to store an ObservableDeferred even after the operation
@@ -91,6 +96,15 @@ class ResponseCacheEntry:
opentracing_span_context: "opentracing.SpanContext | None"
"""The opentracing span which generated/is generating the result"""
cancellable: bool
"""Whether the deferred is safe to be cancelled."""
last_observer_removed_time_ms: int | None = None
"""The last time that an observer was removed from this entry.
Used to determine when to evict the entry if it has no observers.
"""
class ResponseCache(Generic[KV]):
"""
@@ -98,6 +112,22 @@ class ResponseCache(Generic[KV]):
returned from the cache. This means that if the client retries the request
while the response is still being computed, that original response will be
used rather than trying to compute a new response.
If a timeout is not specified then the cache entry will be kept while the
wrapped function is still running, and will be removed immediately once it
completes.
If a timeout is specified then the cache entry will be kept for the duration
of the timeout after the wrapped function completes. If the wrapped function
is cancellable and during processing nothing waits on the result for longer
than the timeout then the wrapped function will be cancelled and the cache
entry will be removed.
This behaviour is useful for caching responses to requests which are
expensive to compute, but which may be retried by clients if they time out.
For example, /sync requests which may take a long time to compute, and which
clients will retry. However, if the client stops retrying for a while then
we want to stop processing the request and free up the resources.
"""
def __init__(
@@ -106,7 +136,7 @@ class ResponseCache(Generic[KV]):
clock: Clock,
name: str,
server_name: str,
timeout_ms: float = 0,
timeout: Duration | None = None,
enable_logging: bool = True,
):
"""
@@ -121,7 +151,7 @@ class ResponseCache(Generic[KV]):
self._result_cache: dict[KV, ResponseCacheEntry] = {}
self.clock = clock
self.timeout = Duration(milliseconds=timeout_ms)
self.timeout = timeout
self._name = name
self._metrics = register_cache(
@@ -133,6 +163,13 @@ class ResponseCache(Generic[KV]):
)
self._enable_logging = enable_logging
self._prune_timer: WheelTimer[KV] | None = None
if self.timeout:
# Set up the timers for pruning inflight entries. The times here are
# how often we check for entries to prune.
self._prune_timer = WheelTimer(bucket_size=self.timeout / 10)
self.clock.looping_call(self._prune_inflight_entries, self.timeout / 10)
def size(self) -> int:
return len(self._result_cache)
@@ -172,6 +209,7 @@ class ResponseCache(Generic[KV]):
context: ResponseCacheContext[KV],
deferred: "defer.Deferred[RV]",
opentracing_span_context: "opentracing.SpanContext | None",
cancellable: bool,
) -> ResponseCacheEntry:
"""Set the entry for the given key to the given deferred.
@@ -183,13 +221,16 @@ class ResponseCache(Generic[KV]):
context: Information about the cache miss
deferred: The deferred which resolves to the result.
opentracing_span_context: An opentracing span wrapping the calculation
cancellable: Whether the deferred is safe to be cancelled
Returns:
The cache entry object.
"""
result = ObservableDeferred(deferred, consumeErrors=True)
key = context.cache_key
entry = ResponseCacheEntry(result, opentracing_span_context)
entry = ResponseCacheEntry(
result, opentracing_span_context, cancellable=cancellable
)
self._result_cache[key] = entry
def on_complete(r: RV) -> RV:
@@ -233,6 +274,7 @@ class ResponseCache(Generic[KV]):
self._metrics.inc_evictions(EvictionReason.time)
self._result_cache.pop(key, None)
@cancellable
async def wrap(
self,
key: KV,
@@ -301,8 +343,44 @@ class ResponseCache(Generic[KV]):
return await callback(*args, **kwargs)
d = run_in_background(cb)
entry = self._set(context, d, span_context)
return await make_deferred_yieldable(entry.result.observe())
entry = self._set(
context, d, span_context, cancellable=is_function_cancellable(callback)
)
try:
return await make_deferred_yieldable(entry.result.observe())
except defer.CancelledError:
pass
# We've been cancelled.
#
# Since we've kicked off the background operation, we can't just
# give up and return here and need to wait for the background
# operation to stop. We don't want to stop the background process
# immediately to give a chance for retries to come in and wait for
# the result.
#
# Instead, we temporarily swallow the cancellation and mark the
# cache key as one to potentially timeout.
# Update the `last_observer_removed_time_ms` so that the pruning
# mechanism can kick in if needed.
now = self.clock.time_msec()
entry.last_observer_removed_time_ms = now
if self._prune_timer is not None and self.timeout:
self._prune_timer.insert(now, key, now + self.timeout.as_millis())
# Wait on the original deferred, which will continue to run in the
# background until it completes. We don't want to add an observer as
# this would prevent the entry from being pruned.
#
# Note that this deferred has been consumed by the
# ObservableDeferred, so we don't know what it will return. That
# doesn't matter as we just want to throw a CancelledError once it completes anyway.
try:
await make_deferred_yieldable(delay_cancellation(d))
except Exception:
pass
raise defer.CancelledError()
result = entry.result.observe()
if self._enable_logging:
@@ -320,4 +398,60 @@ class ResponseCache(Generic[KV]):
f"ResponseCache[{self._name}].wait",
contexts=(span_context,) if span_context else (),
):
return await make_deferred_yieldable(result)
try:
return await make_deferred_yieldable(result)
except defer.CancelledError:
# If we're cancelled then we update the
# `last_observer_removed_time_ms` so that the pruning mechanism
# can kick in if needed.
now = self.clock.time_msec()
entry.last_observer_removed_time_ms = now
if self._prune_timer is not None and self.timeout:
self._prune_timer.insert(now, key, now + self.timeout.as_millis())
raise
def _prune_inflight_entries(self) -> None:
"""Prune entries which have been in the cache for too long without
observers"""
assert self._prune_timer is not None
assert self.timeout is not None
now = self.clock.time_msec()
keys_to_check = self._prune_timer.fetch(now)
# Loop through the keys and check if they should be evicted. We evict
# entries which have no active observers, and which have been in the
# cache for longer than the timeout since the last observer was removed.
for key in keys_to_check:
entry = self._result_cache.get(key)
if not entry:
continue
if not entry.cancellable:
# this entry is not cancellable, so we should keep it in the cache until it completes.
continue
if entry.result.has_called():
# this entry has already completed, so we should have scheduled it for
# removal at the right time. We can just skip it here and wait for the
# scheduled call to remove it.
continue
if entry.result.has_observers():
# this entry has observers, so we should keep it in the cache for now.
continue
if entry.last_observer_removed_time_ms is None:
# this should never happen, but just in case, we should keep the entry
# in the cache until we have a valid last_observer_removed_time_ms to
# compare against.
continue
if now - entry.last_observer_removed_time_ms > self.timeout.as_millis():
self._metrics.inc_evictions(EvictionReason.time)
self._result_cache.pop(key, None)
try:
entry.result.cancel()
except Exception:
# we ignore exceptions from cancel, as it is best effort anyway.
pass

View File

@@ -62,6 +62,16 @@ this setting won't inherit the log level from the parent logger.
logging.setLoggerClass(original_logger_class)
def _try_wakeup_deferred(d: Deferred) -> None:
"""Try to wake up a deferred, but ignore any exceptions raised by the
callback. This is useful when we want to wake up a deferred that may have
already been cancelled, and we don't care about the result."""
try:
d.callback(None)
except Exception:
pass
class Clock:
"""
A Clock wraps a Twisted reactor and provides utilities on top of it.
@@ -114,7 +124,11 @@ class Clock:
with context.PreserveLoggingContext():
# We can ignore the lint here since this class is the one location callLater should
# be called.
self._reactor.callLater(duration.as_secs(), d.callback, duration.as_secs()) # type: ignore[call-later-not-tracked]
self._reactor.callLater(
duration.as_secs(),
lambda _: _try_wakeup_deferred(d),
duration.as_secs(),
) # type: ignore[call-later-not-tracked]
await d
def time(self) -> float:

View File

@@ -23,6 +23,8 @@ from typing import Generic, Hashable, TypeVar
import attr
from synapse.util.duration import Duration
logger = logging.getLogger(__name__)
T = TypeVar("T", bound=Hashable)
@@ -39,13 +41,13 @@ class WheelTimer(Generic[T]):
expired.
"""
def __init__(self, bucket_size: int = 5000) -> None:
def __init__(self, bucket_size: Duration = Duration(seconds=5)) -> None:
"""
Args:
bucket_size: Size of buckets in ms. Corresponds roughly to the
accuracy of the timer.
"""
self.bucket_size: int = bucket_size
self.bucket_size = bucket_size
self.entries: list[_Entry[T]] = []
def insert(self, now: int, obj: T, then: int) -> None:
@@ -56,8 +58,8 @@ class WheelTimer(Generic[T]):
obj: Object to be inserted
then: When to return the object strictly after.
"""
then_key = int(then / self.bucket_size) + 1
now_key = int(now / self.bucket_size)
then_key = int(then / self.bucket_size.as_millis()) + 1
now_key = int(now / self.bucket_size.as_millis())
if self.entries:
min_key = self.entries[0].end_key
@@ -100,7 +102,7 @@ class WheelTimer(Generic[T]):
Returns:
List of objects that have timed out
"""
now_key = int(now / self.bucket_size)
now_key = int(now / self.bucket_size.as_millis())
ret: list[T] = []
while self.entries and self.entries[0].end_key <= now_key:

View File

@@ -41,6 +41,7 @@ from tests import unittest
from tests.federation.transport.test_knocking import (
KnockingStrippedStateEventHelperMixin,
)
from tests.rest.client.test_rooms import make_request_with_cancellation_test
from tests.server import TimedOutException
logger = logging.getLogger(__name__)
@@ -1145,3 +1146,65 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase):
self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
class SyncCancellationTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
room.register_servlets,
]
def test_initial_sync(self) -> None:
"""Tests that an initial sync request can be cancelled."""
user_id = self.register_user("user", "password")
tok = self.login("user", "password")
# Populate the account with a few rooms
for _ in range(5):
room_id = self.helper.create_room_as(user_id, tok=tok)
self.helper.send(room_id, tok=tok)
channel = make_request_with_cancellation_test(
"test_initial_sync",
self.reactor,
self.site,
"GET",
"/_matrix/client/v3/sync",
token=tok,
)
self.assertEqual(200, channel.code, msg=channel.result["body"])
def test_incremental_sync(self) -> None:
"""Tests that an incremental sync request can be cancelled."""
user_id = self.register_user("user", "password")
tok = self.login("user", "password")
# Populate the account with a few rooms
room_ids = []
for _ in range(5):
room_id = self.helper.create_room_as(user_id, tok=tok)
self.helper.send(room_id, tok=tok)
room_ids.append(room_id)
# Do an initial sync to get a since token.
channel = self.make_request("GET", "/sync", access_token=tok)
self.assertEqual(200, channel.code, msg=channel.result)
since = channel.json_body["next_batch"]
# Send some more messages to generate activity in the rooms.
for room_id in room_ids:
self.helper.send(room_id, tok=tok)
channel = make_request_with_cancellation_test(
"test_incremental_sync",
self.reactor,
self.site,
"GET",
f"/_matrix/client/v3/sync?since={since}&timeout=10000",
token=tok,
)
self.assertEqual(200, channel.code, msg=channel.result["body"])

View File

@@ -19,6 +19,7 @@
#
#
from functools import wraps
from unittest.mock import Mock
from parameterized import parameterized
@@ -26,6 +27,7 @@ from parameterized import parameterized
from twisted.internet import defer
from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
from synapse.util.cancellation import cancellable
from synapse.util.duration import Duration
from tests.server import get_clock
@@ -48,15 +50,23 @@ class ResponseCacheTestCase(TestCase):
def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
return ResponseCache(
clock=self.clock, name=name, server_name="test_server", timeout_ms=ms
clock=self.clock,
name=name,
server_name="test_server",
timeout=Duration(milliseconds=ms),
)
@staticmethod
async def instant_return(o: str) -> str:
return o
async def delayed_return(self, o: str) -> str:
await self.clock.sleep(Duration(seconds=1))
@cancellable
async def delayed_return(
self,
o: str,
duration: Duration = Duration(seconds=1), # noqa
) -> str:
await self.clock.sleep(duration)
return o
def test_cache_hit(self) -> None:
@@ -223,3 +233,332 @@ class ResponseCacheTestCase(TestCase):
self.assertCountEqual(
[], cache.keys(), "cache should not have the result now"
)
def test_cache_func_errors(self) -> None:
"""If the callback raises an error, the error should be raised to all
callers and the result should not be cached"""
cache = self.with_cache("error_cache", ms=3000)
expected_error = Exception("oh no")
async def erring(o: str) -> str:
await self.clock.sleep(Duration(seconds=1))
raise expected_error
wrap_d = defer.ensureDeferred(cache.wrap(0, erring, "ignored"))
self.assertNoResult(wrap_d)
# a second call should also return a pending deferred
wrap2_d = defer.ensureDeferred(cache.wrap(0, erring, "ignored"))
self.assertNoResult(wrap2_d)
# let the call complete
self.reactor.advance(1)
# both results should have completed with the error
self.assertFailure(wrap_d, Exception)
self.assertFailure(wrap2_d, Exception)
def test_cache_cancel_first_wait(self) -> None:
"""Test that cancellation of the deferred returned by wrap() on the
first call does not immediately cause a cancellation error to be raised
when its cancelled and the wrapped function continues execution (unless
it times out).
"""
cache = self.with_cache("cancel_cache", ms=3000)
expected_result = "howdy"
wrap_d = defer.ensureDeferred(
cache.wrap(0, self.delayed_return, expected_result)
)
# cancel the deferred before it has a chance to return
wrap_d.cancel()
# The cancel should be ignored for now, and the inner function should
# still be running.
self.assertNoResult(wrap_d)
# Advance the clock until the inner function should have returned, but
# not long enough for the cache entry to have expired.
self.reactor.advance(2)
# The deferred we're waiting on should now return a cancelled error.
self.assertFailure(wrap_d, defer.CancelledError)
# However future callers should get the result.
wrap_d2 = defer.ensureDeferred(
cache.wrap(0, self.delayed_return, expected_result)
)
self.assertEqual(expected_result, self.successResultOf(wrap_d2))
def test_cache_cancel_first_wait_expire(self) -> None:
"""Test that cancellation of the deferred returned by wrap() and the
entry expiring before the wrapped function returns.
The wrapped function should be cancelled.
"""
cache = self.with_cache("cancel_expire_cache", ms=300)
expected_result = "howdy"
# Wrap the function so that we can keep track of when it completes or
# errors.
completed = False
cancelled = False
@wraps(self.delayed_return)
async def wrapped(o: str) -> str:
nonlocal completed, cancelled
try:
return await self.delayed_return(o)
except defer.CancelledError:
cancelled = True
raise
finally:
completed = True
wrap_d = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
# cancel the deferred before it has a chance to return
wrap_d.cancel()
# The cancel should be ignored for now, and the inner function should
# still be running.
self.assertNoResult(wrap_d)
self.assertFalse(completed, "wrapped function should not have completed yet")
# Advance the clock until the cache entry should have expired, but not
# long enough for the inner function to have returned.
self.reactor.advance(0.7)
# The deferred we're waiting on should now return a cancelled error.
self.assertFailure(wrap_d, defer.CancelledError)
self.assertTrue(completed, "wrapped function should have completed")
self.assertTrue(cancelled, "wrapped function should have been cancelled")
def test_cache_cancel_first_wait_other_observers(self) -> None:
"""Test that cancellation of the deferred returned by wrap() does not
cause a cancellation error to be raised if there are other observers
still waiting on the result.
"""
cache = self.with_cache("cancel_other_cache", ms=300)
expected_result = "howdy"
# Wrap the function so that we can keep track of when it completes or
# errors.
completed = False
cancelled = False
@wraps(self.delayed_return)
async def wrapped(o: str) -> str:
nonlocal completed, cancelled
try:
return await self.delayed_return(o)
except defer.CancelledError:
cancelled = True
raise
finally:
completed = True
wrap_d1 = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
wrap_d2 = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
# cancel the first deferred before it has a chance to return
wrap_d1.cancel()
# The cancel should be ignored for now, and the inner function should
# still be running.
self.assertNoResult(wrap_d1)
self.assertNoResult(wrap_d2)
self.assertFalse(completed, "wrapped function should not have completed yet")
# Advance the clock until the cache entry should have expired, but not
# long enough for the inner function to have returned.
self.reactor.advance(0.7)
# Neither deferred should have returned yet, since the inner function
# should still be running.
self.assertNoResult(wrap_d1)
self.assertNoResult(wrap_d2)
self.assertFalse(completed, "wrapped function should not have completed yet")
# Now advance the clock until the inner function should have returned.
self.reactor.advance(2.5)
# The wrapped function should have completed without cancellation.
self.assertTrue(completed, "wrapped function should have completed")
self.assertFalse(cancelled, "wrapped function should not have been cancelled")
# The first deferred we're waiting on should now return a cancelled error.
self.assertFailure(wrap_d1, defer.CancelledError)
# The second deferred should return the result.
self.assertEqual(expected_result, self.successResultOf(wrap_d2))
def test_cache_add_and_cancel(self) -> None:
"""Test that waiting on the cache and cancelling repeatedly keeps the
cache entry alive.
"""
cache = self.with_cache("cancel_add_cache", ms=300)
expected_result = "howdy"
# Wrap the function so that we can keep track of when it completes or
# errors.
completed = False
cancelled = False
@wraps(self.delayed_return)
async def wrapped(o: str) -> str:
nonlocal completed, cancelled
try:
return await self.delayed_return(o)
except defer.CancelledError:
cancelled = True
raise
finally:
completed = True
# Repeatedly await for the result and cancel it, which should keep the
# cache entry alive even though the total time exceeds the cache
# timeout.
deferreds = []
for _ in range(8):
# Await the deferred.
wrap_d = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
# cancel the deferred before it has a chance to return
self.reactor.advance(0.05)
wrap_d.cancel()
deferreds.append(wrap_d)
# The cancel should not cause the inner function to be cancelled
# yet.
self.assertFalse(
completed, "wrapped function should not have completed yet"
)
self.assertFalse(
cancelled, "wrapped function should not have been cancelled yet"
)
# Advance the clock until the cache entry should have expired, but not
# long enough for the inner function to have returned.
self.reactor.advance(0.05)
# Now advance the clock until the inner function should have returned.
self.reactor.advance(0.2)
# All the deferreds we're waiting on should now return a cancelled error.
for wrap_d in deferreds:
self.assertFailure(wrap_d, defer.CancelledError)
# The wrapped function should have completed without cancellation.
self.assertTrue(completed, "wrapped function should have completed")
self.assertFalse(cancelled, "wrapped function should not have been cancelled")
# Querying the cache should return the completed result
wrap_d = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
self.assertEqual(expected_result, self.successResultOf(wrap_d))
def test_cache_cancel_non_cancellable(self) -> None:
"""Test that cancellation of the deferred returned by wrap() on a
non-cancellable entry does not cause a cancellation error to be raised
when it's cancelled and the wrapped function continues execution.
"""
cache = self.with_cache("cancel_non_cancellable_cache", ms=300)
expected_result = "howdy"
# Wrap the function so that we can keep track of when it completes or
# errors.
completed = False
cancelled = False
async def wrapped(o: str) -> str:
nonlocal completed, cancelled
try:
return await self.delayed_return(o)
except defer.CancelledError:
cancelled = True
raise
finally:
completed = True
wrap_d = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
# cancel the deferred before it has a chance to return
wrap_d.cancel()
# The cancel should be ignored for now, and the inner function should
# still be running.
self.assertNoResult(wrap_d)
self.assertFalse(completed, "wrapped function should not have completed yet")
# Advance the clock until the inner function should have returned, but
# not long enough for the cache entry to have expired.
self.reactor.advance(2)
# The deferred we're waiting on should be cancelled, but a new call to
# the cache should return the result.
self.assertFailure(wrap_d, defer.CancelledError)
wrap_d2 = defer.ensureDeferred(cache.wrap(0, wrapped, expected_result))
self.assertEqual(expected_result, self.successResultOf(wrap_d2))
def test_cache_cancel_then_error(self) -> None:
"""Test that cancellation of the deferred returned by wrap() that then
subsequently errors is correctly propagated to a second caller.
"""
cache = self.with_cache("cancel_then_error_cache", ms=3000)
expected_error = Exception("oh no")
# Wrap the function so that we can keep track of when it completes or
# errors.
completed = False
cancelled = False
@wraps(self.delayed_return)
async def wrapped(o: str) -> str:
nonlocal completed, cancelled
try:
await self.delayed_return(o)
raise expected_error
except defer.CancelledError:
cancelled = True
raise
finally:
completed = True
wrap_d1 = defer.ensureDeferred(cache.wrap(0, wrapped, "ignored"))
wrap_d2 = defer.ensureDeferred(cache.wrap(0, wrapped, "ignored"))
# cancel the first deferred before it has a chance to return
wrap_d1.cancel()
# The cancel should be ignored for now, and the inner function should
# still be running.
self.assertNoResult(wrap_d1)
self.assertNoResult(wrap_d2)
self.assertFalse(completed, "wrapped function should not have completed yet")
# Advance the clock until the inner function should have returned.
self.reactor.advance(2)
# The wrapped function should have completed with an error without cancellation.
self.assertTrue(completed, "wrapped function should have completed")
self.assertFalse(cancelled, "wrapped function should not have been cancelled")
# The first deferred we're waiting on should now return a cancelled error.
self.assertFailure(wrap_d1, defer.CancelledError)
# The second deferred should return the error.
self.assertFailure(wrap_d2, Exception)

View File

@@ -120,7 +120,7 @@ class ObservableDeferredTest(TestCase):
assert results[1] is not None
self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result")
def test_cancellation(self) -> None:
def test_cancellation_observer(self) -> None:
"""Test that cancelling an observer does not affect other observers."""
origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
@@ -138,6 +138,10 @@ class ObservableDeferredTest(TestCase):
self.assertFalse(observer1.called)
self.failureResultOf(observer2, CancelledError)
self.assertFalse(observer3.called)
# check that we remove the cancelled observer from the list of observers
# as a clean up.
self.assertEqual(len(observable.observers()), 2)
self.assertNotIn(observer2, observable.observers())
# other observers resolve as normal
origin_d.callback(123)
@@ -148,6 +152,22 @@ class ObservableDeferredTest(TestCase):
observer4 = observable.observe()
self.assertEqual(observer4.result, 123, "observer 4 callback result")
def test_cancellation_observee(self) -> None:
"""Test that cancelling the original deferred cancels all observers."""
origin_d: "Deferred[int]" = Deferred()
observable = ObservableDeferred(origin_d, consumeErrors=True)
observer1 = observable.observe()
observer2 = observable.observe()
self.assertFalse(observer1.called)
self.assertFalse(observer2.called)
# cancel the original deferred
origin_d.cancel()
self.failureResultOf(observer1, CancelledError)
self.failureResultOf(observer2, CancelledError)
class TimeoutDeferredTest(TestCase):
def setUp(self) -> None:

View File

@@ -19,6 +19,7 @@
#
#
from synapse.util.duration import Duration
from synapse.util.wheel_timer import WheelTimer
from .. import unittest
@@ -26,7 +27,7 @@ from .. import unittest
class WheelTimerTestCase(unittest.TestCase):
def test_single_insert_fetch(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5))
wheel.insert(100, "1", 150)
@@ -39,7 +40,7 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(170), [])
def test_multi_insert(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5))
wheel.insert(100, "1", 150)
wheel.insert(105, "2", 130)
@@ -54,13 +55,13 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(210), [])
def test_insert_past(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5))
wheel.insert(100, "1", 50)
self.assertListEqual(wheel.fetch(120), ["1"])
def test_insert_past_multi(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5))
wheel.insert(100, "1", 150)
wheel.insert(100, "2", 140)
@@ -72,7 +73,7 @@ class WheelTimerTestCase(unittest.TestCase):
self.assertListEqual(wheel.fetch(240), [])
def test_multi_insert_then_past(self) -> None:
wheel: WheelTimer[object] = WheelTimer(bucket_size=5)
wheel: WheelTimer[object] = WheelTimer(bucket_size=Duration(milliseconds=5))
wheel.insert(100, "1", 150)
wheel.insert(100, "2", 160)