✨ Magic ✨
This commit is contained in:
@@ -124,7 +124,10 @@ class RelationsHandler:
|
||||
# Note that ignored users are not passed into get_relations_for_event
|
||||
# below. Ignored users are handled in filter_events_for_client (and by
|
||||
# not passing them in here we should get a better cache hit rate).
|
||||
related_events, next_token = await self._main_store.get_relations_for_event(
|
||||
(
|
||||
related_events,
|
||||
next_token,
|
||||
) = await self._main_store.relations.get_relations_for_event(
|
||||
event_id=event_id,
|
||||
event=event,
|
||||
room_id=room_id,
|
||||
@@ -211,7 +214,7 @@ class RelationsHandler:
|
||||
ShadowBanError if the requester is shadow-banned
|
||||
"""
|
||||
related_event_ids = (
|
||||
await self._main_store.get_all_relations_for_event_with_types(
|
||||
await self._main_store.relations.get_all_relations_for_event_with_types(
|
||||
event_id, relation_types
|
||||
)
|
||||
)
|
||||
@@ -250,7 +253,9 @@ class RelationsHandler:
|
||||
A map of event IDs to a list related events.
|
||||
"""
|
||||
|
||||
related_events = await self._main_store.get_references_for_events(event_ids)
|
||||
related_events = await self._main_store.relations.get_references_for_events(
|
||||
event_ids
|
||||
)
|
||||
|
||||
# Avoid additional logic if there are no ignored users.
|
||||
if not ignored_users:
|
||||
@@ -304,7 +309,7 @@ class RelationsHandler:
|
||||
event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
|
||||
|
||||
# Fetch thread summaries.
|
||||
summaries = await self._main_store.get_thread_summaries(event_ids)
|
||||
summaries = await self._main_store.relations.get_thread_summaries(event_ids)
|
||||
|
||||
# Limit fetching whether the requester has participated in a thread to
|
||||
# events which are thread roots.
|
||||
@@ -320,7 +325,7 @@ class RelationsHandler:
|
||||
# For events the requester did not send, check the database for whether
|
||||
# the requester sent a threaded reply.
|
||||
participated.update(
|
||||
await self._main_store.get_threads_participated(
|
||||
await self._main_store.relations.get_threads_participated(
|
||||
[
|
||||
event_id
|
||||
for event_id in thread_event_ids
|
||||
@@ -331,8 +336,10 @@ class RelationsHandler:
|
||||
)
|
||||
|
||||
# Then subtract off the results for any ignored users.
|
||||
ignored_results = await self._main_store.get_threaded_messages_per_user(
|
||||
thread_event_ids, ignored_users
|
||||
ignored_results = (
|
||||
await self._main_store.relations.get_threaded_messages_per_user(
|
||||
thread_event_ids, ignored_users
|
||||
)
|
||||
)
|
||||
|
||||
# A map of event ID to the thread aggregation.
|
||||
@@ -361,7 +368,10 @@ class RelationsHandler:
|
||||
continue
|
||||
|
||||
# Attempt to find another event to use as the latest event.
|
||||
potential_events, _ = await self._main_store.get_relations_for_event(
|
||||
(
|
||||
potential_events,
|
||||
_,
|
||||
) = await self._main_store.relations.get_relations_for_event(
|
||||
event_id,
|
||||
event,
|
||||
room_id,
|
||||
@@ -498,7 +508,7 @@ class RelationsHandler:
|
||||
Note that there is no use in limiting edits by ignored users since the
|
||||
parent event should be ignored in the first place if the user is ignored.
|
||||
"""
|
||||
edits = await self._main_store.get_applicable_edits(
|
||||
edits = await self._main_store.relations.get_applicable_edits(
|
||||
[
|
||||
event_id
|
||||
for event_id, event in events_by_id.items()
|
||||
@@ -553,7 +563,7 @@ class RelationsHandler:
|
||||
# Note that ignored users are not passed into get_threads
|
||||
# below. Ignored users are handled in filter_events_for_client (and by
|
||||
# not passing them in here we should get a better cache hit rate).
|
||||
thread_roots, next_batch = await self._main_store.get_threads(
|
||||
thread_roots, next_batch = await self._main_store.relations.get_threads(
|
||||
room_id=room_id, limit=limit, from_token=from_token
|
||||
)
|
||||
|
||||
@@ -565,7 +575,7 @@ class RelationsHandler:
|
||||
# For events the requester did not send, check the database for whether
|
||||
# the requester sent a threaded reply.
|
||||
participated.update(
|
||||
await self._main_store.get_threads_participated(
|
||||
await self._main_store.relations.get_threads_participated(
|
||||
[eid for eid, p in participated.items() if not p],
|
||||
user_id,
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ from synapse.util.caches.descriptors import CachedFunction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,11 +45,14 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
datastore: Optional["DataStore"] = None,
|
||||
):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self.database_engine = database.engine
|
||||
self.db_pool = database
|
||||
# A reference back to the root datastore.
|
||||
self.datastore = datastore
|
||||
|
||||
self.external_cached_functions: Dict[str, CachedFunction] = {}
|
||||
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, List, Match, Optional, Tuple, Type, cast
|
||||
|
||||
from synapse.api.constants import Direction
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
@@ -24,6 +25,7 @@ from synapse.storage.database import (
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.databases.main.stats import UserSortOrder
|
||||
from synapse.storage.engines import BaseDatabaseEngine
|
||||
from synapse.storage.types import Cursor
|
||||
@@ -121,7 +123,6 @@ class DataStore(
|
||||
UserErasureStore,
|
||||
MonthlyActiveUsersWorkerStore,
|
||||
StatsStore,
|
||||
RelationsStore,
|
||||
CensorEventsStore,
|
||||
UIAuthStore,
|
||||
EventForwardExtremitiesStore,
|
||||
@@ -129,6 +130,13 @@ class DataStore(
|
||||
LockStore,
|
||||
SessionStore,
|
||||
):
|
||||
DATASTORE_CLASSES: List[Type[SQLBaseStore]] = [
|
||||
RelationsStore,
|
||||
]
|
||||
|
||||
# XXX So mypy knows about dynamic properties.
|
||||
relations: RelationsStore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
@@ -141,6 +149,19 @@ class DataStore(
|
||||
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
def repl(match: Match[str]) -> str:
|
||||
return "_" + match.group(0).lower()
|
||||
|
||||
for datastore_class in self.DATASTORE_CLASSES:
|
||||
name = datastore_class.__name__
|
||||
if name.endswith("Store"):
|
||||
name = name[: -len("Store")]
|
||||
|
||||
name = re.sub(r"[A-Z]", repl, name)[1:]
|
||||
|
||||
store = datastore_class(database, db_conn, hs, self)
|
||||
setattr(self, name, store)
|
||||
|
||||
async def get_users(self) -> List[JsonDict]:
|
||||
"""Function to retrieve a list of users in users table.
|
||||
|
||||
|
||||
@@ -52,6 +52,7 @@ from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -95,8 +96,9 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
datastore: "DataStore",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
super().__init__(database, db_conn, hs, datastore)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"threads_backfill", self._backfill_threads
|
||||
|
||||
@@ -58,28 +58,28 @@ class RelationsStoreTestCase(unittest.HomeserverTestCase):
|
||||
Ensure that get_thread_id only searches up the tree for threads.
|
||||
"""
|
||||
# The thread itself and children of it return the thread.
|
||||
thread_id = self.get_success(self._main_store.get_thread_id("B"))
|
||||
thread_id = self.get_success(self._main_store.relations.get_thread_id("B"))
|
||||
self.assertEqual("A", thread_id)
|
||||
|
||||
thread_id = self.get_success(self._main_store.get_thread_id("C"))
|
||||
thread_id = self.get_success(self._main_store.relations.get_thread_id("C"))
|
||||
self.assertEqual("A", thread_id)
|
||||
|
||||
# But the root and events related to the root do not.
|
||||
thread_id = self.get_success(self._main_store.get_thread_id("A"))
|
||||
thread_id = self.get_success(self._main_store.relations.get_thread_id("A"))
|
||||
self.assertEqual(MAIN_TIMELINE, thread_id)
|
||||
|
||||
thread_id = self.get_success(self._main_store.get_thread_id("D"))
|
||||
thread_id = self.get_success(self._main_store.relations.get_thread_id("D"))
|
||||
self.assertEqual(MAIN_TIMELINE, thread_id)
|
||||
|
||||
thread_id = self.get_success(self._main_store.get_thread_id("E"))
|
||||
thread_id = self.get_success(self._main_store.relations.get_thread_id("E"))
|
||||
self.assertEqual(MAIN_TIMELINE, thread_id)
|
||||
|
||||
# Events which are not related to a thread at all should return the
|
||||
# main timeline.
|
||||
thread_id = self.get_success(self._main_store.get_thread_id("F"))
|
||||
thread_id = self.get_success(self._main_store.relations.get_thread_id("F"))
|
||||
self.assertEqual(MAIN_TIMELINE, thread_id)
|
||||
|
||||
thread_id = self.get_success(self._main_store.get_thread_id("G"))
|
||||
thread_id = self.get_success(self._main_store.relations.get_thread_id("G"))
|
||||
self.assertEqual(MAIN_TIMELINE, thread_id)
|
||||
|
||||
def test_get_thread_id_for_receipts(self) -> None:
|
||||
@@ -87,25 +87,35 @@ class RelationsStoreTestCase(unittest.HomeserverTestCase):
|
||||
Ensure that get_thread_id_for_receipts searches up and down the tree for a thread.
|
||||
"""
|
||||
# All of the events are considered related to this thread.
|
||||
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A"))
|
||||
thread_id = self.get_success(
|
||||
self._main_store.relations.get_thread_id_for_receipts("A")
|
||||
)
|
||||
self.assertEqual("A", thread_id)
|
||||
|
||||
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B"))
|
||||
thread_id = self.get_success(
|
||||
self._main_store.relations.get_thread_id_for_receipts("B")
|
||||
)
|
||||
self.assertEqual("A", thread_id)
|
||||
|
||||
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C"))
|
||||
thread_id = self.get_success(
|
||||
self._main_store.relations.get_thread_id_for_receipts("C")
|
||||
)
|
||||
self.assertEqual("A", thread_id)
|
||||
|
||||
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D"))
|
||||
thread_id = self.get_success(
|
||||
self._main_store.relations.get_thread_id_for_receipts("D")
|
||||
)
|
||||
self.assertEqual("A", thread_id)
|
||||
|
||||
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E"))
|
||||
thread_id = self.get_success(
|
||||
self._main_store.relations.get_thread_id_for_receipts("E")
|
||||
)
|
||||
self.assertEqual("A", thread_id)
|
||||
|
||||
# Events which are not related to a thread at all should return the
|
||||
# main timeline.
|
||||
thread_id = self.get_success(self._main_store.get_thread_id("F"))
|
||||
thread_id = self.get_success(self._main_store.relations.get_thread_id("F"))
|
||||
self.assertEqual(MAIN_TIMELINE, thread_id)
|
||||
|
||||
thread_id = self.get_success(self._main_store.get_thread_id("G"))
|
||||
thread_id = self.get_success(self._main_store.relations.get_thread_id("G"))
|
||||
self.assertEqual(MAIN_TIMELINE, thread_id)
|
||||
|
||||
Reference in New Issue
Block a user