checkpoint: mostly what I wanted here
This commit is contained in:
@@ -14,9 +14,14 @@
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStore
|
||||
@@ -26,9 +31,12 @@ from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||
from synapse.util.caches.multi_key_response_cache import MultiKeyResponseCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# XXX
|
||||
UNKNOWN = Any # TODO
|
||||
|
||||
MAX_STATE_DELTA_HOPS = 100
|
||||
|
||||
@@ -91,6 +99,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||
500000,
|
||||
)
|
||||
|
||||
# XXX ADD TYPE
|
||||
self._state_group_inflight_cache: MultiKeyResponseCache[
|
||||
...
|
||||
] = MultiKeyResponseCache(
|
||||
self.hs.get_clock(),
|
||||
"*stateGroupInflightCache*",
|
||||
# As the results from this transaction immediately go into the
|
||||
# immediate caches _state_group_cache and _state_group_members_cache,
|
||||
# we do not keep them in the in-flight cache when done.
|
||||
timeout_ms=0,
|
||||
)
|
||||
|
||||
def get_max_state_group_txn(txn: Cursor):
|
||||
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
|
||||
return txn.fetchone()[0]
|
||||
@@ -168,13 +188,18 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||
|
||||
return results
|
||||
|
||||
def _get_state_for_group_using_cache(self, cache, group, state_filter):
|
||||
def _get_state_for_group_using_cache(
|
||||
self,
|
||||
cache: DictionaryCache[int, UNKNOWN],
|
||||
group: int,
|
||||
state_filter: StateFilter,
|
||||
) -> Tuple[MutableStateMap[UNKNOWN], bool]:
|
||||
"""Checks if group is in cache. See `_get_state_for_groups`
|
||||
|
||||
Args:
|
||||
cache(DictionaryCache): the state group cache to use
|
||||
group(int): The state group to lookup
|
||||
state_filter (StateFilter): The state filter used to fetch state
|
||||
cache: the state group cache to use
|
||||
group: The state group to lookup
|
||||
state_filter: The state filter used to fetch state
|
||||
from the database.
|
||||
|
||||
Returns 2-tuple (`state_dict`, `got_all`).
|
||||
@@ -212,7 +237,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||
self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
|
||||
) -> Dict[int, MutableStateMap[str]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key
|
||||
filtering by type/state_key.
|
||||
|
||||
Args:
|
||||
groups: list of state groups for which we want
|
||||
@@ -221,11 +246,38 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||
from the database.
|
||||
Returns:
|
||||
Dict of state group to state map.
|
||||
"""
|
||||
state_filter = state_filter or StateFilter.all()
|
||||
|
||||
|
||||
The flow for this function looks as follows:
|
||||
|
||||
* Query the immediate caches (self._state_group_cache,
|
||||
| self._state_group_members_cache).
|
||||
NONSTOP |
|
||||
|
|
||||
* Query the in-flight cache (self._state_group_inflight_cache)
|
||||
| for immediate-cache misses.
|
||||
NONSTOP |
|
||||
|
|
||||
* Service cache misses:
|
||||
| - Expand the state filter (to help cache hit ratio).
|
||||
| - Start a new transaction to fetch outstanding groups.
|
||||
| - Register entries in the in-flight cache for this transaction.
|
||||
| - (When the transaction is finished) Register entries in
|
||||
| the immediate caches.
|
||||
|
|
||||
* Wait for in-flight requests to finish...
|
||||
|
|
||||
* Assemble everything together and filter out anything we didn't
|
||||
ask for.
|
||||
|
||||
The sections marked NONSTOP must not contain any `await`s, otherwise
|
||||
race conditions could occur and the cache could be made less effective.
|
||||
"""
|
||||
|
||||
state_filter = state_filter or StateFilter.all()
|
||||
member_filter, non_member_filter = state_filter.get_member_split()
|
||||
|
||||
# QUERY THE IMMEDIATE CACHES
|
||||
# Now we look them up in the member and non-member caches
|
||||
(
|
||||
non_member_state,
|
||||
@@ -242,43 +294,147 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
|
||||
for group in groups:
|
||||
state[group].update(member_state[group])
|
||||
|
||||
# Now fetch any missing groups from the database
|
||||
|
||||
incomplete_groups = incomplete_groups_m | incomplete_groups_nm
|
||||
|
||||
if not incomplete_groups:
|
||||
return state
|
||||
|
||||
cache_sequence_nm = self._state_group_cache.sequence
|
||||
cache_sequence_m = self._state_group_members_cache.sequence
|
||||
# QUERY THE IN-FLIGHT CACHE
|
||||
# list (group ID -> Deferred that will contain a result for that group)
|
||||
inflight_requests: List[Tuple[int, Deferred[Dict[int, StateMap[str]]]]] = []
|
||||
inflight_cache_misses: List[int] = []
|
||||
|
||||
# Help the cache hit ratio by expanding the filter a bit
|
||||
# When we get around to requesting state from the database, we help the
|
||||
# cache hit ratio by expanding the filter a bit.
|
||||
# However, we need to know this now so that we can properly query the
|
||||
# in-flight cache where include_others is concerned.
|
||||
db_state_filter = state_filter.return_expanded()
|
||||
|
||||
group_to_state_dict = await self._get_state_groups_from_groups(
|
||||
list(incomplete_groups), state_filter=db_state_filter
|
||||
)
|
||||
for group in incomplete_groups:
|
||||
event_type: str
|
||||
state_keys: Optional[FrozenSet[str]]
|
||||
|
||||
# Now lets update the caches
|
||||
self._insert_into_cache(
|
||||
group_to_state_dict,
|
||||
db_state_filter,
|
||||
cache_seq_num_members=cache_sequence_m,
|
||||
cache_seq_num_non_members=cache_sequence_nm,
|
||||
)
|
||||
# First check if our exact state filter is being looked up.
|
||||
result = self._state_group_inflight_cache.get((group, db_state_filter))
|
||||
if result is not None:
|
||||
inflight_requests.append((group, make_deferred_yieldable(result)))
|
||||
continue
|
||||
|
||||
# Then check if the universal state filter is being looked up.
|
||||
result = self._state_group_inflight_cache.get((group, StateFilter.all()))
|
||||
if result is not None:
|
||||
inflight_requests.append((group, make_deferred_yieldable(result)))
|
||||
continue
|
||||
|
||||
if state_filter.include_others:
|
||||
# if the state filter includes others, we only match against the
|
||||
# state filter directly, so we give up here.
|
||||
# This is because it's too complex to cache this case properly.
|
||||
inflight_cache_misses.append(group)
|
||||
continue
|
||||
elif not db_state_filter.include_others:
|
||||
# TODO IS THIS USEFUL
|
||||
# Try looking to see if the same filter but with include_others
|
||||
# is being looked up.
|
||||
result = self._state_group_inflight_cache.get(
|
||||
(group, attr.evolve(db_state_filter, include_others=True))
|
||||
)
|
||||
if result is not None:
|
||||
inflight_requests.append((group, make_deferred_yieldable(result)))
|
||||
continue
|
||||
|
||||
for event_type, state_keys in state_filter.types.items():
|
||||
result = self._state_group_inflight_cache.get((group, event_type, None))
|
||||
if result is not None:
|
||||
inflight_requests.append((group, make_deferred_yieldable(result)))
|
||||
continue
|
||||
|
||||
if state_keys is not None:
|
||||
got_all_state_keys = False
|
||||
for state_key in state_keys:
|
||||
result = self._state_group_inflight_cache.get(
|
||||
(group, event_type, state_key)
|
||||
)
|
||||
if result is not None:
|
||||
inflight_requests.append(
|
||||
(group, make_deferred_yieldable(result))
|
||||
)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
got_all_state_keys = True
|
||||
|
||||
if not got_all_state_keys:
|
||||
# we still have to request against this group.
|
||||
inflight_cache_misses.append(group)
|
||||
break
|
||||
|
||||
# SERVICE CACHE MISSES
|
||||
if inflight_cache_misses:
|
||||
cache_sequence_nm = self._state_group_cache.sequence
|
||||
cache_sequence_m = self._state_group_members_cache.sequence
|
||||
|
||||
async def get_state_groups_from_groups_then_add_to_cache() -> Dict[
|
||||
int, StateMap[str]
|
||||
]:
|
||||
groups_to_state_dict = await self._get_state_groups_from_groups(
|
||||
list(inflight_cache_misses), state_filter=db_state_filter
|
||||
)
|
||||
|
||||
# Now let's update the caches.
|
||||
self._insert_into_cache(
|
||||
groups_to_state_dict,
|
||||
db_state_filter,
|
||||
cache_seq_num_members=cache_sequence_m,
|
||||
cache_seq_num_non_members=cache_sequence_nm,
|
||||
)
|
||||
|
||||
return groups_to_state_dict
|
||||
|
||||
# make a list of keys for us to store in the in-flight cache
|
||||
# this should list all the keys that the request will pick up from
|
||||
# the database.
|
||||
keys: List[
|
||||
Union[Tuple[int, StateFilter], Tuple[int, str, Optional[str]]]
|
||||
] = []
|
||||
for group in inflight_cache_misses:
|
||||
if db_state_filter.include_others:
|
||||
# we can't intelligently cache include_others under any other keys
|
||||
# because we don't know what keys are included.
|
||||
keys.append((group, db_state_filter))
|
||||
continue
|
||||
|
||||
for event_type, state_keys in db_state_filter.types.items():
|
||||
if state_keys is None:
|
||||
keys.append((group, event_type, None))
|
||||
else:
|
||||
for state_key in state_keys:
|
||||
keys.append((group, event_type, state_key))
|
||||
|
||||
spawned_request = self._state_group_inflight_cache.set_and_compute(
|
||||
tuple(keys), get_state_groups_from_groups_then_add_to_cache
|
||||
)
|
||||
for group in inflight_cache_misses:
|
||||
inflight_requests.append((group, spawned_request))
|
||||
|
||||
# WAIT FOR IN-FLIGHT REQUESTS TO FINISH
|
||||
for group, inflight_request in inflight_requests:
|
||||
request_result = await inflight_request
|
||||
state[group].update(request_result[group])
|
||||
|
||||
# ASSEMBLE
|
||||
# And finally update the result dict, by filtering out any extra
|
||||
# stuff we pulled out of the database.
|
||||
for group, group_state_dict in group_to_state_dict.items():
|
||||
for group in groups:
|
||||
# We just replace any existing entries, as we will have loaded
|
||||
# everything we need from the database anyway.
|
||||
state[group] = state_filter.filter_state(group_state_dict)
|
||||
state[group] = state_filter.filter_state(state[group])
|
||||
|
||||
return state
|
||||
|
||||
def _get_state_for_groups_using_cache(
|
||||
self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
|
||||
) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
|
||||
) -> Tuple[Dict[int, MutableStateMap[str]], Set[int]]:
|
||||
"""Gets the state at each of a list of state groups, optionally
|
||||
filtering by type/state_key, querying from a specific cache.
|
||||
|
||||
|
||||
210
synapse/util/caches/multi_key_response_cache.py
Normal file
210
synapse/util/caches/multi_key_response_cache.py
Normal file
@@ -0,0 +1,210 @@
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, Dict, Generic, Optional, Tuple, TypeVar
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.caches import register_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# the type of the key in the cache
|
||||
KV = TypeVar("KV")
|
||||
|
||||
# the type of the result from the operation
|
||||
RV = TypeVar("RV")
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class MultiKeyResponseCacheContext(Generic[KV]):
|
||||
"""Information about a missed MultiKeyResponseCache hit
|
||||
|
||||
This object can be passed into the callback for additional feedback
|
||||
"""
|
||||
|
||||
cache_keys: Tuple[KV, ...]
|
||||
"""The cache key that caused the cache miss
|
||||
|
||||
This should be considered read-only.
|
||||
|
||||
TODO: in attrs 20.1, make it frozen with an on_setattr.
|
||||
"""
|
||||
|
||||
should_cache: bool = True
|
||||
"""Whether the result should be cached once the request completes.
|
||||
|
||||
This can be modified by the callback if it decides its result should not be cached.
|
||||
"""
|
||||
|
||||
|
||||
class MultiKeyResponseCache(Generic[KV]):
|
||||
"""
|
||||
This caches a deferred response. Until the deferred completes it will be
|
||||
returned from the cache. This means that if the client retries the request
|
||||
while the response is still being computed, that original response will be
|
||||
used rather than trying to compute a new response.
|
||||
|
||||
Unlike the plain ResponseCache, this cache admits multiple keys to the
|
||||
deferred response.
|
||||
"""
|
||||
|
||||
def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
|
||||
# This is poorly-named: it includes both complete and incomplete results.
|
||||
# We keep complete results rather than switching to absolute values because
|
||||
# that makes it easier to cache Failure results.
|
||||
self.pending_result_cache: Dict[KV, ObservableDeferred] = {}
|
||||
|
||||
self.clock = clock
|
||||
self.timeout_sec = timeout_ms / 1000.0
|
||||
|
||||
self._name = name
|
||||
self._metrics = register_cache(
|
||||
"multikey_response_cache", name, self, resizable=False
|
||||
)
|
||||
|
||||
def size(self) -> int:
|
||||
return len(self.pending_result_cache)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.size()
|
||||
|
||||
def get(self, key: KV) -> Optional[defer.Deferred]:
|
||||
"""Look up the given key.
|
||||
|
||||
Returns a new Deferred (which also doesn't follow the synapse
|
||||
logcontext rules). You will probably want to make_deferred_yieldable the result.
|
||||
|
||||
If there is no entry for the key, returns None.
|
||||
|
||||
Args:
|
||||
key: key to get/set in the cache
|
||||
|
||||
Returns:
|
||||
None if there is no entry for this key; otherwise a deferred which
|
||||
resolves to the result.
|
||||
"""
|
||||
result = self.pending_result_cache.get(key)
|
||||
if result is not None:
|
||||
self._metrics.inc_hits()
|
||||
return result.observe()
|
||||
else:
|
||||
self._metrics.inc_misses()
|
||||
return None
|
||||
|
||||
def _set(
|
||||
self, context: MultiKeyResponseCacheContext[KV], deferred: defer.Deferred
|
||||
) -> defer.Deferred:
|
||||
"""Set the entry for the given key to the given deferred.
|
||||
|
||||
*deferred* should run its callbacks in the sentinel logcontext (ie,
|
||||
you should wrap normal synapse deferreds with
|
||||
synapse.logging.context.run_in_background).
|
||||
|
||||
Returns a new Deferred (which also doesn't follow the synapse logcontext rules).
|
||||
You will probably want to make_deferred_yieldable the result.
|
||||
|
||||
Args:
|
||||
context: Information about the cache miss
|
||||
deferred: The deferred which resolves to the result.
|
||||
|
||||
Returns:
|
||||
A new deferred which resolves to the actual result.
|
||||
"""
|
||||
result = ObservableDeferred(deferred, consumeErrors=True)
|
||||
keys = context.cache_keys
|
||||
for key in keys:
|
||||
if key not in self.pending_result_cache:
|
||||
# we only add the key if it's not already there, since we assume
|
||||
# that we won't overtake prior entries.
|
||||
self.pending_result_cache[key] = result
|
||||
|
||||
def on_complete(r):
|
||||
# if this cache has a non-zero timeout, and the callback has not cleared
|
||||
# the should_cache bit, we leave it in the cache for now and schedule
|
||||
# its removal later.
|
||||
if self.timeout_sec and context.should_cache:
|
||||
for key in keys:
|
||||
# TODO sketch, should do this in only one call_later.
|
||||
self.clock.call_later(
|
||||
self.timeout_sec, self.pending_result_cache.pop, key, None
|
||||
)
|
||||
else:
|
||||
for key in keys:
|
||||
# otherwise, remove the result immediately.
|
||||
self.pending_result_cache.pop(key, None)
|
||||
return r
|
||||
|
||||
# make sure we do this *after* adding the entry to pending_result_cache,
|
||||
# in case the result is already complete (in which case flipping the order would
|
||||
# leave us with a stuck entry in the cache).
|
||||
result.addBoth(on_complete)
|
||||
return result.observe()
|
||||
|
||||
def set_and_compute(
|
||||
self,
|
||||
keys: Tuple[KV, ...],
|
||||
callback: Callable[..., Awaitable[RV]],
|
||||
*args: Any,
|
||||
cache_context: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> defer.Deferred[RV]:
|
||||
"""Perform a *set* call, taking care of logcontexts
|
||||
|
||||
Makes a call to *callback(*args, **kwargs)*, which should
|
||||
follow the synapse logcontext rules, and adds the result to the cache.
|
||||
|
||||
Example usage:
|
||||
|
||||
async def handle_request(request):
|
||||
# etc
|
||||
return result
|
||||
|
||||
result = await response_cache.wrap(
|
||||
key,
|
||||
handle_request,
|
||||
request,
|
||||
)
|
||||
|
||||
Args:
|
||||
keys: keys to get/set in the cache
|
||||
|
||||
callback: function to call
|
||||
|
||||
*args: positional parameters to pass to the callback, if it is used
|
||||
|
||||
cache_context: if set, the callback will be given a `cache_context` kw arg,
|
||||
which will be a ResponseCacheContext object.
|
||||
|
||||
**kwargs: named parameters to pass to the callback, if it is used
|
||||
|
||||
Returns:
|
||||
The result of the callback (from the cache, or otherwise)
|
||||
"""
|
||||
|
||||
# TODO sketch logger.debug(
|
||||
# "[%s]: no cached result for [%s], calculating new one", self._name, key
|
||||
# )
|
||||
context = MultiKeyResponseCacheContext(cache_keys=keys)
|
||||
if cache_context:
|
||||
kwargs["cache_context"] = context
|
||||
d = run_in_background(callback, *args, **kwargs)
|
||||
result = self._set(context, d)
|
||||
|
||||
return make_deferred_yieldable(result)
|
||||
Reference in New Issue
Block a user