From 507cafc2c3d752a45f1fc498055aa009ca4c83dc Mon Sep 17 00:00:00 2001 From: "Olivier Wilkinson (reivilibre)" Date: Mon, 2 Aug 2021 14:19:20 +0100 Subject: [PATCH] Make a single-group transaction function --- synapse/storage/databases/state/bg_updates.py | 111 +++++++++++++++++- 1 file changed, 108 insertions(+), 3 deletions(-) diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py index c2891cb07f..7608bb67a5 100644 --- a/synapse/storage/databases/state/bg_updates.py +++ b/synapse/storage/databases/state/bg_updates.py @@ -13,13 +13,17 @@ # limitations under the License. import logging -from typing import Optional +import typing +from typing import Dict, List, Optional from synapse.storage._base import SQLBaseStore from synapse.storage.database import DatabasePool from synapse.storage.engines import PostgresEngine from synapse.storage.state import StateFilter +if typing.TYPE_CHECKING: + from synapse.types import StateMap + logger = logging.getLogger(__name__) @@ -73,8 +77,8 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): return count def _get_state_groups_from_groups_txn( - self, txn, groups, state_filter: Optional[StateFilter] = None - ): + self, txn, groups: List[int], state_filter: Optional[StateFilter] = None + ) -> Dict[int, "StateMap[str]"]: state_filter = state_filter or StateFilter.all() results = {group: {} for group in groups} @@ -175,6 +179,107 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore): return results + def _get_state_groups_from_group_txn( + self, txn, group: int, state_filter: Optional[StateFilter] = None + ) -> "StateMap[str]": + state_filter = state_filter or StateFilter.all() + + result = {} + + where_clause, where_args = state_filter.make_sql_filter_clause() + + # Unless the filter clause is empty, we're going to append it after an + # existing where clause + if where_clause: + where_clause = " AND (%s)" % (where_clause,) + + if isinstance(self.database_engine, PostgresEngine): + # Temporarily disable sequential scans in this transaction. This is + # a temporary hack until we can add the right indices in + txn.execute("SET LOCAL enable_seqscan=off") + + # The below query walks the state_group tree so that the "state" + # table includes all state_groups in the tree. It then joins + # against `state_groups_state` to fetch the latest state. + # It assumes that previous state groups are always numerically + # lesser. + # The PARTITION is used to get the event_id in the greatest state + # group for the given type, state_key. + # This may return multiple rows per (type, state_key), but last_value + # should be the same. + sql = """ + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT DISTINCT ON (type, state_key) + type, state_key, event_id + FROM state_groups_state + WHERE state_group IN ( + SELECT state_group FROM state + ) %s + ORDER BY type, state_key, state_group DESC + """ + + args = [group] + args.extend(where_args) + + txn.execute(sql % (where_clause,), args) + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + result[key] = event_id + else: + max_entries_returned = state_filter.max_entries_returned() + + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = group + + while next_group: + # We did this before by getting the list of group ids, and + # then passing that list to sqlite to get latest event for + # each (type, state_key). However, that was terribly slow + # without the right indices (which we can't add until + # after we finish deduping state, which requires this func) + args = [next_group] + args.extend(where_args) + + txn.execute( + "SELECT type, state_key, event_id FROM state_groups_state" + " WHERE state_group = ? " + where_clause, + args, + ) + result.update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn + if (typ, state_key) not in result + ) + + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + max_entries_returned is not None + and len(result) == max_entries_returned + ): + break + + next_group = self.db_pool.simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + + return result + class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):