diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 6522148fa1..77b3bc8391 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -14,6 +14,7 @@ pub mod matrix_const; pub mod push; pub mod rendezvous; pub mod segmenter; +pub mod tmp_cachetrace; lazy_static! { static ref LOGGING_HANDLE: ResetHandle = pyo3_log::init(); @@ -55,6 +56,7 @@ fn synapse_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { http_client::register_module(py, m)?; rendezvous::register_module(py, m)?; segmenter::register_module(py, m)?; + tmp_cachetrace::register_module(py, m)?; Ok(()) } diff --git a/rust/src/tmp_cachetrace.rs b/rust/src/tmp_cachetrace.rs new file mode 100644 index 0000000000..987b128ebc --- /dev/null +++ b/rust/src/tmp_cachetrace.rs @@ -0,0 +1,249 @@ +use std::{ + collections::BTreeMap, + fs::File, + io::{BufWriter, Write}, + sync::{ + atomic::{AtomicBool, Ordering}, + mpsc::{self, Receiver, SyncSender}, + Arc, OnceLock, + }, + time::{SystemTime, UNIX_EPOCH}, +}; + +use anyhow::{bail, Context}; +use pyo3::{ + pyclass, pymethods, + types::{PyAnyMethods, PyModule, PyModuleMethods}, + Bound, PyAny, PyResult, Python, +}; + +struct Row { + cache: u16, + time_ms: i64, + hash: u64, + op: Op, +} + +enum Op { + New { key_size: u64, value_size: u64 }, + Request, + Invalidate, + Evict, +} + +#[pyclass] +pub struct CacheTracer { + tx: SyncSender, + error_flag: Arc, + cache_names: BTreeMap, +} + +#[pymethods] +impl CacheTracer { + #[new] + #[pyo3(signature = ())] + pub fn py_new() -> Self { + let (tx, rx) = mpsc::sync_channel(2048); + let error_flag = Arc::new(AtomicBool::new(false)); + std::thread::spawn({ + let error_flag = Arc::clone(&error_flag); + move || { + if let Err(err) = receive_and_log_traces(rx, error_flag) { + eprintln!("error in cache tracer: {err}"); + } + } + }); + CacheTracer { + tx, + error_flag, + cache_names: BTreeMap::new(), + } + } + + #[pyo3(signature = (cache, key, value))] + pub fn on_new( + &mut self, + py: Python<'_>, + cache: &str, + key: Bound<'_, PyAny>, + value: Bound<'_, PyAny>, + ) { + let key_hash = key.hash().unwrap() as u64; + let key_size = get_size_of(py, &key); + let value_size = get_size_of(py, &value); + + let cache_id = if let Some(cache_id) = self.cache_names.get(cache) { + *cache_id + } else { + let new = self.cache_names.len() as u16; + self.cache_names.insert(cache.to_owned(), new); + new + }; + + if let Err(_e) = self.tx.try_send(Row { + cache: cache_id, + op: Op::New { + key_size, + value_size, + }, + hash: key_hash, + time_ms: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as i64, + }) { + self.error_flag.store(true, Ordering::Relaxed); + } + } + + #[pyo3(signature = (cache, key))] + pub fn on_request(&mut self, _py: Python<'_>, cache: &str, key: Bound<'_, PyAny>) { + let key_hash = key.hash().unwrap() as u64; + + let cache_id = if let Some(cache_id) = self.cache_names.get(cache) { + *cache_id + } else { + let new = self.cache_names.len() as u16; + self.cache_names.insert(cache.to_owned(), new); + new + }; + + if let Err(_e) = self.tx.try_send(Row { + cache: cache_id, + op: Op::Request, + hash: key_hash, + time_ms: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as i64, + }) { + self.error_flag.store(true, Ordering::Relaxed); + } + } + + #[pyo3(signature = (cache, key))] + pub fn on_invalidate(&mut self, _py: Python<'_>, cache: &str, key: Bound<'_, PyAny>) { + let key_hash = key.hash().unwrap() as u64; + + let cache_id = if let Some(cache_id) = self.cache_names.get(cache) { + *cache_id + } else { + let new = self.cache_names.len() as u16; + self.cache_names.insert(cache.to_owned(), new); + new + }; + + if let Err(_e) = self.tx.try_send(Row { + cache: cache_id, + op: Op::Invalidate, + hash: key_hash, + time_ms: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as i64, + }) { + self.error_flag.store(true, Ordering::Relaxed); + } + } + + #[pyo3(signature = (cache, key))] + pub fn on_evict(&mut self, _py: Python<'_>, cache: &str, key: Bound<'_, PyAny>) { + let key_hash = key.hash().unwrap() as u64; + + let cache_id = if let Some(cache_id) = self.cache_names.get(cache) { + *cache_id + } else { + let new = self.cache_names.len() as u16; + self.cache_names.insert(cache.to_owned(), new); + new + }; + + if let Err(_e) = self.tx.try_send(Row { + cache: cache_id, + op: Op::Evict, + hash: key_hash, + time_ms: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as i64, + }) { + self.error_flag.store(true, Ordering::Relaxed); + } + } +} + +static GETSIZEOF: OnceLock> = OnceLock::new(); + +fn get_size_of(py: Python<'_>, obj: &Bound<'_, PyAny>) -> u64 { + let getsizeof = GETSIZEOF.get_or_init(|| { + let sys = PyModule::import(py, "synapse.util.caches.lrucache").unwrap(); + let func = sys.getattr("_get_size_of").unwrap().unbind(); + func + }); + + let size: u64 = getsizeof.call1(py, (obj,)).unwrap().extract(py).unwrap(); + size +} + +fn receive_and_log_traces(rx: Receiver, error_flag: Arc) -> anyhow::Result<()> { + let pid = std::process::id(); + let f = File::create_new(format!("/tmp/syncachetrace-{pid}")) + .context("failed to start cache tracer")?; + let mut bw = BufWriter::new(f); + + let mut last_time = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_millis() as i64; + + while let Ok(row) = rx.recv() { + if error_flag.load(Ordering::Relaxed) { + bw.write_all(b"DEADBEEF")?; + bw.flush()?; + bail!("error flagged"); + } + + let time_delta = row.time_ms.saturating_sub(last_time); + last_time = row.time_ms; + bw.write_all(&(time_delta as i16).to_be_bytes())?; + bw.write_all(&row.cache.to_be_bytes())?; + bw.write_all(&row.hash.to_be_bytes())?; + + match row.op { + Op::New { + key_size, + value_size, + } => { + let key_size = key_size.min(u32::MAX as u64) as u32; + let value_size = value_size.min(u32::MAX as u64) as u32; + bw.write_all(b"N")?; + bw.write_all(&key_size.to_be_bytes())?; + bw.write_all(&value_size.to_be_bytes())?; + } + Op::Request => { + bw.write_all(b"R")?; + } + Op::Invalidate => { + bw.write_all(b"I")?; + } + Op::Evict => { + bw.write_all(b"E")?; + } + } + } + + Ok(()) +} + +pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { + let child_module = PyModule::new(py, "tmp_cachetrace")?; + child_module.add_class::()?; + + m.add_submodule(&child_module)?; + + py.import("sys")? + .getattr("modules")? + .set_item("synapse.synapse_rust.tmp_cachetrace", child_module)?; + + Ok(()) +} diff --git a/synapse/synapse_rust/tmp_cachetrace.pyi b/synapse/synapse_rust/tmp_cachetrace.pyi new file mode 100644 index 0000000000..f44f67e579 --- /dev/null +++ b/synapse/synapse_rust/tmp_cachetrace.pyi @@ -0,0 +1,6 @@ +class CacheTracer: + def __init__(self) -> None: ... + def on_new(self, cache: str, key: object, value: object) -> None: ... + def on_request(self, cache: str, key: object) -> None: ... + def on_invalidate(self, cache: str, key: object) -> None: ... + def on_evict(self, cache: str, key: object) -> None: ... diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 927162700a..0f8d3b632d 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -21,6 +21,7 @@ import logging import math +import os import threading import weakref from enum import Enum @@ -64,6 +65,7 @@ from synapse.util.linked_list import ListNode if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.synapse_rust.tmp_cachetrace import CacheTracer logger = logging.getLogger(__name__) @@ -102,6 +104,24 @@ VT = TypeVar("VT") # a general type var, distinct from either KT or VT T = TypeVar("T") +_tracer: Optional["CacheTracer"] = None +_should_trace = "SYNTRACE" in os.environ + + +def get_tracer() -> Optional["CacheTracer"]: + from synapse.synapse_rust.tmp_cachetrace import CacheTracer + + global _tracer + + if _tracer: + return _tracer + + if _should_trace: + _tracer = CacheTracer() + return _tracer + + return None + class _TimedListNode(ListNode[T]): """A `ListNode` that tracks last access time.""" @@ -493,6 +513,7 @@ class LruCache(Generic[KT, VT]): Note: The new key does not have to be unique. """ + # Default `clock` to something sensible. Note that we rename it to # `real_clock` so that mypy doesn't think its still `Optional`. if clock is None: @@ -504,6 +525,11 @@ class LruCache(Generic[KT, VT]): self.cache = cache # Used for introspection. self.apply_cache_factor_from_config = apply_cache_factor_from_config + if not isinstance(cache, TreeCache): + self._tracer = get_tracer() + else: + self._tracer = None + # Save the original max size, and apply the default size factor. self._original_max_size = max_size # We previously didn't apply the cache factor here, and as such some caches were @@ -542,6 +568,8 @@ class LruCache(Generic[KT, VT]): extra_index: Dict[KT, Set[KT]] = {} + self._cache_name = cache_name or str(id(self)) + def evict() -> None: while cache_len() > self.max_size: # Get the last node in the list (i.e. the oldest node). @@ -559,6 +587,10 @@ class LruCache(Generic[KT, VT]): evicted_len = delete_node(node) cache.pop(node.key, None) + + if self._tracer: + self._tracer.on_evict(self._cache_name, node.key) + if metrics: metrics.inc_evictions(EvictionReason.size, evicted_len) @@ -675,6 +707,10 @@ class LruCache(Generic[KT, VT]): to False if this fetch should *not* prevent a node from being expired. """ + + if self._tracer: + self._tracer.on_request(self._cache_name, key) + node = cache.get(key, None) if node is not None: if update_last_access: @@ -750,6 +786,10 @@ class LruCache(Generic[KT, VT]): key: KT, value: VT, callbacks: Collection[Callable[[], None]] = () ) -> None: node = cache.get(key, None) + + if self._tracer: + self._tracer.on_new(self._cache_name, key, value) + if node is not None: # We sometimes store large objects, e.g. dicts, which cause # the inequality check to take a long time. So let's only do @@ -792,6 +832,8 @@ class LruCache(Generic[KT, VT]): @synchronized def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]: + if self._tracer: + self._tracer.on_invalidate(self._cache_name, key) node = cache.get(key, None) if node: evicted_len = delete_node(node) @@ -813,6 +855,8 @@ class LruCache(Generic[KT, VT]): may be of lower cardinality than the TreeCache - in which case the whole subtree is deleted. """ + if self._tracer: + self._tracer.on_invalidate(self._cache_name, key) popped = cache.pop(key, None) if popped is None: return @@ -824,6 +868,8 @@ class LruCache(Generic[KT, VT]): @synchronized def cache_clear() -> None: for node in cache.values(): + if self._tracer: + self._tracer.on_invalidate(self._cache_name, node.key) node.run_and_clear_callbacks() node.drop_from_lists() @@ -841,6 +887,8 @@ class LruCache(Generic[KT, VT]): @synchronized def cache_contains(key: KT) -> bool: + if self._tracer: + self._tracer.on_request(self._cache_name, key) return key in cache @synchronized @@ -857,6 +905,8 @@ class LruCache(Generic[KT, VT]): return for key in keys: + if self._tracer: + self._tracer.on_invalidate(self._cache_name, key) node = cache.pop(key, None) if not node: continue