Compare commits

...

10 Commits

Author SHA1 Message Date
Patrick Cloke
9796706a96 Fix references to get_events. 2023-05-17 14:26:01 -04:00
Patrick Cloke
86b836f2b5 Properly invalidate caches over replication. 2023-05-17 14:26:01 -04:00
Patrick Cloke
edb91e78b2 Fix-up other users of RelationsWorkerStore. 2023-05-17 14:26:01 -04:00
Patrick Cloke
0e67a3f703 Stop passing in the base-store. 2023-05-17 14:26:01 -04:00
Patrick Cloke
1630199e4d Update changelog. 2023-05-17 14:26:01 -04:00
Patrick Cloke
a3e154910a Lint 2023-05-17 14:26:01 -04:00
Patrick Cloke
b3e0354c98 Be explicit about datastores. 2023-05-17 14:26:01 -04:00
Patrick Cloke
e55ad9b6cf Clean-up more references to relations store. 2023-05-17 14:26:01 -04:00
Patrick Cloke
4b180db298 Newsfragment 2023-05-17 14:26:01 -04:00
Patrick Cloke
c25ec34d73 Magic 2023-05-17 14:26:01 -04:00
16 changed files with 157 additions and 60 deletions

1
changelog.d/14678.misc Normal file
View File

@@ -0,0 +1 @@
Improve type hints in datastores.

View File

@@ -24,6 +24,7 @@ import time
import traceback
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
@@ -53,7 +54,12 @@ from synapse.logging.context import (
run_in_background,
)
from synapse.notifier import ReplicationNotifier
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_conn,
)
from synapse.storage.databases.main import FilteringWorkerStore, PushRuleStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
@@ -94,6 +100,9 @@ from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor
from synapse.util import SYNAPSE_VERSION, Clock
if TYPE_CHECKING:
from synapse.server import HomeServer
# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
reactor = cast(ISynapseReactor, reactor_)
@@ -238,8 +247,18 @@ class Store(
PusherBackgroundUpdatesStore,
PresenceBackgroundUpdateStore,
ReceiptsBackgroundUpdateStore,
RelationsWorkerStore,
):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# This is a bit repetitive, but avoids dynamically setting attributes.
self.relations = RelationsWorkerStore(database, db_conn, hs)
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)

View File

