@@ -32,7 +32,6 @@ from prometheus_client.core import Gauge
|
||||
|
||||
from synapse.config.cache import add_resizable_cache
|
||||
from synapse.util.metrics import DynamicCollectorRegistry
|
||||
from synapse.util.caches.deferred_cache import DeferredCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -225,26 +224,6 @@ class CacheManager:
|
||||
def __init__(self, hs: "HomeServer") -> None:
|
||||
self._cache_metrics = CacheMetrics(hs.metrics_collector_registry)
|
||||
|
||||
self._deferred_cache_map: Dict[str, DeferredCache[CacheKey, Any]] = {}
|
||||
|
||||
def get_deferred_cache(
|
||||
self,
|
||||
name: str,
|
||||
max_entries: int = 1000,
|
||||
tree: bool = False,
|
||||
iterable: bool = False,
|
||||
apply_cache_factor_from_config: bool = True,
|
||||
prune_unread_entries: bool = True,
|
||||
) -> DeferredCache[CacheKey, Any]:
|
||||
cache: DeferredCache[CacheKey, Any] = DeferredCache(
|
||||
name=self.name,
|
||||
cache_manager=self,
|
||||
max_entries=self.max_entries,
|
||||
tree=self.tree,
|
||||
iterable=self.iterable,
|
||||
prune_unread_entries=self.prune_unread_entries,
|
||||
)
|
||||
|
||||
def register_cache(
|
||||
self,
|
||||
cache_type: str,
|
||||
|
||||
@@ -60,16 +60,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
CacheKey = Union[Tuple, Any]
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class HasCacheManager(Protocol):
|
||||
# Used to handle registering the caches
|
||||
cache_manager: CacheManager
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[Concatenate[HasCacheManager, P], Any])
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class CachedFunction(Generic[F]):
|
||||
@@ -89,7 +80,7 @@ class CachedFunction(Generic[F]):
|
||||
class _CacheDescriptorBase:
|
||||
def __init__(
|
||||
self,
|
||||
orig: Callable[Concatenate[HasCacheManager, P], Any],
|
||||
orig: Callable[..., Any],
|
||||
num_args: Optional[int],
|
||||
uncached_args: Optional[Collection[str]] = None,
|
||||
cache_context: bool = False,
|
||||
@@ -214,8 +205,9 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orig: Callable[Concatenate[HasCacheManager, P], Any],
|
||||
orig: Callable[..., Any],
|
||||
*,
|
||||
cache_manager: CacheManager,
|
||||
max_entries: int = 1000,
|
||||
num_args: Optional[int] = None,
|
||||
uncached_args: Optional[Collection[str]] = None,
|
||||
@@ -232,6 +224,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||
cache_context=cache_context,
|
||||
name=name,
|
||||
)
|
||||
self.cache_manager = cache_manager
|
||||
|
||||
if tree and self.num_args < 2:
|
||||
raise RuntimeError(
|
||||
@@ -246,31 +239,19 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||
def __get__(
|
||||
self, obj: Optional[Any], owner: Optional[Type]
|
||||
) -> Callable[..., "defer.Deferred[Any]"]:
|
||||
logger.info("asdf %s %s", self.orig.__name__, owner)
|
||||
cache: DeferredCache[CacheKey, Any] = DeferredCache(
|
||||
name=self.name,
|
||||
cache_manager=self.cache_manager,
|
||||
max_entries=self.max_entries,
|
||||
tree=self.tree,
|
||||
iterable=self.iterable,
|
||||
prune_unread_entries=self.prune_unread_entries,
|
||||
)
|
||||
|
||||
get_cache_key = self.cache_key_builder
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def _wrapped(
|
||||
wrapped_self: HasCacheManager, *args: Any, **kwargs: Any
|
||||
) -> "defer.Deferred[Any]":
|
||||
# cache: DeferredCache[CacheKey, Any] = DeferredCache(
|
||||
# name=self.name,
|
||||
# cache_manager=wrapped_self.cache_manager,
|
||||
# max_entries=self.max_entries,
|
||||
# tree=self.tree,
|
||||
# iterable=self.iterable,
|
||||
# prune_unread_entries=self.prune_unread_entries,
|
||||
# )
|
||||
cache = wrapped_self.cache_manager.get_deferred_cache(
|
||||
self.name,
|
||||
cache_manager=wrapped_self.cache_manager,
|
||||
max_entries=self.max_entries,
|
||||
tree=self.tree,
|
||||
iterable=self.iterable,
|
||||
prune_unread_entries=self.prune_unread_entries,
|
||||
)
|
||||
|
||||
get_cache_key = self.cache_key_builder
|
||||
|
||||
def _wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Any]":
|
||||
# If we're passed a cache_context then we'll want to call its invalidate()
|
||||
# whenever we are invalidated
|
||||
invalidate_callback = kwargs.pop("on_invalidate", None)
|
||||
@@ -515,10 +496,12 @@ class _CachedFunctionDescriptor:
|
||||
iterable: bool
|
||||
prune_unread_entries: bool
|
||||
name: Optional[str]
|
||||
cache_manager: CacheManager
|
||||
|
||||
def __call__(self, orig: F) -> CachedFunction[F]:
|
||||
d = DeferredCacheDescriptor(
|
||||
orig,
|
||||
cache_manager=self.cache_manager,
|
||||
max_entries=self.max_entries,
|
||||
num_args=self.num_args,
|
||||
uncached_args=self.uncached_args,
|
||||
@@ -531,6 +514,15 @@ class _CachedFunctionDescriptor:
|
||||
return cast(CachedFunction[F], d)
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class HasCacheManager(Protocol):
|
||||
# Used to handle registering the caches
|
||||
cache_manager: CacheManager
|
||||
|
||||
|
||||
def cached(
|
||||
*,
|
||||
max_entries: int = 1000,
|
||||
@@ -541,21 +533,55 @@ def cached(
|
||||
iterable: bool = False,
|
||||
prune_unread_entries: bool = True,
|
||||
name: Optional[str] = None,
|
||||
) -> _CachedFunctionDescriptor:
|
||||
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
|
||||
"""Decorate an async method with a `Measure` context manager.
|
||||
|
||||
The Measure is created using `self.cache_manager`; it should only be used to decorate
|
||||
methods in classes defining an instance-level `clock` attribute.
|
||||
|
||||
Usage:
|
||||
|
||||
@measure_func()
|
||||
async def foo(...):
|
||||
...
|
||||
|
||||
Which is analogous to:
|
||||
|
||||
async def foo(...):
|
||||
with Measure(...):
|
||||
...
|
||||
|
||||
"""
|
||||
The cache is created using `self.cache_manager`; it should only be used to decorate
|
||||
methods in classes defining an instance-level `cache_manager` attribute.
|
||||
"""
|
||||
return _CachedFunctionDescriptor(
|
||||
max_entries=max_entries,
|
||||
num_args=num_args,
|
||||
uncached_args=uncached_args,
|
||||
tree=tree,
|
||||
cache_context=cache_context,
|
||||
iterable=iterable,
|
||||
prune_unread_entries=prune_unread_entries,
|
||||
name=name,
|
||||
)
|
||||
|
||||
def wrapper(
|
||||
func: Callable[Concatenate[HasCacheManager, P], Awaitable[R]],
|
||||
) -> Callable[P, Awaitable[R]]:
|
||||
# block_name = func.__name__ if name is None else name
|
||||
|
||||
@functools.wraps(func)
|
||||
async def cached_func(
|
||||
self: HasCacheManager, *args: P.args, **kwargs: P.kwargs
|
||||
) -> R:
|
||||
return _CachedFunctionDescriptor(
|
||||
max_entries=max_entries,
|
||||
num_args=num_args,
|
||||
uncached_args=uncached_args,
|
||||
tree=tree,
|
||||
cache_context=cache_context,
|
||||
iterable=iterable,
|
||||
prune_unread_entries=prune_unread_entries,
|
||||
name=name,
|
||||
# Grab this attribute from the instance
|
||||
cache_manager=self.cache_manager,
|
||||
)
|
||||
|
||||
# There are some shenanigans here, because we're decorating a method but
|
||||
# explicitly making use of the `self` parameter. The key thing here is that the
|
||||
# return type within the return type for `measure_func` itself describes how the
|
||||
# decorated function will be called.
|
||||
return cached_func # type: ignore[return-value]
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, slots=True, frozen=True)
|
||||
|
||||
Reference in New Issue
Block a user