1
0

Magic

This commit is contained in:
Patrick Cloke
2022-12-14 12:08:21 -05:00
parent 41b9def9f2
commit c25ec34d73
5 changed files with 75 additions and 28 deletions

View File

@@ -124,7 +124,10 @@ class RelationsHandler:
# Note that ignored users are not passed into get_relations_for_event
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
related_events, next_token = await self._main_store.get_relations_for_event(
(
related_events,
next_token,
) = await self._main_store.relations.get_relations_for_event(
event_id=event_id,
event=event,
room_id=room_id,
@@ -211,7 +214,7 @@ class RelationsHandler:
ShadowBanError if the requester is shadow-banned
"""
related_event_ids = (
await self._main_store.get_all_relations_for_event_with_types(
await self._main_store.relations.get_all_relations_for_event_with_types(
event_id, relation_types
)
)
@@ -250,7 +253,9 @@ class RelationsHandler:
A map of event IDs to a list related events.
"""
related_events = await self._main_store.get_references_for_events(event_ids)
related_events = await self._main_store.relations.get_references_for_events(
event_ids
)
# Avoid additional logic if there are no ignored users.
if not ignored_users:
@@ -304,7 +309,7 @@ class RelationsHandler:
event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(event_ids)
summaries = await self._main_store.relations.get_thread_summaries(event_ids)
# Limit fetching whether the requester has participated in a thread to
# events which are thread roots.
@@ -320,7 +325,7 @@ class RelationsHandler:
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
await self._main_store.relations.get_threads_participated(
[
event_id
for event_id in thread_event_ids
@@ -331,8 +336,10 @@ class RelationsHandler:
)
# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_threaded_messages_per_user(
thread_event_ids, ignored_users
ignored_results = (
await self._main_store.relations.get_threaded_messages_per_user(
thread_event_ids, ignored_users
)
)
# A map of event ID to the thread aggregation.
@@ -361,7 +368,10 @@ class RelationsHandler:
continue
# Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event(
(
potential_events,
_,
) = await self._main_store.relations.get_relations_for_event(
event_id,
event,
room_id,
@@ -498,7 +508,7 @@ class RelationsHandler:
Note that there is no use in limiting edits by ignored users since the
parent event should be ignored in the first place if the user is ignored.
"""
edits = await self._main_store.get_applicable_edits(
edits = await self._main_store.relations.get_applicable_edits(
[
event_id
for event_id, event in events_by_id.items()
@@ -553,7 +563,7 @@ class RelationsHandler:
# Note that ignored users are not passed into get_threads
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
thread_roots, next_batch = await self._main_store.get_threads(
thread_roots, next_batch = await self._main_store.relations.get_threads(
room_id=room_id, limit=limit, from_token=from_token
)
@@ -565,7 +575,7 @@ class RelationsHandler:
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
await self._main_store.relations.get_threads_participated(
[eid for eid, p in participated.items() if not p],
user_id,
)

View File

@@ -25,6 +25,7 @@ from synapse.util.caches.descriptors import CachedFunction
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases import DataStore
logger = logging.getLogger(__name__)
@@ -44,11 +45,14 @@ class SQLBaseStore(metaclass=ABCMeta):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
datastore: Optional["DataStore"] = None,
):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
self.db_pool = database
# A reference back to the root datastore.
self.datastore = datastore
self.external_cached_functions: Dict[str, CachedFunction] = {}

View File

@@ -15,7 +15,8 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
import re
from typing import TYPE_CHECKING, Any, List, Match, Optional, Tuple, Type, cast
from synapse.api.constants import Direction
from synapse.config.homeserver import HomeServerConfig
@@ -24,6 +25,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage._base import SQLBaseStore
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.types import Cursor
@@ -121,7 +123,6 @@ class DataStore(
UserErasureStore,
MonthlyActiveUsersWorkerStore,
StatsStore,
RelationsStore,
CensorEventsStore,
UIAuthStore,
EventForwardExtremitiesStore,
@@ -129,6 +130,13 @@ class DataStore(
LockStore,
SessionStore,
):
DATASTORE_CLASSES: List[Type[SQLBaseStore]] = [
RelationsStore,
]
# XXX So mypy knows about dynamic properties.
relations: RelationsStore
def __init__(
self,
database: DatabasePool,
@@ -141,6 +149,19 @@ class DataStore(
super().__init__(database, db_conn, hs)
def repl(match: Match[str]) -> str:
return "_" + match.group(0).lower()
for datastore_class in self.DATASTORE_CLASSES:
name = datastore_class.__name__
if name.endswith("Store"):
name = name[: -len("Store")]
name = re.sub(r"[A-Z]", repl, name)[1:]
store = datastore_class(database, db_conn, hs, self)
setattr(self, name, store)
async def get_users(self) -> List[JsonDict]:
"""Function to retrieve a list of users in users table.

View File

@@ -52,6 +52,7 @@ from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -95,8 +96,9 @@ class RelationsWorkerStore(SQLBaseStore):
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
datastore: "DataStore",
):
super().__init__(database, db_conn, hs)
super().__init__(database, db_conn, hs, datastore)
self.db_pool.updates.register_background_update_handler(
"threads_backfill", self._backfill_threads

View File

@@ -58,28 +58,28 @@ class RelationsStoreTestCase(unittest.HomeserverTestCase):
Ensure that get_thread_id only searches up the tree for threads.
"""
# The thread itself and children of it return the thread.
thread_id = self.get_success(self._main_store.get_thread_id("B"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("B"))
self.assertEqual("A", thread_id)
thread_id = self.get_success(self._main_store.get_thread_id("C"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("C"))
self.assertEqual("A", thread_id)
# But the root and events related to the root do not.
thread_id = self.get_success(self._main_store.get_thread_id("A"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("A"))
self.assertEqual(MAIN_TIMELINE, thread_id)
thread_id = self.get_success(self._main_store.get_thread_id("D"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("D"))
self.assertEqual(MAIN_TIMELINE, thread_id)
thread_id = self.get_success(self._main_store.get_thread_id("E"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("E"))
self.assertEqual(MAIN_TIMELINE, thread_id)
# Events which are not related to a thread at all should return the
# main timeline.
thread_id = self.get_success(self._main_store.get_thread_id("F"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("F"))
self.assertEqual(MAIN_TIMELINE, thread_id)
thread_id = self.get_success(self._main_store.get_thread_id("G"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("G"))
self.assertEqual(MAIN_TIMELINE, thread_id)
def test_get_thread_id_for_receipts(self) -> None:
@@ -87,25 +87,35 @@ class RelationsStoreTestCase(unittest.HomeserverTestCase):
Ensure that get_thread_id_for_receipts searches up and down the tree for a thread.
"""
# All of the events are considered related to this thread.
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A"))
thread_id = self.get_success(
self._main_store.relations.get_thread_id_for_receipts("A")
)
self.assertEqual("A", thread_id)
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B"))
thread_id = self.get_success(
self._main_store.relations.get_thread_id_for_receipts("B")
)
self.assertEqual("A", thread_id)
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C"))
thread_id = self.get_success(
self._main_store.relations.get_thread_id_for_receipts("C")
)
self.assertEqual("A", thread_id)
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D"))
thread_id = self.get_success(
self._main_store.relations.get_thread_id_for_receipts("D")
)
self.assertEqual("A", thread_id)
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E"))
thread_id = self.get_success(
self._main_store.relations.get_thread_id_for_receipts("E")
)
self.assertEqual("A", thread_id)
# Events which are not related to a thread at all should return the
# main timeline.
thread_id = self.get_success(self._main_store.get_thread_id("F"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("F"))
self.assertEqual(MAIN_TIMELINE, thread_id)
thread_id = self.get_success(self._main_store.get_thread_id("G"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("G"))
self.assertEqual(MAIN_TIMELINE, thread_id)