diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index e569ca6673..0ad1e7e372 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -32,7 +32,6 @@ from prometheus_client.core import Gauge from synapse.config.cache import add_resizable_cache from synapse.util.metrics import DynamicCollectorRegistry -from synapse.util.caches.deferred_cache import DeferredCache if TYPE_CHECKING: from synapse.server import HomeServer @@ -225,26 +224,6 @@ class CacheManager: def __init__(self, hs: "HomeServer") -> None: self._cache_metrics = CacheMetrics(hs.metrics_collector_registry) - self._deferred_cache_map: Dict[str, DeferredCache[CacheKey, Any]] = {} - - def get_deferred_cache( - self, - name: str, - max_entries: int = 1000, - tree: bool = False, - iterable: bool = False, - apply_cache_factor_from_config: bool = True, - prune_unread_entries: bool = True, - ) -> DeferredCache[CacheKey, Any]: - cache: DeferredCache[CacheKey, Any] = DeferredCache( - name=self.name, - cache_manager=self, - max_entries=self.max_entries, - tree=self.tree, - iterable=self.iterable, - prune_unread_entries=self.prune_unread_entries, - ) - def register_cache( self, cache_type: str, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 21080c8b37..e276c19934 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -60,16 +60,7 @@ logger = logging.getLogger(__name__) CacheKey = Union[Tuple, Any] -P = ParamSpec("P") -R = TypeVar("R") - - -class HasCacheManager(Protocol): - # Used to handle registering the caches - cache_manager: CacheManager - - -F = TypeVar("F", bound=Callable[Concatenate[HasCacheManager, P], Any]) +F = TypeVar("F", bound=Callable[..., Any]) class CachedFunction(Generic[F]): @@ -89,7 +80,7 @@ class CachedFunction(Generic[F]): class _CacheDescriptorBase: def __init__( self, - orig: Callable[Concatenate[HasCacheManager, P], Any], + orig: Callable[..., Any], num_args: Optional[int], uncached_args: Optional[Collection[str]] = None, cache_context: bool = False, @@ -214,8 +205,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): def __init__( self, - orig: Callable[Concatenate[HasCacheManager, P], Any], + orig: Callable[..., Any], *, + cache_manager: CacheManager, max_entries: int = 1000, num_args: Optional[int] = None, uncached_args: Optional[Collection[str]] = None, @@ -232,6 +224,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): cache_context=cache_context, name=name, ) + self.cache_manager = cache_manager if tree and self.num_args < 2: raise RuntimeError( @@ -246,31 +239,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): def __get__( self, obj: Optional[Any], owner: Optional[Type] ) -> Callable[..., "defer.Deferred[Any]"]: - logger.info("asdf %s %s", self.orig.__name__, owner) + cache: DeferredCache[CacheKey, Any] = DeferredCache( + name=self.name, + cache_manager=self.cache_manager, + max_entries=self.max_entries, + tree=self.tree, + iterable=self.iterable, + prune_unread_entries=self.prune_unread_entries, + ) + + get_cache_key = self.cache_key_builder @functools.wraps(self.orig) - def _wrapped( - wrapped_self: HasCacheManager, *args: Any, **kwargs: Any - ) -> "defer.Deferred[Any]": - # cache: DeferredCache[CacheKey, Any] = DeferredCache( - # name=self.name, - # cache_manager=wrapped_self.cache_manager, - # max_entries=self.max_entries, - # tree=self.tree, - # iterable=self.iterable, - # prune_unread_entries=self.prune_unread_entries, - # ) - cache = wrapped_self.cache_manager.get_deferred_cache( - self.name, - cache_manager=wrapped_self.cache_manager, - max_entries=self.max_entries, - tree=self.tree, - iterable=self.iterable, - prune_unread_entries=self.prune_unread_entries, - ) - - get_cache_key = self.cache_key_builder - + def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]": # If we're passed a cache_context then we'll want to call its invalidate() # whenever we are invalidated invalidate_callback = kwargs.pop("on_invalidate", None) @@ -515,10 +496,12 @@ class _CachedFunctionDescriptor: iterable: bool prune_unread_entries: bool name: Optional[str] + cache_manager: CacheManager def __call__(self, orig: F) -> CachedFunction[F]: d = DeferredCacheDescriptor( orig, + cache_manager=self.cache_manager, max_entries=self.max_entries, num_args=self.num_args, uncached_args=self.uncached_args, @@ -531,6 +514,15 @@ class _CachedFunctionDescriptor: return cast(CachedFunction[F], d) +P = ParamSpec("P") +R = TypeVar("R") + + +class HasCacheManager(Protocol): + # Used to handle registering the caches + cache_manager: CacheManager + + def cached( *, max_entries: int = 1000, @@ -541,21 +533,55 @@ def cached( iterable: bool = False, prune_unread_entries: bool = True, name: Optional[str] = None, -) -> _CachedFunctionDescriptor: +) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]: + """Decorate an async method with a `Measure` context manager. + + The Measure is created using `self.cache_manager`; it should only be used to decorate + methods in classes defining an instance-level `clock` attribute. + + Usage: + + @measure_func() + async def foo(...): + ... + + Which is analogous to: + + async def foo(...): + with Measure(...): + ... + """ - The cache is created using `self.cache_manager`; it should only be used to decorate - methods in classes defining an instance-level `cache_manager` attribute. - """ - return _CachedFunctionDescriptor( - max_entries=max_entries, - num_args=num_args, - uncached_args=uncached_args, - tree=tree, - cache_context=cache_context, - iterable=iterable, - prune_unread_entries=prune_unread_entries, - name=name, - ) + + def wrapper( + func: Callable[Concatenate[HasCacheManager, P], Awaitable[R]], + ) -> Callable[P, Awaitable[R]]: + # block_name = func.__name__ if name is None else name + + @functools.wraps(func) + async def cached_func( + self: HasCacheManager, *args: P.args, **kwargs: P.kwargs + ) -> R: + return _CachedFunctionDescriptor( + max_entries=max_entries, + num_args=num_args, + uncached_args=uncached_args, + tree=tree, + cache_context=cache_context, + iterable=iterable, + prune_unread_entries=prune_unread_entries, + name=name, + # Grab this attribute from the instance + cache_manager=self.cache_manager, + ) + + # There are some shenanigans here, because we're decorating a method but + # explicitly making use of the `self` parameter. The key thing here is that the + # return type within the return type for `measure_func` itself describes how the + # decorated function will be called. + return cached_func # type: ignore[return-value] + + return wrapper # type: ignore[return-value] @attr.s(auto_attribs=True, slots=True, frozen=True)