diff --git a/changelog.d/18674.feature b/changelog.d/18674.feature new file mode 100644 index 0000000000..b1a1aa11f1 --- /dev/null +++ b/changelog.d/18674.feature @@ -0,0 +1 @@ +Add experimental and incomplete support for [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-spec-proposals/blob/rei/msc_thread_subscriptions/proposals/4306-thread-subscriptions.md). \ No newline at end of file diff --git a/docker/configure_workers_and_start.py b/docker/configure_workers_and_start.py index 7909b9d932..6212a94042 100755 --- a/docker/configure_workers_and_start.py +++ b/docker/configure_workers_and_start.py @@ -327,6 +327,15 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = { "shared_extra_conf": {}, "worker_extra_conf": "", }, + "thread_subscriptions": { + "app": "synapse.app.generic_worker", + "listener_resources": ["client", "replication"], + "endpoint_patterns": [ + "^/_matrix/client/unstable/io.element.msc4306/.*", + ], + "shared_extra_conf": {}, + "worker_extra_conf": "", + }, } # Templates for sections that may be inserted multiple times in config files @@ -427,6 +436,7 @@ def add_worker_roles_to_shared_config( "to_device", "typing", "push_rules", + "thread_subscriptions", } # Worker-type specific sharding config. Now a single worker can fulfill multiple diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 6c3e380355..9a0b459e65 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -136,6 +136,7 @@ BOOLEAN_COLUMNS = { "has_known_state", "is_encrypted", ], + "thread_subscriptions": ["subscribed", "automatic"], "users": ["shadow_banned", "approved", "locked", "suspended"], "un_partial_stated_event_stream": ["rejection_status_changed"], "users_who_share_rooms": ["share_private"], diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index d0924c413b..4f5bea6bd6 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -104,6 +104,9 @@ from synapse.storage.databases.main.stats import StatsStore from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore +from synapse.storage.databases.main.thread_subscriptions import ( + ThreadSubscriptionsWorkerStore, +) from synapse.storage.databases.main.transactions import TransactionWorkerStore from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore from synapse.storage.databases.main.user_directory import UserDirectoryStore @@ -132,6 +135,7 @@ class GenericWorkerStore( KeyStore, RoomWorkerStore, DirectoryWorkerStore, + ThreadSubscriptionsWorkerStore, PushRulesWorkerStore, ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index b14bc97ae7..1b7474034f 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -581,3 +581,7 @@ class ExperimentalConfig(Config): # MSC4155: Invite filtering self.msc4155_enabled: bool = experimental.get("msc4155_enabled", False) + + # MSC4306: Thread Subscriptions + # (and MSC4308: sliding sync extension for thread subscriptions) + self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index c0c8a13861..825ba78482 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -174,6 +174,10 @@ class WriterLocations: default=[MAIN_PROCESS_INSTANCE_NAME], converter=_instance_to_list_converter, ) + thread_subscriptions: List[str] = attr.ib( + default=["master"], + converter=_instance_to_list_converter, + ) @attr.s(auto_attribs=True) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 8d4d84bed1..305363892f 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -187,6 +187,9 @@ class DeactivateAccountHandler: # Remove account data (including ignored users and push rules). await self.store.purge_account_data_for_user(user_id) + # Remove thread subscriptions for the user + await self.store.purge_thread_subscription_settings_for_user(user_id) + # Delete any server-side backup keys await self.store.bulk_delete_backup_keys_and_versions_for_user(user_id) diff --git a/synapse/handlers/thread_subscriptions.py b/synapse/handlers/thread_subscriptions.py new file mode 100644 index 0000000000..79e4d6040d --- /dev/null +++ b/synapse/handlers/thread_subscriptions.py @@ -0,0 +1,126 @@ +import logging +from typing import TYPE_CHECKING, Optional + +from synapse.api.errors import AuthError, NotFoundError +from synapse.storage.databases.main.thread_subscriptions import ThreadSubscription +from synapse.types import UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class ThreadSubscriptionsHandler: + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastores().main + self.event_handler = hs.get_event_handler() + self.auth = hs.get_auth() + + async def get_thread_subscription_settings( + self, + user_id: UserID, + room_id: str, + thread_root_event_id: str, + ) -> Optional[ThreadSubscription]: + """Get thread subscription settings for a specific thread and user. + Checks that the thread root is both a real event and also that it is visible + to the user. + + Args: + user_id: The ID of the user + thread_root_event_id: The event ID of the thread root + + Returns: + A `ThreadSubscription` containing the active subscription settings or None if not set + """ + # First check that the user can access the thread root event + # and that it exists + try: + event = await self.event_handler.get_event( + user_id, room_id, thread_root_event_id + ) + if event is None: + raise NotFoundError("No such thread root") + except AuthError: + raise NotFoundError("No such thread root") + + return await self.store.get_subscription_for_thread( + user_id.to_string(), event.room_id, thread_root_event_id + ) + + async def subscribe_user_to_thread( + self, + user_id: UserID, + room_id: str, + thread_root_event_id: str, + *, + automatic: bool, + ) -> Optional[int]: + """Sets or updates a user's subscription settings for a specific thread root. + + Args: + requester_user_id: The ID of the user whose settings are being updated. + thread_root_event_id: The event ID of the thread root. + automatic: whether the user was subscribed by an automatic decision by + their client. + + Returns: + The stream ID for this update, if the update isn't no-opped. + + Raises: + NotFoundError if the user cannot access the thread root event, or it isn't + known to this homeserver. + """ + # First check that the user can access the thread root event + # and that it exists + try: + event = await self.event_handler.get_event( + user_id, room_id, thread_root_event_id + ) + if event is None: + raise NotFoundError("No such thread root") + except AuthError: + logger.info("rejecting thread subscriptions change (thread not accessible)") + raise NotFoundError("No such thread root") + + return await self.store.subscribe_user_to_thread( + user_id.to_string(), + event.room_id, + thread_root_event_id, + automatic=automatic, + ) + + async def unsubscribe_user_from_thread( + self, user_id: UserID, room_id: str, thread_root_event_id: str + ) -> Optional[int]: + """Clears a user's subscription settings for a specific thread root. + + Args: + requester_user_id: The ID of the user whose settings are being updated. + thread_root_event_id: The event ID of the thread root. + + Returns: + The stream ID for this update, if the update isn't no-opped. + + Raises: + NotFoundError if the user cannot access the thread root event, or it isn't + known to this homeserver. + """ + # First check that the user can access the thread root event + # and that it exists + try: + event = await self.event_handler.get_event( + user_id, room_id, thread_root_event_id + ) + if event is None: + raise NotFoundError("No such thread root") + except AuthError: + logger.info("rejecting thread subscriptions change (thread not accessible)") + raise NotFoundError("No such thread root") + + return await self.store.unsubscribe_user_from_thread( + user_id.to_string(), + event.room_id, + thread_root_event_id, + ) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index e434bed3e5..3611c678c2 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -72,7 +72,10 @@ from synapse.replication.tcp.streams import ( ToDeviceStream, TypingStream, ) -from synapse.replication.tcp.streams._base import DeviceListsStream +from synapse.replication.tcp.streams._base import ( + DeviceListsStream, + ThreadSubscriptionsStream, +) if TYPE_CHECKING: from synapse.server import HomeServer @@ -186,6 +189,15 @@ class ReplicationCommandHandler: continue + if isinstance(stream, ThreadSubscriptionsStream): + if ( + hs.get_instance_name() + in hs.config.worker.writers.thread_subscriptions + ): + self._streams_to_replicate.append(stream) + + continue + if isinstance(stream, DeviceListsStream): if hs.get_instance_name() in hs.config.worker.writers.device_lists: self._streams_to_replicate.append(stream) diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 677dcb7b40..25c15e5d48 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -41,6 +41,7 @@ from synapse.replication.tcp.streams._base import ( PushRulesStream, ReceiptsStream, Stream, + ThreadSubscriptionsStream, ToDeviceStream, TypingStream, ) @@ -67,6 +68,7 @@ STREAMS_MAP = { ToDeviceStream, FederationStream, AccountDataStream, + ThreadSubscriptionsStream, UnPartialStatedRoomStream, UnPartialStatedEventStream, ) @@ -86,6 +88,7 @@ __all__ = [ "DeviceListsStream", "ToDeviceStream", "AccountDataStream", + "ThreadSubscriptionsStream", "UnPartialStatedRoomStream", "UnPartialStatedEventStream", ] diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index ebf5964d29..3ef86486e6 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -723,3 +723,47 @@ class AccountDataStream(_StreamFromIdGen): heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0]) ) return updates, to_token, limited + + +class ThreadSubscriptionsStream(_StreamFromIdGen): + """A thread subscription was changed.""" + + @attr.s(slots=True, auto_attribs=True) + class ThreadSubscriptionsStreamRow: + """Stream to inform workers about changes to thread subscriptions.""" + + user_id: str + room_id: str + event_id: str # The event ID of the thread root + + NAME = "thread_subscriptions" + ROW_TYPE = ThreadSubscriptionsStreamRow + + def __init__(self, hs: Any): + self.store = hs.get_datastores().main + super().__init__( + hs.get_instance_name(), + self._update_function, + self.store._thread_subscriptions_id_gen, + ) + + async def _update_function( + self, instance_name: str, from_token: int, to_token: int, limit: int + ) -> StreamUpdateResult: + updates = await self.store.get_updated_thread_subscriptions( + from_token, to_token, limit + ) + rows = [ + ( + stream_id, + # These are the args to `ThreadSubscriptionsStreamRow` + (user_id, room_id, event_id), + ) + for stream_id, user_id, room_id, event_id in updates + ] + + logger.error("TS %d->%d %r", from_token, to_token, rows) + if not rows: + return [], to_token, False + + return rows, rows[-1][0], len(updates) == limit diff --git a/synapse/rest/client/thread_subscriptions.py b/synapse/rest/client/thread_subscriptions.py new file mode 100644 index 0000000000..5307132ec3 --- /dev/null +++ b/synapse/rest/client/thread_subscriptions.py @@ -0,0 +1,98 @@ +from http import HTTPStatus +from typing import Tuple + +from synapse._pydantic_compat import StrictBool +from synapse.api.errors import Codes, NotFoundError, SynapseError +from synapse.http.server import HttpServer +from synapse.http.servlet import ( + RestServlet, + parse_and_validate_json_object_from_request, +) +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns +from synapse.server import HomeServer +from synapse.types import JsonDict, RoomID +from synapse.types.rest import RequestBodyModel + + +class ThreadSubscriptionsRestServlet(RestServlet): + PATTERNS = client_patterns( + "/io.element.msc4306/rooms/(?P[^/]*)/thread/(?P[^/]*)/subscription$", + unstable=True, + releases=(), + ) + CATEGORY = "Thread Subscriptions requests (unstable)" + + def __init__(self, hs: "HomeServer"): + self.auth = hs.get_auth() + self.is_mine = hs.is_mine + self.store = hs.get_datastores().main + self.handler = hs.get_thread_subscriptions_handler() + + class PutBody(RequestBodyModel): + automatic: StrictBool + + async def on_GET( + self, request: SynapseRequest, room_id: str, thread_root_id: str + ) -> Tuple[int, JsonDict]: + RoomID.from_string(room_id) + if not thread_root_id.startswith("$"): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM + ) + requester = await self.auth.get_user_by_req(request) + + subscription = await self.handler.get_thread_subscription_settings( + requester.user, + room_id, + thread_root_id, + ) + + if subscription is None: + raise NotFoundError("Not subscribed.") + + return HTTPStatus.OK, {"automatic": subscription.automatic} + + async def on_PUT( + self, request: SynapseRequest, room_id: str, thread_root_id: str + ) -> Tuple[int, JsonDict]: + RoomID.from_string(room_id) + if not thread_root_id.startswith("$"): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM + ) + requester = await self.auth.get_user_by_req(request) + + body = parse_and_validate_json_object_from_request(request, self.PutBody) + + await self.handler.subscribe_user_to_thread( + requester.user, + room_id, + thread_root_id, + automatic=body.automatic, + ) + + return HTTPStatus.OK, {} + + async def on_DELETE( + self, request: SynapseRequest, room_id: str, thread_root_id: str + ) -> Tuple[int, JsonDict]: + RoomID.from_string(room_id) + if not thread_root_id.startswith("$"): + raise SynapseError( + HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM + ) + requester = await self.auth.get_user_by_req(request) + + await self.handler.unsubscribe_user_from_thread( + requester.user, + room_id, + thread_root_id, + ) + + return HTTPStatus.OK, {} + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc4306_enabled: + ThreadSubscriptionsRestServlet(hs).register(http_server) diff --git a/synapse/server.py b/synapse/server.py index 5270f7792d..231bd14907 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -117,6 +117,7 @@ from synapse.handlers.sliding_sync import SlidingSyncHandler from synapse.handlers.sso import SsoHandler from synapse.handlers.stats import StatsHandler from synapse.handlers.sync import SyncHandler +from synapse.handlers.thread_subscriptions import ThreadSubscriptionsHandler from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler from synapse.handlers.user_directory import UserDirectoryHandler from synapse.handlers.worker_lock import WorkerLocksHandler @@ -789,6 +790,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_timestamp_lookup_handler(self) -> TimestampLookupHandler: return TimestampLookupHandler(self) + @cache_in_self + def get_thread_subscriptions_handler(self) -> ThreadSubscriptionsHandler: + return ThreadSubscriptionsHandler(self) + @cache_in_self def get_registration_handler(self) -> RegistrationHandler: return RegistrationHandler(self) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 86431f6e40..de55c452ae 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -19,7 +19,6 @@ # [This file includes modifications made by New Vector Limited] # # - import logging from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast @@ -35,6 +34,9 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.sliding_sync import SlidingSyncStore from synapse.storage.databases.main.stats import UserSortOrder +from synapse.storage.databases.main.thread_subscriptions import ( + ThreadSubscriptionsWorkerStore, +) from synapse.storage.engines import BaseDatabaseEngine from synapse.storage.types import Cursor from synapse.types import get_domain_from_id @@ -141,6 +143,7 @@ class DataStore( SearchStore, TagsStore, AccountDataStore, + ThreadSubscriptionsWorkerStore, PushRulesWorkerStore, StreamWorkerStore, OpenIdStore, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index b7fbfdc0ca..741146417f 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2986,6 +2986,10 @@ class PersistEventsStore: # Upsert into the threads table, but only overwrite the value if the # new event is of a later topological order OR if the topological # ordering is equal, but the stream ordering is later. + # (Note by definition that the stream ordering will always be later + # unless this is a backfilled event [= negative stream ordering] + # because we are only persisting this event now and stream_orderings + # are strictly monotonically increasing) sql = """ INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering) VALUES (?, ?, ?, ?, ?) diff --git a/synapse/storage/databases/main/thread_subscriptions.py b/synapse/storage/databases/main/thread_subscriptions.py new file mode 100644 index 0000000000..e04e692e6a --- /dev/null +++ b/synapse/storage/databases/main/thread_subscriptions.py @@ -0,0 +1,382 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +import logging +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, + cast, +) + +import attr + +from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) +from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore +from synapse.storage.util.id_generators import MultiWriterIdGenerator +from synapse.util.caches.descriptors import cached + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadSubscription: + automatic: bool + """ + whether the subscription was made automatically (as opposed to by manual + action from the user) + """ + + +class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self._can_write_to_thread_subscriptions = ( + self._instance_name in hs.config.worker.writers.thread_subscriptions + ) + + self._thread_subscriptions_id_gen: MultiWriterIdGenerator = ( + MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + notifier=hs.get_replication_notifier(), + stream_name="thread_subscriptions", + instance_name=self._instance_name, + tables=[ + ("thread_subscriptions", "instance_name", "stream_id"), + ], + sequence_name="thread_subscriptions_sequence", + writers=hs.config.worker.writers.thread_subscriptions, + ) + ) + + def process_replication_rows( + self, + stream_name: str, + instance_name: str, + token: int, + rows: Iterable[Any], + ) -> None: + if stream_name == ThreadSubscriptionsStream.NAME: + for row in rows: + self.get_subscription_for_thread.invalidate( + (row.user_id, row.room_id, row.event_id) + ) + + super().process_replication_rows(stream_name, instance_name, token, rows) + + def process_replication_position( + self, stream_name: str, instance_name: str, token: int + ) -> None: + if stream_name == ThreadSubscriptionsStream.NAME: + self._thread_subscriptions_id_gen.advance(instance_name, token) + super().process_replication_position(stream_name, instance_name, token) + + async def subscribe_user_to_thread( + self, user_id: str, room_id: str, thread_root_event_id: str, *, automatic: bool + ) -> Optional[int]: + """Updates a user's subscription settings for a specific thread root. + + If no change would be made to the subscription, does not produce any database change. + + Args: + user_id: The ID of the user whose settings are being updated. + room_id: The ID of the room the thread root belongs to. + thread_root_event_id: The event ID of the thread root. + automatic: Whether the subscription was performed automatically by the user's client. + Only `False` will overwrite an existing value of automatic for a subscription row. + + Returns: + The stream ID for this update, if the update isn't no-opped. + """ + assert self._can_write_to_thread_subscriptions + + def _subscribe_user_to_thread_txn(txn: LoggingTransaction) -> Optional[int]: + already_automatic = self.db_pool.simple_select_one_onecol_txn( + txn, + table="thread_subscriptions", + keyvalues={ + "user_id": user_id, + "event_id": thread_root_event_id, + "room_id": room_id, + "subscribed": True, + }, + retcol="automatic", + allow_none=True, + ) + + if already_automatic is None: + already_subscribed = False + already_automatic = True + else: + already_subscribed = True + # convert int (SQLite bool) to Python bool + already_automatic = bool(already_automatic) + + if already_subscribed and already_automatic == automatic: + # there is nothing we need to do here + return None + + stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) + + values: Dict[str, Optional[Union[bool, int, str]]] = { + "subscribed": True, + "stream_id": stream_id, + "instance_name": self._instance_name, + "automatic": already_automatic and automatic, + } + + self.db_pool.simple_upsert_txn( + txn, + table="thread_subscriptions", + keyvalues={ + "user_id": user_id, + "event_id": thread_root_event_id, + "room_id": room_id, + }, + values=values, + ) + + txn.call_after( + self.get_subscription_for_thread.invalidate, + (user_id, room_id, thread_root_event_id), + ) + + return stream_id + + return await self.db_pool.runInteraction( + "subscribe_user_to_thread", _subscribe_user_to_thread_txn + ) + + async def unsubscribe_user_from_thread( + self, user_id: str, room_id: str, thread_root_event_id: str + ) -> Optional[int]: + """Unsubscribes a user from a thread. + + If no change would be made to the subscription, does not produce any database change. + + Args: + user_id: The ID of the user whose settings are being updated. + room_id: The ID of the room the thread root belongs to. + thread_root_event_id: The event ID of the thread root. + + Returns: + The stream ID for this update, if the update isn't no-opped. + """ + + assert self._can_write_to_thread_subscriptions + + def _unsubscribe_user_from_thread_txn(txn: LoggingTransaction) -> Optional[int]: + already_subscribed = self.db_pool.simple_select_one_onecol_txn( + txn, + table="thread_subscriptions", + keyvalues={ + "user_id": user_id, + "event_id": thread_root_event_id, + "room_id": room_id, + }, + retcol="subscribed", + allow_none=True, + ) + + if already_subscribed is None or already_subscribed is False: + # there is nothing we need to do here + return None + + stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn) + + self.db_pool.simple_update_txn( + txn, + table="thread_subscriptions", + keyvalues={ + "user_id": user_id, + "event_id": thread_root_event_id, + "room_id": room_id, + "subscribed": True, + }, + updatevalues={ + "subscribed": False, + "stream_id": stream_id, + "instance_name": self._instance_name, + }, + ) + + txn.call_after( + self.get_subscription_for_thread.invalidate, + (user_id, room_id, thread_root_event_id), + ) + + return stream_id + + return await self.db_pool.runInteraction( + "unsubscribe_user_from_thread", _unsubscribe_user_from_thread_txn + ) + + async def purge_thread_subscription_settings_for_user(self, user_id: str) -> None: + """ + Purge all subscriptions for the user. + The fact that subscriptions have been purged will not be streamed; + all stream rows for the user will in fact be removed. + This is intended only for dealing with user deactivation. + """ + + def _purge_thread_subscription_settings_for_user_txn( + txn: LoggingTransaction, + ) -> None: + self.db_pool.simple_delete_txn( + txn, + table="thread_subscriptions", + keyvalues={"user_id": user_id}, + ) + self._invalidate_cache_and_stream( + txn, self.get_subscription_for_thread, (user_id,) + ) + + await self.db_pool.runInteraction( + desc="purge_thread_subscription_settings_for_user", + func=_purge_thread_subscription_settings_for_user_txn, + ) + + @cached(tree=True) + async def get_subscription_for_thread( + self, user_id: str, room_id: str, thread_root_event_id: str + ) -> Optional[ThreadSubscription]: + """Get the thread subscription for a specific thread and user. + + Args: + user_id: The ID of the user + room_id: The ID of the room + thread_root_event_id: The event ID of the thread root + + Returns: + A `ThreadSubscription` dataclass if there is a subscription, + or `None` if there is no subscription. + + If there is a row in the table but `subscribed` is `False`, + behaves the same as if there was no row at all and returns `None`. + """ + row = await self.db_pool.simple_select_one( + table="thread_subscriptions", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "event_id": thread_root_event_id, + "subscribed": True, + }, + retcols=("automatic",), + allow_none=True, + desc="get_subscription_for_thread", + ) + + if row is None: + return None + + (automatic_rawbool,) = row + + # convert SQLite integer booleans into real booleans + automatic = bool(automatic_rawbool) + + return ThreadSubscription(automatic=automatic) + + def get_max_thread_subscriptions_stream_id(self) -> int: + """Get the current maximum stream_id for thread subscriptions. + + Returns: + The maximum stream_id + """ + return self._thread_subscriptions_id_gen.get_current_token() + + async def get_updated_thread_subscriptions( + self, from_id: int, to_id: int, limit: int + ) -> List[Tuple[int, str, str, str]]: + """Get updates to thread subscriptions between two stream IDs. + + Args: + from_id: The starting stream ID (exclusive) + to_id: The ending stream ID (inclusive) + limit: The maximum number of rows to return + + Returns: + list of (stream_id, user_id, room_id, thread_root_id) tuples + """ + + def get_updated_thread_subscriptions_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str, str]]: + sql = """ + SELECT stream_id, user_id, room_id, event_id + FROM thread_subscriptions + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + + txn.execute(sql, (from_id, to_id, limit)) + return cast(List[Tuple[int, str, str, str]], txn.fetchall()) + + return await self.db_pool.runInteraction( + "get_updated_thread_subscriptions", + get_updated_thread_subscriptions_txn, + ) + + async def get_updated_thread_subscriptions_for_user( + self, user_id: str, from_id: int, to_id: int, limit: int + ) -> List[Tuple[int, str, str]]: + """Get updates to thread subscriptions for a specific user. + + Args: + user_id: The ID of the user + from_id: The starting stream ID (exclusive) + to_id: The ending stream ID (inclusive) + limit: The maximum number of rows to return + + Returns: + A list of (stream_id, room_id, thread_root_event_id) tuples. + """ + + def get_updated_thread_subscriptions_for_user_txn( + txn: LoggingTransaction, + ) -> List[Tuple[int, str, str]]: + sql = """ + SELECT stream_id, room_id, event_id + FROM thread_subscriptions + WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """ + + txn.execute(sql, (user_id, from_id, to_id, limit)) + return [(row[0], row[1], row[2]) for row in txn] + + return await self.db_pool.runInteraction( + "get_updated_thread_subscriptions_for_user", + get_updated_thread_subscriptions_for_user_txn, + ) diff --git a/synapse/storage/schema/main/delta/92/04_thread_subscriptions.sql b/synapse/storage/schema/main/delta/92/04_thread_subscriptions.sql new file mode 100644 index 0000000000..d19dd7a46d --- /dev/null +++ b/synapse/storage/schema/main/delta/92/04_thread_subscriptions.sql @@ -0,0 +1,59 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- Introduce a table for tracking users' subscriptions to threads. +CREATE TABLE thread_subscriptions ( + stream_id INTEGER NOT NULL PRIMARY KEY, + instance_name TEXT NOT NULL, + + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + + subscribed BOOLEAN NOT NULL, + automatic BOOLEAN NOT NULL, + + CONSTRAINT thread_subscriptions_fk_users + FOREIGN KEY (user_id) + REFERENCES users(name), + + CONSTRAINT thread_subscriptions_fk_rooms + FOREIGN KEY (room_id) + -- When we delete a room, we should already have deleted all the events in that room + -- and so there shouldn't be any subscriptions left in that room. + -- So the `ON DELETE CASCADE` should be optional, but included anyway for good measure. + REFERENCES rooms(room_id) ON DELETE CASCADE, + + CONSTRAINT thread_subscriptions_fk_events + FOREIGN KEY (event_id) + REFERENCES events(event_id) ON DELETE CASCADE, + + -- This order provides a useful index for: + -- 1. foreign key constraint on (room_id) + -- 2. foreign key constraint on (room_id, event_id) + -- 3. finding the user's settings for a specific thread (as well as enforcing uniqueness) + UNIQUE (room_id, event_id, user_id) +); + +-- this provides a useful index for finding a user's own rules, +-- potentially scoped to a single room +CREATE INDEX thread_subscriptions_user_room ON thread_subscriptions (user_id, room_id); + +-- this provides a useful way for clients to efficiently find new changes to +-- their subscriptions. +-- (This is necessary to sync subscriptions between multiple devices.) +CREATE INDEX thread_subscriptions_by_user ON thread_subscriptions (user_id, stream_id); + +-- this provides a useful index for deleting the subscriptions when the underlying +-- events are removed. This also covers the foreign key constraint on `events`. +CREATE INDEX thread_subscriptions_by_event ON thread_subscriptions (event_id); diff --git a/synapse/storage/schema/main/delta/92/04_thread_subscriptions_seq.sql.postgres b/synapse/storage/schema/main/delta/92/04_thread_subscriptions_seq.sql.postgres new file mode 100644 index 0000000000..8d53691747 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/04_thread_subscriptions_seq.sql.postgres @@ -0,0 +1,19 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +CREATE SEQUENCE thread_subscriptions_sequence + -- Synapse streams start at 2, because the default position is 1 + -- so any item inserted at position 1 is ignored. + -- This is also what existing streams do, except they use `setval(..., 1)` + -- which is semantically the same except less obvious. + START WITH 2; diff --git a/synapse/storage/schema/main/delta/92/05_thread_subscriptions_comments.sql.postgres b/synapse/storage/schema/main/delta/92/05_thread_subscriptions_comments.sql.postgres new file mode 100644 index 0000000000..b0729894c0 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/05_thread_subscriptions_comments.sql.postgres @@ -0,0 +1,18 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +COMMENT ON TABLE thread_subscriptions IS 'Tracks local users that subscribe to threads'; + +COMMENT ON COLUMN thread_subscriptions.subscribed IS 'Whether the user is subscribed to the thread or not. We track unsubscribed threads because we need to stream the subscription change to the client.'; + +COMMENT ON COLUMN thread_subscriptions.automatic IS 'True if the user was subscribed to the thread automatically by their client, or false if the client manually requested the subscription.'; diff --git a/synapse/storage/schema/main/delta/92/06_threads_last_sent_stream_ordering_comments.sql.postgres b/synapse/storage/schema/main/delta/92/06_threads_last_sent_stream_ordering_comments.sql.postgres new file mode 100644 index 0000000000..3fc7e4b11e --- /dev/null +++ b/synapse/storage/schema/main/delta/92/06_threads_last_sent_stream_ordering_comments.sql.postgres @@ -0,0 +1,24 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +COMMENT ON COLUMN threads.latest_event_id IS + 'the ID of the event that is latest, ordered by (topological_ordering, stream_ordering)'; + +COMMENT ON COLUMN threads.topological_ordering IS + $$the topological ordering of the thread''s LATEST event. +Used as the primary way of ordering threads by recency in a room.$$; + +COMMENT ON COLUMN threads.stream_ordering IS + $$the stream ordering of the thread's LATEST event. +Used as a tie-breaker for ordering threads by recency in a room, when the topological order is a tie. +Also used for recency ordering in sliding sync.$$; diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index fe6b6579e6..026a0517d2 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -184,6 +184,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): Note: Only works with Postgres. + Warning: Streams using this generator start at ID 2, because ID 1 is always assumed + to have been 'seen as persisted'. + Unclear if this extant behaviour is desirable for some reason. + When creating a new sequence for a new stream, + it will be necessary to use `START WITH 2`. + Args: db_conn db @@ -269,6 +275,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): self._known_persisted_positions: List[int] = [] # The maximum stream ID that we have seen been allocated across any writer. + # Since this defaults to 1, this means that ID 1 is assumed to have already + # been 'seen'. In other words, multi-writer streams start at 2. + # Unclear if this is desirable behaviour. self._max_seen_allocated_stream_id = 1 # The maximum position of the local instance. This can be higher than diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index d09fd30e81..3b516fce3d 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -362,7 +362,8 @@ class RoomID(DomainSpecificString): @attr.s(slots=True, frozen=True, repr=False) class EventID(DomainSpecificString): - """Structure representing an event id.""" + """Structure representing an event ID which is namespaced to a homeserver. + Room versions 3 and above are not supported by this grammar.""" SIGIL = "$" diff --git a/tests/replication/tcp/streams/test_thread_subscriptions.py b/tests/replication/tcp/streams/test_thread_subscriptions.py new file mode 100644 index 0000000000..30c3415ad4 --- /dev/null +++ b/tests/replication/tcp/streams/test_thread_subscriptions.py @@ -0,0 +1,157 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.replication.tcp.streams._base import ( + _STREAM_UPDATE_TARGET_ROW_COUNT, + ThreadSubscriptionsStream, +) +from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction +from synapse.util import Clock + +from tests.replication._base import BaseStreamTestCase + + +class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + + # Postgres + def f(txn: LoggingTransaction) -> None: + txn.execute( + """ + ALTER TABLE thread_subscriptions + DROP CONSTRAINT thread_subscriptions_fk_users, + DROP CONSTRAINT thread_subscriptions_fk_rooms, + DROP CONSTRAINT thread_subscriptions_fk_events; + """, + ) + + self.get_success( + self.hs.get_datastores().main.db_pool.runInteraction( + "disable_foreign_keys", f + ) + ) + + def test_thread_subscription_updates(self) -> None: + """Test replication with thread subscription updates""" + store = self.hs.get_datastores().main + + # Create thread subscription updates + updates = [] + room_id = "!test_room:example.com" + + # Generate several thread subscription updates + for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5): + thread_root_id = f"$thread_{i}:example.com" + self.get_success( + store.subscribe_user_to_thread( + "@test_user:example.org", + room_id, + thread_root_id, + automatic=True, + ) + ) + updates.append(thread_root_id) + + # Also add one in a different room + other_room_id = "!other_room:example.com" + other_thread_root_id = "$other_thread:example.com" + self.get_success( + store.subscribe_user_to_thread( + "@test_user:example.org", + other_room_id, + other_thread_root_id, + automatic=False, + ) + ) + + # Not yet connected: no rows should yet have been received + self.assertEqual([], self.test_handler.received_rdata_rows) + + # Now reconnect to pull the updates + self.reconnect() + self.replicate() + + # We should have received all the expected rows in the right order + # Filter the updates to only include thread subscription changes + received_rows = [ + upd + for upd in self.test_handler.received_rdata_rows + if upd[0] == ThreadSubscriptionsStream.NAME + ] + + # Verify all the thread subscription updates + for thread_id in updates: + (stream_name, token, row) = received_rows.pop(0) + self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME) + self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE) + self.assertEqual(row.user_id, "@test_user:example.org") + self.assertEqual(row.room_id, room_id) + self.assertEqual(row.event_id, thread_id) + + # Verify the last update in the different room + (stream_name, token, row) = received_rows.pop(0) + self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME) + self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE) + self.assertEqual(row.user_id, "@test_user:example.org") + self.assertEqual(row.room_id, other_room_id) + self.assertEqual(row.event_id, other_thread_root_id) + + self.assertEqual([], received_rows) + + def test_multiple_users_thread_subscription_updates(self) -> None: + """Test replication with thread subscription updates for multiple users""" + store = self.hs.get_datastores().main + room_id = "!test_room:example.com" + thread_root_id = "$thread_root:example.com" + + # Create updates for multiple users + users = ["@user1:example.com", "@user2:example.com", "@user3:example.com"] + for user_id in users: + self.get_success( + store.subscribe_user_to_thread( + user_id, room_id, thread_root_id, automatic=True + ) + ) + + # Check no rows have been received yet + self.replicate() + self.assertEqual([], self.test_handler.received_rdata_rows) + + # Not yet connected: no rows should yet have been received + self.reconnect() + self.replicate() + + # We should have received all the expected rows + # Filter the updates to only include thread subscription changes + received_rows = [ + upd + for upd in self.test_handler.received_rdata_rows + if upd[0] == ThreadSubscriptionsStream.NAME + ] + + # Should have one update per user + self.assertEqual(len(received_rows), len(users)) + + # Verify all updates + for i, user_id in enumerate(users): + (stream_name, token, row) = received_rows[i] + self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME) + self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE) + self.assertEqual(row.user_id, user_id) + self.assertEqual(row.room_id, room_id) + self.assertEqual(row.event_id, thread_root_id) diff --git a/tests/rest/client/test_thread_subscriptions.py b/tests/rest/client/test_thread_subscriptions.py new file mode 100644 index 0000000000..a5c38753cb --- /dev/null +++ b/tests/rest/client/test_thread_subscriptions.py @@ -0,0 +1,256 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . + +from http import HTTPStatus + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, profile, room, thread_subscriptions +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests import unittest + +PREFIX = "/_matrix/client/unstable/io.element.msc4306/rooms" + + +class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + profile.register_servlets, + room.register_servlets, + thread_subscriptions.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc4306_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user_id = self.register_user("user", "password") + self.token = self.login("user", "password") + self.other_user_id = self.register_user("other_user", "password") + self.other_token = self.login("other_user", "password") + + # Create a room and send a message to use as a thread root + self.room_id = self.helper.create_room_as(self.user_id, tok=self.token) + self.helper.join(self.room_id, self.other_user_id, tok=self.other_token) + response = self.helper.send(self.room_id, body="Root message", tok=self.token) + self.root_event_id = response["event_id"] + + # Send a message in the thread + self.helper.send_event( + room_id=self.room_id, + type="m.room.message", + content={ + "body": "Thread message", + "msgtype": "m.text", + "m.relates_to": { + "rel_type": "m.thread", + "event_id": self.root_event_id, + }, + }, + tok=self.token, + ) + + def test_get_thread_subscription_unsubscribed(self) -> None: + """Test retrieving thread subscription when not subscribed.""" + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + def test_get_thread_subscription_nonexistent_thread(self) -> None: + """Test retrieving subscription settings for a nonexistent thread.""" + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + def test_get_thread_subscription_no_access(self) -> None: + """Test that a user can't get thread subscription for a thread they can't access.""" + self.register_user("no_access", "password") + no_access_token = self.login("no_access", "password") + + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=no_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + def test_subscribe_manual_then_automatic(self) -> None: + """Test subscribing to a thread, first a manual subscription then an automatic subscription. + The manual subscription wins over the automatic one.""" + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + { + "automatic": False, + }, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + + # Assert the subscription was saved + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.json_body, {"automatic": False}) + + # Now also register an automatic subscription; it should not + # override the manual subscription + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + {"automatic": True}, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + + # Assert the manual subscription was not overridden + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.json_body, {"automatic": False}) + + def test_subscribe_automatic_then_manual(self) -> None: + """Test subscribing to a thread, first an automatic subscription then a manual subscription. + The manual subscription wins over the automatic one.""" + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + { + "automatic": True, + }, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + + # Assert the subscription was saved + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.json_body, {"automatic": True}) + + # Now also register a manual subscription + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + {"automatic": False}, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + + # Assert the manual subscription was not overridden + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.json_body, {"automatic": False}) + + def test_unsubscribe(self) -> None: + """Test subscribing to a thread, then unsubscribing.""" + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + { + "automatic": True, + }, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + + # Assert the subscription was saved + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertEqual(channel.json_body, {"automatic": True}) + + # Now also register a manual subscription + channel = self.make_request( + "DELETE", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.OK) + + # Assert the manual subscription was not overridden + channel = self.make_request( + "GET", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + def test_set_thread_subscription_nonexistent_thread(self) -> None: + """Test setting subscription settings for a nonexistent thread.""" + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription", + {"automatic": True}, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + def test_set_thread_subscription_no_access(self) -> None: + """Test that a user can't set thread subscription for a thread they can't access.""" + self.register_user("no_access2", "password") + no_access_token = self.login("no_access2", "password") + + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + {"automatic": True}, + access_token=no_access_token, + ) + self.assertEqual(channel.code, HTTPStatus.NOT_FOUND) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + def test_invalid_body(self) -> None: + """Test that sending invalid subscription settings is rejected.""" + channel = self.make_request( + "PUT", + f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription", + # non-boolean `automatic` + {"automatic": "true"}, + access_token=self.token, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) diff --git a/tests/storage/test_thread_subscriptions.py b/tests/storage/test_thread_subscriptions.py new file mode 100644 index 0000000000..dd0b804f1f --- /dev/null +++ b/tests/storage/test_thread_subscriptions.py @@ -0,0 +1,272 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# + +from typing import Optional + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines.sqlite import Sqlite3Engine +from synapse.util import Clock + +from tests import unittest + + +class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = self.hs.get_datastores().main + self.user_id = "@user:test" + self.room_id = "!room:test" + self.thread_root_id = "$thread_root:test" + self.other_thread_root_id = "$other_thread_root:test" + + # Disable foreign key checks for testing + # This allows us to insert test data without having to create actual events + db_pool = self.store.db_pool + if isinstance(db_pool.engine, Sqlite3Engine): + self.get_success( + db_pool.execute("disable_foreign_keys", "PRAGMA foreign_keys = OFF;") + ) + else: + # Postgres + def f(txn: LoggingTransaction) -> None: + txn.execute( + """ + ALTER TABLE thread_subscriptions + DROP CONSTRAINT thread_subscriptions_fk_users, + DROP CONSTRAINT thread_subscriptions_fk_rooms, + DROP CONSTRAINT thread_subscriptions_fk_events; + """, + ) + + self.get_success(db_pool.runInteraction("disable_foreign_keys", f)) + + # Create rooms and events in the db to satisfy foreign key constraints + self.get_success(db_pool.simple_insert("rooms", {"room_id": self.room_id})) + + self.get_success( + db_pool.simple_insert( + "events", + { + "event_id": self.thread_root_id, + "room_id": self.room_id, + "topological_ordering": 1, + "stream_ordering": 1, + "type": "m.room.message", + "depth": 1, + "processed": True, + "outlier": False, + }, + ) + ) + + self.get_success( + db_pool.simple_insert( + "events", + { + "event_id": self.other_thread_root_id, + "room_id": self.room_id, + "topological_ordering": 2, + "stream_ordering": 2, + "type": "m.room.message", + "depth": 2, + "processed": True, + "outlier": False, + }, + ) + ) + + # Create the user + self.get_success( + db_pool.simple_insert("users", {"name": self.user_id, "is_guest": 0}) + ) + + def _subscribe( + self, + thread_root_id: str, + *, + automatic: bool, + room_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> Optional[int]: + if user_id is None: + user_id = self.user_id + + if room_id is None: + room_id = self.room_id + + return self.get_success( + self.store.subscribe_user_to_thread( + user_id, + room_id, + thread_root_id, + automatic=automatic, + ) + ) + + def _unsubscribe( + self, + thread_root_id: str, + room_id: Optional[str] = None, + user_id: Optional[str] = None, + ) -> Optional[int]: + if user_id is None: + user_id = self.user_id + + if room_id is None: + room_id = self.room_id + + return self.get_success( + self.store.unsubscribe_user_from_thread( + user_id, + room_id, + thread_root_id, + ) + ) + + def test_set_and_get_thread_subscription(self) -> None: + """Test basic setting and getting of thread subscriptions.""" + # Initial state: no subscription + subscription = self.get_success( + self.store.get_subscription_for_thread( + self.user_id, self.room_id, self.thread_root_id + ) + ) + self.assertIsNone(subscription) + + # Subscribe + self._subscribe( + self.thread_root_id, + automatic=True, + ) + + # Assert subscription went through + subscription = self.get_success( + self.store.get_subscription_for_thread( + self.user_id, self.room_id, self.thread_root_id + ) + ) + self.assertIsNotNone(subscription) + self.assertTrue(subscription.automatic) # type: ignore + + # Now make it a manual subscription + self._subscribe( + self.thread_root_id, + automatic=False, + ) + + # Assert the manual subscription overrode the automatic one + subscription = self.get_success( + self.store.get_subscription_for_thread( + self.user_id, self.room_id, self.thread_root_id + ) + ) + self.assertFalse(subscription.automatic) # type: ignore + + def test_purge_thread_subscriptions_for_user(self) -> None: + """Test purging all thread subscription settings for a user.""" + # Set subscription settings for multiple threads + self._subscribe(self.thread_root_id, automatic=True) + self._subscribe(self.other_thread_root_id, automatic=False) + + subscriptions = self.get_success( + self.store.get_updated_thread_subscriptions_for_user( + self.user_id, + from_id=0, + to_id=50, + limit=50, + ) + ) + min_id = min(id for (id, _, _) in subscriptions) + self.assertEqual( + subscriptions, + [ + (min_id, self.room_id, self.thread_root_id), + (min_id + 1, self.room_id, self.other_thread_root_id), + ], + ) + + # Purge all settings for the user + self.get_success( + self.store.purge_thread_subscription_settings_for_user(self.user_id) + ) + + # Check user has no subscriptions + subscriptions = self.get_success( + self.store.get_updated_thread_subscriptions_for_user( + self.user_id, + from_id=0, + to_id=50, + limit=50, + ) + ) + self.assertEqual(subscriptions, []) + + def test_get_updated_thread_subscriptions(self) -> None: + """Test getting updated thread subscriptions since a stream ID.""" + + stream_id1 = self._subscribe(self.thread_root_id, automatic=False) + stream_id2 = self._subscribe(self.other_thread_root_id, automatic=True) + assert stream_id1 is not None + assert stream_id2 is not None + + # Get updates since initial ID (should include both changes) + updates = self.get_success( + self.store.get_updated_thread_subscriptions(0, stream_id2, 10) + ) + self.assertEqual(len(updates), 2) + + # Get updates since first change (should include only the second change) + updates = self.get_success( + self.store.get_updated_thread_subscriptions(stream_id1, stream_id2, 10) + ) + self.assertEqual( + updates, + [(stream_id2, self.user_id, self.room_id, self.other_thread_root_id)], + ) + + def test_get_updated_thread_subscriptions_for_user(self) -> None: + """Test getting updated thread subscriptions for a specific user.""" + other_user_id = "@other_user:test" + + # Set thread subscription for main user + stream_id1 = self._subscribe(self.thread_root_id, automatic=True) + assert stream_id1 is not None + + # Set thread subscription for other user + stream_id2 = self._subscribe( + self.other_thread_root_id, + automatic=True, + user_id=other_user_id, + ) + assert stream_id2 is not None + + # Get updates for main user + updates = self.get_success( + self.store.get_updated_thread_subscriptions_for_user( + self.user_id, 0, stream_id2, 10 + ) + ) + self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)]) + + # Get updates for other user + updates = self.get_success( + self.store.get_updated_thread_subscriptions_for_user( + other_user_id, 0, max(stream_id1, stream_id2), 10 + ) + ) + self.assertEqual( + updates, [(stream_id2, self.room_id, self.other_thread_root_id)] + )