@@ -507,7 +507,7 @@ class Filter:
# The event IDs to check, mypy doesn't understand the isinstance check.
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
event_ids_to_keep = set(
await self._store.events_have_relations(
await self._store.relations.events_have_relations(
event_ids, self.related_by_senders, self.related_by_rel_types
)
)

View File

@@ -75,7 +75,6 @@ class AdminCmdStore(
ApplicationServiceTransactionWorkerStore,
ApplicationServiceWorkerStore,
RoomMemberWorkerStore,
RelationsWorkerStore,
EventFederationWorkerStore,
EventPushActionsWorkerStore,
StateGroupWorkerStore,
@@ -101,6 +100,9 @@ class AdminCmdStore(
# should refactor it to take a `Clock` directly.
self.clock = hs.get_clock()
# This is a bit repetitive, but avoids dynamically setting attributes.
self.relations = RelationsWorkerStore(database, db_conn, hs)
class AdminCmdServer(HomeServer):
DATASTORE_CLASS = AdminCmdStore # type: ignore

View File

@@ -51,6 +51,7 @@ from synapse.rest.key.v2 import KeyResource
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.rest.well_known import well_known_resource
from synapse.server import HomeServer
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.appservice import (
ApplicationServiceTransactionWorkerStore,
@@ -132,7 +133,6 @@ class GenericWorkerStore(
ServerMetricsStore,
PusherWorkerStore,
RoomMemberWorkerStore,
RelationsWorkerStore,
EventFederationWorkerStore,
EventPushActionsWorkerStore,
StateGroupWorkerStore,
@@ -152,6 +152,17 @@ class GenericWorkerStore(
server_name: str
config: HomeServerConfig
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
# This is a bit repetitive, but avoids dynamically setting attributes.
self.relations = RelationsWorkerStore(database, db_conn, hs)
class GenericWorkerServer(HomeServer):
DATASTORE_CLASS = GenericWorkerStore # type: ignore

View File

@@ -1361,7 +1361,9 @@ class EventCreationHandler:
else:
# There must be some reason that the client knows the event exists,
# see if there are existing relations. If so, assume everything is fine.
if not await self.store.event_is_target_of_relation(relation.parent_id):
if not await self.store.relations.event_is_target_of_relation(
relation.parent_id
):
# Otherwise, the client can't know about the parent event!
raise SynapseError(400, "Can't send relation to unknown event")
@@ -1377,7 +1379,7 @@ class EventCreationHandler:
if len(aggregation_key) > 500:
raise SynapseError(400, "Aggregation key is too long")
already_exists = await self.store.has_user_annotated_event(
already_exists = await self.store.relations.has_user_annotated_event(
relation.parent_id, event.type, aggregation_key, event.sender
)
if already_exists:
@@ -1389,7 +1391,7 @@ class EventCreationHandler:
# Don't attempt to start a thread if the parent event is a relation.
elif relation.rel_type == RelationTypes.THREAD:
if await self.store.event_includes_relation(relation.parent_id):
if await self.store.relations.event_includes_relation(relation.parent_id):
raise SynapseError(
400, "Cannot start threads from an event with a relation"
)

View File

@@ -124,7 +124,10 @@ class RelationsHandler:
# Note that ignored users are not passed into get_relations_for_event
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
related_events, next_token = await self._main_store.get_relations_for_event(
(
related_events,
next_token,
) = await self._main_store.relations.get_relations_for_event(
event_id=event_id,
event=event,
room_id=room_id,
@@ -211,7 +214,7 @@ class RelationsHandler:
ShadowBanError if the requester is shadow-banned
"""
related_event_ids = (
await self._main_store.get_all_relations_for_event_with_types(
await self._main_store.relations.get_all_relations_for_event_with_types(
event_id, relation_types
)
)
@@ -250,7 +253,9 @@ class RelationsHandler:
A map of event IDs to a list related events.
"""
related_events = await self._main_store.get_references_for_events(event_ids)
related_events = await self._main_store.relations.get_references_for_events(
event_ids
)
# Avoid additional logic if there are no ignored users.
if not ignored_users:
@@ -304,7 +309,7 @@ class RelationsHandler:
event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(event_ids)
summaries = await self._main_store.relations.get_thread_summaries(event_ids)
# Limit fetching whether the requester has participated in a thread to
# events which are thread roots.
@@ -320,7 +325,7 @@ class RelationsHandler:
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
await self._main_store.relations.get_threads_participated(
[
event_id
for event_id in thread_event_ids
@@ -331,8 +336,10 @@ class RelationsHandler:
)
# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_threaded_messages_per_user(
thread_event_ids, ignored_users
ignored_results = (
await self._main_store.relations.get_threaded_messages_per_user(
thread_event_ids, ignored_users
)
)
# A map of event ID to the thread aggregation.
@@ -361,7 +368,10 @@ class RelationsHandler:
continue
# Attempt to find another event to use as the latest event.
potential_events, _ = await self._main_store.get_relations_for_event(
(
potential_events,
_,
) = await self._main_store.relations.get_relations_for_event(
event_id,
event,
room_id,
@@ -498,7 +508,7 @@ class RelationsHandler:
Note that there is no use in limiting edits by ignored users since the
parent event should be ignored in the first place if the user is ignored.
"""
edits = await self._main_store.get_applicable_edits(
edits = await self._main_store.relations.get_applicable_edits(
[
event_id
for event_id, event in events_by_id.items()
@@ -553,7 +563,7 @@ class RelationsHandler:
# Note that ignored users are not passed into get_threads
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
thread_roots, next_batch = await self._main_store.get_threads(
thread_roots, next_batch = await self._main_store.relations.get_threads(
room_id=room_id, limit=limit, from_token=from_token
)
@@ -565,7 +575,7 @@ class RelationsHandler:
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
await self._main_store.relations.get_threads_participated(
[eid for eid, p in participated.items() if not p],
user_id,
)

View File

@@ -368,7 +368,7 @@ class BulkPushRuleEvaluator:
else:
# Since the event has not yet been persisted we check whether
# the parent is part of a thread.
thread_id = await self.store.get_thread_id(relation.parent_id)
thread_id = await self.store.relations.get_thread_id(relation.parent_id)
related_events = await self._related_events(event)

View File

@@ -147,11 +147,15 @@ class ReceiptRestServlet(RestServlet):
# If the receipt is on the main timeline, it is enough to check whether
# the event is directly related to a thread.
if thread_id == MAIN_TIMELINE:
return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id)
return MAIN_TIMELINE == await self._main_store.relations.get_thread_id(
event_id
)
# Otherwise, check if the event is directly part of a thread, or is the
# root message (or related to the root message) of a thread.
return thread_id == await self._main_store.get_thread_id_for_receipts(event_id)
return thread_id == await self._main_store.relations.get_thread_id_for_receipts(
event_id
)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:

View File

@@ -118,7 +118,10 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
self,
cache_name: str,
key: Optional[Collection[Any]],
store_name: Optional[str] = None,
) -> bool:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
@@ -132,10 +135,21 @@ class SQLBaseStore(metaclass=ABCMeta):
cache_name
key: Entry to invalidate. If None then invalidates the entire
cache.
store_name: The name of the store, leave as None for stores which
have not yet been split out.
"""
# First get the store.
store = self
if store_name is not None:
try:
store = getattr(self, store_name)
except AttributeError:
pass
# Then attempt to find the cache on that store.
try:
cache = getattr(self, cache_name)
cache = getattr(store, cache_name)
except AttributeError:
# Check if an externally defined module cache has been registered
cache = self.external_cached_functions.get(cache_name)

View File

@@ -121,7 +121,6 @@ class DataStore(
UserErasureStore,
MonthlyActiveUsersWorkerStore,
StatsStore,
RelationsStore,
CensorEventsStore,
UIAuthStore,
EventForwardExtremitiesStore,
@@ -141,6 +140,9 @@ class DataStore(
super().__init__(database, db_conn, hs)
# This is a bit repetitive, but avoids dynamically setting attributes.
self.relations = RelationsStore(database, db_conn, hs)
async def get_users(self) -> List[JsonDict]:
"""Function to retrieve a list of users in users table.

View File

@@ -248,10 +248,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._invalidate_local_get_event_cache(redacts) # type: ignore[attr-defined]
# Caches which might leak edits must be invalidated for the event being
# redacted.
self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
self._attempt_to_invalidate_cache(
"get_relations_for_event", (redacts,), "relations"
)
self._attempt_to_invalidate_cache(
"get_applicable_edit", (redacts,), "relations"
)
self._attempt_to_invalidate_cache("get_thread_id", (redacts,), "relations")
self._attempt_to_invalidate_cache(
"get_thread_id_for_receipts", (redacts,), "relations"
)
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) # type: ignore[attr-defined]
@@ -264,12 +270,22 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,))
if relates_to:
self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,))
self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,))
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
self._attempt_to_invalidate_cache("get_threads", (room_id,))
self._attempt_to_invalidate_cache(
"get_relations_for_event", (relates_to,), "relations"
)
self._attempt_to_invalidate_cache(
"get_references_for_event", (relates_to,), "relations"
)
self._attempt_to_invalidate_cache(
"get_applicable_edit", (relates_to,), "relations"
)
self._attempt_to_invalidate_cache(
"get_thread_summary", (relates_to,), "relations"
)
self._attempt_to_invalidate_cache(
"get_thread_participated", (relates_to,), "relations"
)
self._attempt_to_invalidate_cache("get_threads", (room_id,), "relations")
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]

View File

@@ -2072,25 +2072,29 @@ class PersistEventsStore:
# Any relation information for the related event must be cleared.
self.store._invalidate_cache_and_stream(
txn, self.store.get_relations_for_event, (redacted_relates_to,)
txn, self.store.relations.get_relations_for_event, (redacted_relates_to,)
)
if rel_type == RelationTypes.REFERENCE:
self.store._invalidate_cache_and_stream(
txn, self.store.get_references_for_event, (redacted_relates_to,)
txn,
self.store.relations.get_references_for_event,
(redacted_relates_to,),
)
if rel_type == RelationTypes.REPLACE:
self.store._invalidate_cache_and_stream(
txn, self.store.get_applicable_edit, (redacted_relates_to,)
txn, self.store.relations.get_applicable_edit, (redacted_relates_to,)
)
if rel_type == RelationTypes.THREAD:
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_summary, (redacted_relates_to,)
txn, self.store.relations.get_thread_summary, (redacted_relates_to,)
)
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,)
txn,
self.store.relations.get_thread_participated,
(redacted_relates_to,),
)
self.store._invalidate_cache_and_stream(
txn, self.store.get_threads, (room_id,)
txn, self.store.relations.get_threads, (room_id,)
)
# Find the new latest event in the thread.

View File

@@ -1217,10 +1217,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
for parent_id in {r[1] for r in relations_to_insert}:
cache_tuple = (parent_id,)
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
txn, self.relations.get_relations_for_event, cache_tuple # type: ignore[attr-defined]
)
self._invalidate_cache_and_stream( # type: ignore[attr-defined]
txn, self.get_thread_summary, cache_tuple # type: ignore[attr-defined]
txn, self.relations.get_thread_summary, cache_tuple # type: ignore[attr-defined]
)
if results:

View File

@@ -557,14 +557,14 @@ class RelationsWorkerStore(SQLBaseStore):
"get_applicable_edits", _get_applicable_edits_txn
)
edits = await self.get_events(edit_ids.values()) # type: ignore[attr-defined]
edits = await self.hs.get_datastores().main.get_events(edit_ids.values())
# Map to the original event IDs to the edit events.
#
# There might not be an edit event due to there being no edits or
# due to the event not being known, either case is treated the same.
return {
original_event_id: edits.get(edit_ids.get(original_event_id))
original_event_id: edits.get(edit_ids.get(original_event_id)) # type: ignore[arg-type]
for original_event_id in event_ids
}
@@ -671,7 +671,9 @@ class RelationsWorkerStore(SQLBaseStore):
"get_thread_summaries", _get_thread_summaries_txn
)
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
latest_events = await self.hs.get_datastores().main.get_events(
latest_event_ids.values()
)
# Map to the event IDs to the thread summary.
#

View File

@@ -58,28 +58,28 @@ class RelationsStoreTestCase(unittest.HomeserverTestCase):
Ensure that get_thread_id only searches up the tree for threads.
"""
# The thread itself and children of it return the thread.
thread_id = self.get_success(self._main_store.get_thread_id("B"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("B"))
self.assertEqual("A", thread_id)
thread_id = self.get_success(self._main_store.get_thread_id("C"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("C"))
self.assertEqual("A", thread_id)
# But the root and events related to the root do not.
thread_id = self.get_success(self._main_store.get_thread_id("A"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("A"))
self.assertEqual(MAIN_TIMELINE, thread_id)
thread_id = self.get_success(self._main_store.get_thread_id("D"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("D"))
self.assertEqual(MAIN_TIMELINE, thread_id)
thread_id = self.get_success(self._main_store.get_thread_id("E"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("E"))
self.assertEqual(MAIN_TIMELINE, thread_id)
# Events which are not related to a thread at all should return the
# main timeline.
thread_id = self.get_success(self._main_store.get_thread_id("F"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("F"))
self.assertEqual(MAIN_TIMELINE, thread_id)
thread_id = self.get_success(self._main_store.get_thread_id("G"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("G"))
self.assertEqual(MAIN_TIMELINE, thread_id)
def test_get_thread_id_for_receipts(self) -> None:
@@ -87,25 +87,35 @@ class RelationsStoreTestCase(unittest.HomeserverTestCase):
Ensure that get_thread_id_for_receipts searches up and down the tree for a thread.
"""
# All of the events are considered related to this thread.
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A"))
thread_id = self.get_success(
self._main_store.relations.get_thread_id_for_receipts("A")
)
self.assertEqual("A", thread_id)
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B"))
thread_id = self.get_success(
self._main_store.relations.get_thread_id_for_receipts("B")
)
self.assertEqual("A", thread_id)
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C"))
thread_id = self.get_success(
self._main_store.relations.get_thread_id_for_receipts("C")
)
self.assertEqual("A", thread_id)
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D"))
thread_id = self.get_success(
self._main_store.relations.get_thread_id_for_receipts("D")
)
self.assertEqual("A", thread_id)
thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E"))
thread_id = self.get_success(
self._main_store.relations.get_thread_id_for_receipts("E")
)
self.assertEqual("A", thread_id)
# Events which are not related to a thread at all should return the
# main timeline.
thread_id = self.get_success(self._main_store.get_thread_id("F"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("F"))
self.assertEqual(MAIN_TIMELINE, thread_id)
thread_id = self.get_success(self._main_store.get_thread_id("G"))
thread_id = self.get_success(self._main_store.relations.get_thread_id("G"))
self.assertEqual(MAIN_TIMELINE, thread_id)