From e19dfa15a4fdf29f6727bc15f353eab69df05aff Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Fri, 20 Oct 2023 16:17:39 +0100 Subject: [PATCH] CHECKPOINT --- Cargo.lock | 16 + rust/Cargo.toml | 1 + rust/src/db.rs | 596 ++++++++++++++++++ rust/src/lib.rs | 2 + stubs/synapse/synapse_rust/db.py | 20 + .../databases/main/event_federation.py | 295 +-------- 6 files changed, 637 insertions(+), 293 deletions(-) create mode 100644 rust/src/db.rs create mode 100644 stubs/synapse/synapse_rust/db.py diff --git a/Cargo.lock b/Cargo.lock index 5acf47cea8..66b34f8402 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -80,6 +80,12 @@ dependencies = [ "subtle", ] +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + [[package]] name = "generic-array" version = "0.14.6" @@ -102,6 +108,15 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adab1eaa3408fb7f0c777a73e7465fd5656136fc93b670eb6df3c88c2c1344e3" +[[package]] +name = "itertools" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1c173a5686ce8bfa551b3563d0c2170bf24ca44da99c7ca4bfdab5418c3fe57" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.4" @@ -402,6 +417,7 @@ dependencies = [ "anyhow", "blake2", "hex", + "itertools", "lazy_static", "log", "pyo3", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index f62da35a6f..32ba32d32a 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -23,6 +23,7 @@ name = "synapse.synapse_rust" [dependencies] anyhow = "1.0.63" +itertools = "0.11.0" lazy_static = "1.4.0" log = "0.4.17" pyo3 = { version = "0.19.2", features = [ diff --git a/rust/src/db.rs b/rust/src/db.rs new file mode 100644 index 0000000000..f0d1362864 --- /dev/null +++ b/rust/src/db.rs @@ -0,0 +1,596 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + cmp::max, + collections::{BTreeSet, HashMap, HashSet}, +}; + +use itertools::Itertools; +use log::info; +use pyo3::{ + intern, types::PyModule, wrap_pyfunction, FromPyObject, IntoPy, PyAny, PyObject, PyResult, + Python, +}; + +use pyo3::prelude::*; +pub trait ValidDatabaseFieldType {} +pub trait ValidDatabaseReturnType {} +impl ValidDatabaseFieldType for String {} +impl ValidDatabaseFieldType for usize {} +impl ValidDatabaseFieldType for Option {} +impl ValidDatabaseReturnType for (T0,) {} +impl ValidDatabaseReturnType for (T0, T1) {} +impl + ValidDatabaseReturnType for (T0, T1, T2) +{ +} +impl< + T0: ValidDatabaseFieldType, + T1: ValidDatabaseFieldType, + T2: ValidDatabaseFieldType, + T3: ValidDatabaseFieldType, + > ValidDatabaseReturnType for (T0, T1, T2, T3) +{ +} + +/// Helper struct that accepts a boolean from the database, but also accepts an integer for SQLite compatibility. +pub struct DbBool(bool); +impl<'py> FromPyObject<'py> for DbBool { + fn extract(ob: &'py PyAny) -> PyResult { + #[derive(FromPyObject)] + enum EitherIntOrBool { + Int(i64), + Bool(bool), + } + + Ok(DbBool(match EitherIntOrBool::extract(ob)? { + EitherIntOrBool::Int(int) => int != 0, + EitherIntOrBool::Bool(bool) => bool, + })) + } +} +impl ValidDatabaseFieldType for DbBool {} + +/// Called when registering modules with python. +pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> { + let child_module = PyModule::new(py, "db")?; + + child_module.add_function(wrap_pyfunction!( + _get_auth_chain_difference_using_cover_index_txn, + m + )?)?; + + m.add_submodule(child_module)?; + + // We need to manually add the module to sys.modules to make `from + // synapse.synapse_rust import push` work. + py.import("sys")? + .getattr("modules")? + .set_item("synapse.synapse_rust.db", child_module)?; + + Ok(()) +} + +/// Wrapper for a `LoggingTransaction` from the Python side of Synapse. +pub struct LoggingTransactionWrapper<'py> { + /// The underlying LoggingTransaction + raw: &'py PyAny, + + database_engine: DatabaseEngine, +} + +impl<'source> FromPyObject<'source> for LoggingTransactionWrapper<'source> { + fn extract(ob: &'source PyAny) -> PyResult { + let database_engine = match ob + .getattr("database_engine")? + .get_type() + .name() + .expect("DB engine should have a type name") + { + "PostgresEngine" => DatabaseEngine::Postgres, + "Sqlite3Engine" => DatabaseEngine::Sqlite, + other => panic!("Unknown engine {other:?}"), + }; + Ok(Self { + raw: ob, + database_engine, + }) + } +} + +impl<'py> LoggingTransactionWrapper<'py> { + pub fn execute>(&mut self, sql: &str, args: T) -> PyResult<()> { + let execute_fn = self.raw.getattr(intern!(self.raw.py(), "execute"))?; + execute_fn.call1((sql, args))?; + Ok(()) + } + + pub fn execute_values, R: FromPyObject<'py> + ValidDatabaseReturnType>( + &mut self, + sql: &str, + args: T, + ) -> PyResult> { + let execute_fn = self.raw.getattr(intern!(self.raw.py(), "execute_values"))?; + Ok(execute_fn.call1((sql, args))?.extract()?) + } + + pub fn fetchall + ValidDatabaseReturnType>( + &mut self, + ) -> anyhow::Result> { + let fetch_fn = self.raw.getattr(intern!(self.raw.py(), "fetchall"))?; + Ok(fetch_fn.call0()?.extract()?) + } +} + +#[derive(Copy, Clone, Debug)] +pub enum DatabaseEngine { + Sqlite, + Postgres, +} + +impl DatabaseEngine { + //[inline] + pub fn supports_using_any_list(&self) -> bool { + match self { + DatabaseEngine::Sqlite => false, + DatabaseEngine::Postgres => true, + } + } +} + +/// Reimplementation of the equivalent in Synapse's Python code. +pub fn make_in_list_sql_clause>( + database_engine: DatabaseEngine, + column_name: &str, + values: impl Iterator, + python: Python<'_>, +) -> (String, Vec) { + let list_of_values: Vec = values.map(|val| val.into_py(python)).collect(); + if database_engine.supports_using_any_list() { + let list_of_values_py: PyObject = list_of_values.into_py(python); + (format!("{column_name} = ANY(?)"), vec![list_of_values_py]) + } else { + let mut s = String::new(); + s.push_str(column_name); + s.push_str(" IN ("); + if list_of_values.len() > 1 { + for _ in 1..list_of_values.len() { + s.push_str("?,"); + } + } + if list_of_values.len() > 0 { + s.push('?'); + } + s.push(')'); + (s, list_of_values) + } +} + +/// Calculates the auth chain difference using the chain index. +/// +/// See docs/auth_chain_difference_algorithm.md for details +#[pyfunction] +fn _get_auth_chain_difference_using_cover_index_txn( + mut txn: LoggingTransactionWrapper, + room_id: &str, + mut state_sets: Vec>, +) -> PyResult> { + // First we look up the chain ID/sequence numbers for all the events, and + // work out the chain/sequence numbers reachable from each state set. + + let initial_events: HashSet = state_sets + .iter() + .flat_map(|set| set.iter()) + .cloned() + .collect(); + + // Map from event_id -> (chain ID, seq no) + let mut chain_info: HashMap = HashMap::new(); + + // Map from chain ID -> seq no -> event Id + let mut chain_to_event: HashMap> = HashMap::new(); + + // All the chains that we've found that are reachable from the state + // sets. + let mut seen_chains: HashSet = HashSet::new(); + + // Fetch the chain cover index for the initial set of events we're + // considering. + fn fetch_chain_info<'a>( + txn: &mut LoggingTransactionWrapper, + events_to_fetch: impl IntoIterator, + chain_info: &mut HashMap, + seen_chains: &mut HashSet, + chain_to_event: &mut HashMap>, + ) -> anyhow::Result<()> { + let sql = r#" + SELECT event_id, chain_id, sequence_number + FROM event_auth_chains + WHERE + "#; + for batch in &events_to_fetch.into_iter().chunks(1000) { + let rows = Python::with_gil(|py| -> anyhow::Result<_> { + let (clause, args) = + make_in_list_sql_clause(txn.database_engine, "event_id", batch, py); + txn.execute(&format!("{sql}{clause}"), args)?; + txn.fetchall::<(String, usize, usize)>() + })?; + + for (event_id, chain_id, sequence_number) in rows { + // TODO would be nice to not clone the event IDs + chain_info.insert(event_id.clone(), (chain_id, sequence_number)); + seen_chains.insert(chain_id); + chain_to_event + .entry(chain_id) + .or_default() + .insert(sequence_number, event_id); + } + } + Ok(()) + } + + fetch_chain_info( + &mut txn, + initial_events.iter(), + &mut chain_info, + &mut seen_chains, + &mut chain_to_event, + )?; + + // Check that we actually have a chain ID for all the events. + let events_missing_chain_info: HashSet<&String> = initial_events + .iter() + .filter(|elem| !chain_info.contains_key(elem as &str)) + .collect(); + + // The result set to return, i.e. the auth chain difference. + let mut result: HashSet = HashSet::new(); + + if !events_missing_chain_info.is_empty() { + // For some reason we have events we haven't calculated the chain + // index for, so we need to handle those separately. This should only + // happen for older rooms where the server doesn't have all the auth + // events. + match _fixup_auth_chain_difference_sets( + &mut txn, + room_id, + &mut state_sets, + &events_missing_chain_info, + &chain_info, + )? { + Some(fixup) => { + result = fixup; + } + None => { + // No chain cover index! + let exception = Python::with_gil(|py| -> anyhow::Result { + let module = py.import("synapse.storage.databases.main.event_federation")?; + let exception_type = module.getattr("_NoChainCoverIndex")?; + let exception = exception_type.call1((room_id,))?; + Ok(PyErr::from_value(exception)) + })?; + return Err(exception); + } + } + + // We now need to refetch any events that we have added to the state + // sets. + let new_events_to_fetch = state_sets + .iter() + .flat_map(|state_set| state_set.iter()) + .filter(|event_id| !initial_events.contains(*event_id)); + + fetch_chain_info( + &mut txn, + new_events_to_fetch, + &mut chain_info, + &mut seen_chains, + &mut chain_to_event, + )?; + } + + // Corresponds to `state_sets`, except as a map from chain ID to max + // sequence number reachable from the state set. + let mut set_to_chain: Vec> = Vec::new(); + for state_set in state_sets { + let mut chains: HashMap = HashMap::new(); + + for state_id in state_set { + let (chain_id, seq_no) = chain_info[&state_id as &str]; + + let chain = chains.entry(chain_id).or_insert(0); + *chain = (*chain).max(seq_no); + } + + set_to_chain.push(chains); + } + + // Now we look up all links for the chains we have, adding chains to + // set_to_chain that are reachable from each set. + let sql = r#" + SELECT + origin_chain_id, origin_sequence_number, + target_chain_id, target_sequence_number + FROM event_auth_chain_links + WHERE + "#; + + // (We need to take a copy of `seen_chains` as we want to mutate it in + // the loop) + // TODO might be wise to avoid this clone + for batch2 in &seen_chains.clone().into_iter().chunks(1000) { + Python::with_gil(|py| -> anyhow::Result<()> { + let (clause, args) = + make_in_list_sql_clause(txn.database_engine, "origin_chain_id", batch2, py); + txn.execute(&format!("{sql}{clause}"), args)?; + Ok(()) + })?; + + for (origin_chain_id, origin_sequence_number, target_chain_id, target_sequence_number) in + txn.fetchall::<(usize, usize, usize, usize)>()? + { + for chains in &mut set_to_chain { + // chains are only reachable if the origin sequence number of + // the link is less than the max sequence number in the + // origin chain. + if origin_sequence_number <= chains.get(&origin_chain_id).copied().unwrap_or(0) { + chains.insert( + target_chain_id, + max( + target_sequence_number, + chains.get(&target_chain_id).copied().unwrap_or(0), + ), + ); + } + } + + seen_chains.insert(target_chain_id); + } + } + + // Now for each chain we figure out the maximum sequence number reachable + // from *any* state set and the minimum sequence number reachable from + // *all* state sets. Events in that range are in the auth chain + // difference. + + // Mapping from chain ID to the range of sequence numbers that should be + // pulled from the database. + let mut chain_to_gap: HashMap = HashMap::new(); + + for chain_id in seen_chains { + let (min_seq_no, max_seq_no) = set_to_chain + .iter() + .map(|chains| chains.get(&chain_id).copied().unwrap_or(0)) + .minmax() + .into_option() + .expect("this should not be empty"); + + if min_seq_no < max_seq_no { + // We have a non empty gap, try and fill it from the events that + // we have, otherwise add them to the list of gaps to pull out + // from the DB. + for seq_no in (min_seq_no + 1)..(max_seq_no + 1) { + let event_id = chain_to_event.get(&chain_id).and_then(|x| x.get(&seq_no)); + if let Some(event_id) = event_id { + // TODO we might like to see whether we can avoid this clone + result.insert((*event_id).to_owned()); + } else { + chain_to_gap.insert(chain_id, (min_seq_no, max_seq_no)); + break; + } + } + } + } + + if chain_to_gap.is_empty() { + // If there are no gaps to fetch, we're done! + return Ok(result); + } + + match txn.database_engine { + DatabaseEngine::Postgres => { + // We can use `execute_values` to efficiently fetch the gaps when + // using postgres. + let sql = r#" + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq) + WHERE + c.chain_id = l.chain_id + AND min_seq < sequence_number AND sequence_number <= max_seq + "#; + + // let args = [ + // (chain_id, min_no, max_no) + // for chain_id, (min_no, max_no) in chain_to_gap.items() + // ]; + let args: Vec<(usize, usize, usize)> = chain_to_gap + .iter() + .map(|(&chain_id, &(min_no, max_no))| (chain_id, min_no, max_no)) + .collect(); + + let rows: Vec<(String,)> = txn.execute_values(sql, args)?; + result.extend(rows.into_iter().map(|(r,)| r)); + } + DatabaseEngine::Sqlite => { + // For SQLite we just fall back to doing a noddy for loop. + let sql = r#" + SELECT event_id FROM event_auth_chains + WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ? + "#; + for (chain_id, (min_no, max_no)) in chain_to_gap { + txn.execute(sql, (chain_id, min_no, max_no))?; + result.extend(txn.fetchall::<(String,)>()?.into_iter().map(|(r,)| r)); + } + } + } + + Ok(result) +} + +/// Helper for `_get_auth_chain_difference_using_cover_index_txn` to +/// handle the case where we haven't calculated the chain cover index for +/// all events. +/// +/// This modifies `state_sets` so that they only include events that have a +/// chain cover index, and returns a set of event IDs that are part of the +/// auth difference. +/// +/// Returns None if there is no usable chain cover index. +fn _fixup_auth_chain_difference_sets( + txn: &mut LoggingTransactionWrapper, + room_id: &str, + state_sets: &mut Vec>, + events_missing_chain_info: &HashSet<&String>, + events_that_have_chain_index: &HashMap, +) -> anyhow::Result>> { + // This works similarly to the handling of unpersisted events in + // `synapse.state.v2_get_auth_chain_difference`. We uses the observation + // that if you can split the set of events into two classes X and Y, + // where no events in Y have events in X in their auth chain, then we can + // calculate the auth difference by considering X and Y separately. + // + // We do this in three steps: + // 1. Compute the set of events without chain cover index belonging to + // the auth difference. + // 2. Replacing the un-indexed events in the state_sets with their auth + // events, recursively, until the state_sets contain only indexed + // events. We can then calculate the auth difference of those state + // sets using the chain cover index. + // 3. Add the results of 1 and 2 together. + + // By construction we know that all events that we haven't persisted the + // chain cover index for are contained in + // `event_auth_chain_to_calculate`, so we pull out the events from those + // rather than doing recursive queries to walk the auth chain. + // + // We pull out those events with their auth events, which gives us enough + // information to construct the auth chain of an event up to auth events + // that have the chain cover index. + let sql = r#" + SELECT tc.event_id, ea.auth_id, eac.chain_id IS NOT NULL + FROM event_auth_chain_to_calculate AS tc + LEFT JOIN event_auth AS ea USING (event_id) + LEFT JOIN event_auth_chains AS eac ON (ea.auth_id = eac.event_id) + WHERE tc.room_id = ? + "#; + txn.execute(sql, (room_id,))?; + let mut event_to_auth_ids: HashMap> = HashMap::new(); + let mut events_that_have_chain_index: HashSet = + events_that_have_chain_index.keys().cloned().collect(); + for (event_id, auth_id, DbBool(auth_id_has_chain)) in + txn.fetchall::<(String, Option, DbBool)>()? + { + let s = event_to_auth_ids.entry(event_id).or_default(); + if let Some(auth_id) = auth_id { + s.insert(auth_id.clone()); + // (this is intended to be a BOOL but we must accept ints too?) + if auth_id_has_chain { + events_that_have_chain_index.insert(auth_id); + } + } + } + + let events_still_without_chain_ids: HashSet<&str> = events_missing_chain_info + .iter() + .filter(|event_id| !event_to_auth_ids.contains_key(event_id.as_str())) + .map(|s| s.as_str()) + .collect(); + if !events_still_without_chain_ids.is_empty() { + // Uh oh, we somehow haven't correctly done the chain cover index, + // bail and fall back to the old method. + info!( + "Unexpectedly found that events don't have chain IDs in room {}: {:?}", + room_id, events_still_without_chain_ids, + ); + + return Ok(None); + } + + // Create a map from event IDs we care about to their partial auth chain. + let mut event_id_to_partial_auth_chain: HashMap<&str, HashSet> = HashMap::new(); + for (event_id, auth_ids) in &event_to_auth_ids { + if !state_sets + .iter() + .any(|state_set| state_set.contains(event_id.as_str())) + { + continue; + } + + // TODO can we avoid the clone? + let mut processing: BTreeSet = auth_ids.iter().cloned().collect(); + let mut to_add: HashSet = HashSet::new(); + while let Some(auth_id) = processing.pop_last() { + to_add.insert(auth_id.clone()); + + let sub_auth_ids = event_to_auth_ids.get(&auth_id); + if let Some(sub_auth_ids) = sub_auth_ids { + processing.extend(sub_auth_ids - &to_add); + } + } + + event_id_to_partial_auth_chain.insert(event_id, to_add); + } + + // Now we do two things { + // 1. Update the state sets to only include indexed events; and + // 2. Create a new list containing the auth chains of the un-indexed + // events + let mut unindexed_state_sets: Vec> = Vec::new(); + for state_set in state_sets { + let mut unindexed_state_set: HashSet<&str> = HashSet::new(); + for (event_id, auth_chain) in &event_id_to_partial_auth_chain { + if !state_set.contains(*event_id) { + continue; + } + + unindexed_state_set.insert(event_id); + + state_set.remove(*event_id); + // TODO is there a more efficient way to do this + for elem in auth_chain { + state_set.remove(elem.as_str()); + } + for auth_id in auth_chain { + if events_that_have_chain_index.contains(auth_id.as_str()) { + // TODO undesirable clone + state_set.insert(auth_id.clone()); + } else { + unindexed_state_set.insert(auth_id); + } + } + } + + unindexed_state_sets.push(unindexed_state_set); + } + + // Calculate and return the auth difference of the un-indexed events. + // We want to return the union - intersection here. + let mut set_presence_count: HashMap<&str, usize> = HashMap::new(); + let intersection_count = unindexed_state_sets.len(); + for elem in unindexed_state_sets + .into_iter() + .flat_map(|set| set.into_iter()) + { + *set_presence_count.entry(elem).or_insert(0) += 1; + } + + let union_minus_intersection: HashSet = set_presence_count + .into_iter() + .filter(|(_, count)| *count < intersection_count) + .map(|(k, _)| k.to_owned()) + .collect(); + + Ok(Some(union_minus_intersection)) +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index c44c09bda7..83d9203b1f 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -3,6 +3,7 @@ use pyo3::prelude::*; use pyo3_log::ResetHandle; pub mod acl; +pub mod db; pub mod push; lazy_static! { @@ -41,6 +42,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> { acl::register_module(py, m)?; push::register_module(py, m)?; + db::register_module(py, m)?; Ok(()) } diff --git a/stubs/synapse/synapse_rust/db.py b/stubs/synapse/synapse_rust/db.py new file mode 100644 index 0000000000..a103dd3ea9 --- /dev/null +++ b/stubs/synapse/synapse_rust/db.py @@ -0,0 +1,20 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Set + +from synapse.storage.database import LoggingTransaction + + +def _get_auth_chain_difference_using_cover_index_txn(txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]) -> Set[str]: + ... diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 4f80ce75cc..1db88cb81a 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -31,6 +31,7 @@ from typing import ( import attr from prometheus_client import Counter, Gauge +from synapse import synapse_rust from synapse.api.constants import MAX_DEPTH from synapse.api.errors import StoreError @@ -433,299 +434,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas def _get_auth_chain_difference_using_cover_index_txn( self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]] ) -> Set[str]: - """Calculates the auth chain difference using the chain index. - - See docs/auth_chain_difference_algorithm.md for details - """ - - # First we look up the chain ID/sequence numbers for all the events, and - # work out the chain/sequence numbers reachable from each state set. - - initial_events = set(state_sets[0]).union(*state_sets[1:]) - - # Map from event_id -> (chain ID, seq no) - chain_info: Dict[str, Tuple[int, int]] = {} - - # Map from chain ID -> seq no -> event Id - chain_to_event: Dict[int, Dict[int, str]] = {} - - # All the chains that we've found that are reachable from the state - # sets. - seen_chains: Set[int] = set() - - # Fetch the chain cover index for the initial set of events we're - # considering. - def fetch_chain_info(events_to_fetch: Collection[str]) -> None: - sql = """ - SELECT event_id, chain_id, sequence_number - FROM event_auth_chains - WHERE %s - """ - for batch in batch_iter(events_to_fetch, 1000): - clause, args = make_in_list_sql_clause( - txn.database_engine, "event_id", batch - ) - txn.execute(sql % (clause,), args) - - for event_id, chain_id, sequence_number in txn: - chain_info[event_id] = (chain_id, sequence_number) - seen_chains.add(chain_id) - chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id - - fetch_chain_info(initial_events) - - # Check that we actually have a chain ID for all the events. - events_missing_chain_info = initial_events.difference(chain_info) - - # The result set to return, i.e. the auth chain difference. - result: Set[str] = set() - - if events_missing_chain_info: - # For some reason we have events we haven't calculated the chain - # index for, so we need to handle those separately. This should only - # happen for older rooms where the server doesn't have all the auth - # events. - result = self._fixup_auth_chain_difference_sets( - txn, - room_id, - state_sets=state_sets, - events_missing_chain_info=events_missing_chain_info, - events_that_have_chain_index=chain_info, - ) - - # We now need to refetch any events that we have added to the state - # sets. - new_events_to_fetch = { - event_id - for state_set in state_sets - for event_id in state_set - if event_id not in initial_events - } - - fetch_chain_info(new_events_to_fetch) - - # Corresponds to `state_sets`, except as a map from chain ID to max - # sequence number reachable from the state set. - set_to_chain: List[Dict[int, int]] = [] - for state_set in state_sets: - chains: Dict[int, int] = {} - set_to_chain.append(chains) - - for state_id in state_set: - chain_id, seq_no = chain_info[state_id] - - chains[chain_id] = max(seq_no, chains.get(chain_id, 0)) - - # Now we look up all links for the chains we have, adding chains to - # set_to_chain that are reachable from each set. - sql = """ - SELECT - origin_chain_id, origin_sequence_number, - target_chain_id, target_sequence_number - FROM event_auth_chain_links - WHERE %s - """ - - # (We need to take a copy of `seen_chains` as we want to mutate it in - # the loop) - for batch2 in batch_iter(set(seen_chains), 1000): - clause, args = make_in_list_sql_clause( - txn.database_engine, "origin_chain_id", batch2 - ) - txn.execute(sql % (clause,), args) - - for ( - origin_chain_id, - origin_sequence_number, - target_chain_id, - target_sequence_number, - ) in txn: - for chains in set_to_chain: - # chains are only reachable if the origin sequence number of - # the link is less than the max sequence number in the - # origin chain. - if origin_sequence_number <= chains.get(origin_chain_id, 0): - chains[target_chain_id] = max( - target_sequence_number, - chains.get(target_chain_id, 0), - ) - - seen_chains.add(target_chain_id) - - # Now for each chain we figure out the maximum sequence number reachable - # from *any* state set and the minimum sequence number reachable from - # *all* state sets. Events in that range are in the auth chain - # difference. - - # Mapping from chain ID to the range of sequence numbers that should be - # pulled from the database. - chain_to_gap: Dict[int, Tuple[int, int]] = {} - - for chain_id in seen_chains: - min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain) - max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain) - - if min_seq_no < max_seq_no: - # We have a non empty gap, try and fill it from the events that - # we have, otherwise add them to the list of gaps to pull out - # from the DB. - for seq_no in range(min_seq_no + 1, max_seq_no + 1): - event_id = chain_to_event.get(chain_id, {}).get(seq_no) - if event_id: - result.add(event_id) - else: - chain_to_gap[chain_id] = (min_seq_no, max_seq_no) - break - - if not chain_to_gap: - # If there are no gaps to fetch, we're done! - return result - - if isinstance(self.database_engine, PostgresEngine): - # We can use `execute_values` to efficiently fetch the gaps when - # using postgres. - sql = """ - SELECT event_id - FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, min_seq, max_seq) - WHERE - c.chain_id = l.chain_id - AND min_seq < sequence_number AND sequence_number <= max_seq - """ - - args = [ - (chain_id, min_no, max_no) - for chain_id, (min_no, max_no) in chain_to_gap.items() - ] - - rows = txn.execute_values(sql, args) - result.update(r for r, in rows) - else: - # For SQLite we just fall back to doing a noddy for loop. - sql = """ - SELECT event_id FROM event_auth_chains - WHERE chain_id = ? AND ? < sequence_number AND sequence_number <= ? - """ - for chain_id, (min_no, max_no) in chain_to_gap.items(): - txn.execute(sql, (chain_id, min_no, max_no)) - result.update(r for r, in txn) - - return result - - def _fixup_auth_chain_difference_sets( - self, - txn: LoggingTransaction, - room_id: str, - state_sets: List[Set[str]], - events_missing_chain_info: Set[str], - events_that_have_chain_index: Collection[str], - ) -> Set[str]: - """Helper for `_get_auth_chain_difference_using_cover_index_txn` to - handle the case where we haven't calculated the chain cover index for - all events. - - This modifies `state_sets` so that they only include events that have a - chain cover index, and returns a set of event IDs that are part of the - auth difference. - """ - - # This works similarly to the handling of unpersisted events in - # `synapse.state.v2_get_auth_chain_difference`. We uses the observation - # that if you can split the set of events into two classes X and Y, - # where no events in Y have events in X in their auth chain, then we can - # calculate the auth difference by considering X and Y separately. - # - # We do this in three steps: - # 1. Compute the set of events without chain cover index belonging to - # the auth difference. - # 2. Replacing the un-indexed events in the state_sets with their auth - # events, recursively, until the state_sets contain only indexed - # events. We can then calculate the auth difference of those state - # sets using the chain cover index. - # 3. Add the results of 1 and 2 together. - - # By construction we know that all events that we haven't persisted the - # chain cover index for are contained in - # `event_auth_chain_to_calculate`, so we pull out the events from those - # rather than doing recursive queries to walk the auth chain. - # - # We pull out those events with their auth events, which gives us enough - # information to construct the auth chain of an event up to auth events - # that have the chain cover index. - sql = """ - SELECT tc.event_id, ea.auth_id, eac.chain_id IS NOT NULL - FROM event_auth_chain_to_calculate AS tc - LEFT JOIN event_auth AS ea USING (event_id) - LEFT JOIN event_auth_chains AS eac ON (ea.auth_id = eac.event_id) - WHERE tc.room_id = ? - """ - txn.execute(sql, (room_id,)) - event_to_auth_ids: Dict[str, Set[str]] = {} - events_that_have_chain_index = set(events_that_have_chain_index) - for event_id, auth_id, auth_id_has_chain in txn: - s = event_to_auth_ids.setdefault(event_id, set()) - if auth_id is not None: - s.add(auth_id) - if auth_id_has_chain: - events_that_have_chain_index.add(auth_id) - - if events_missing_chain_info - event_to_auth_ids.keys(): - # Uh oh, we somehow haven't correctly done the chain cover index, - # bail and fall back to the old method. - logger.info( - "Unexpectedly found that events don't have chain IDs in room %s: %s", - room_id, - events_missing_chain_info - event_to_auth_ids.keys(), - ) - raise _NoChainCoverIndex(room_id) - - # Create a map from event IDs we care about to their partial auth chain. - event_id_to_partial_auth_chain: Dict[str, Set[str]] = {} - for event_id, auth_ids in event_to_auth_ids.items(): - if not any(event_id in state_set for state_set in state_sets): - continue - - processing = set(auth_ids) - to_add = set() - while processing: - auth_id = processing.pop() - to_add.add(auth_id) - - sub_auth_ids = event_to_auth_ids.get(auth_id) - if sub_auth_ids is None: - continue - - processing.update(sub_auth_ids - to_add) - - event_id_to_partial_auth_chain[event_id] = to_add - - # Now we do two things: - # 1. Update the state sets to only include indexed events; and - # 2. Create a new list containing the auth chains of the un-indexed - # events - unindexed_state_sets: List[Set[str]] = [] - for state_set in state_sets: - unindexed_state_set = set() - for event_id, auth_chain in event_id_to_partial_auth_chain.items(): - if event_id not in state_set: - continue - - unindexed_state_set.add(event_id) - - state_set.discard(event_id) - state_set.difference_update(auth_chain) - for auth_id in auth_chain: - if auth_id in events_that_have_chain_index: - state_set.add(auth_id) - else: - unindexed_state_set.add(auth_id) - - unindexed_state_sets.append(unindexed_state_set) - - # Calculate and return the auth difference of the un-indexed events. - union = unindexed_state_sets[0].union(*unindexed_state_sets[1:]) - intersection = unindexed_state_sets[0].intersection(*unindexed_state_sets[1:]) - - return union - intersection + return synapse_rust.db._get_auth_chain_difference_using_cover_index_txn(txn, room_id, state_sets) def _get_auth_chain_difference_txn( self, txn: LoggingTransaction, state_sets: List[Set[str]]