Add experimental and incomplete support for MSC4306: Thread Subscriptions. (#18674)
Implements: [MSC4306](https://github.com/matrix-org/matrix-spec-proposals/blob/rei/msc_thread_subscriptions/proposals/4306-thread-subscriptions.md) (partially) What's missing: - Changes to push rules Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
This commit is contained in:
1
changelog.d/18674.feature
Normal file
1
changelog.d/18674.feature
Normal file
@@ -0,0 +1 @@
|
|||||||
|
Add experimental and incomplete support for [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-spec-proposals/blob/rei/msc_thread_subscriptions/proposals/4306-thread-subscriptions.md).
|
||||||
@@ -327,6 +327,15 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
|
|||||||
"shared_extra_conf": {},
|
"shared_extra_conf": {},
|
||||||
"worker_extra_conf": "",
|
"worker_extra_conf": "",
|
||||||
},
|
},
|
||||||
|
"thread_subscriptions": {
|
||||||
|
"app": "synapse.app.generic_worker",
|
||||||
|
"listener_resources": ["client", "replication"],
|
||||||
|
"endpoint_patterns": [
|
||||||
|
"^/_matrix/client/unstable/io.element.msc4306/.*",
|
||||||
|
],
|
||||||
|
"shared_extra_conf": {},
|
||||||
|
"worker_extra_conf": "",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Templates for sections that may be inserted multiple times in config files
|
# Templates for sections that may be inserted multiple times in config files
|
||||||
@@ -427,6 +436,7 @@ def add_worker_roles_to_shared_config(
|
|||||||
"to_device",
|
"to_device",
|
||||||
"typing",
|
"typing",
|
||||||
"push_rules",
|
"push_rules",
|
||||||
|
"thread_subscriptions",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Worker-type specific sharding config. Now a single worker can fulfill multiple
|
# Worker-type specific sharding config. Now a single worker can fulfill multiple
|
||||||
|
|||||||
@@ -136,6 +136,7 @@ BOOLEAN_COLUMNS = {
|
|||||||
"has_known_state",
|
"has_known_state",
|
||||||
"is_encrypted",
|
"is_encrypted",
|
||||||
],
|
],
|
||||||
|
"thread_subscriptions": ["subscribed", "automatic"],
|
||||||
"users": ["shadow_banned", "approved", "locked", "suspended"],
|
"users": ["shadow_banned", "approved", "locked", "suspended"],
|
||||||
"un_partial_stated_event_stream": ["rejection_status_changed"],
|
"un_partial_stated_event_stream": ["rejection_status_changed"],
|
||||||
"users_who_share_rooms": ["share_private"],
|
"users_who_share_rooms": ["share_private"],
|
||||||
|
|||||||
@@ -104,6 +104,9 @@ from synapse.storage.databases.main.stats import StatsStore
|
|||||||
from synapse.storage.databases.main.stream import StreamWorkerStore
|
from synapse.storage.databases.main.stream import StreamWorkerStore
|
||||||
from synapse.storage.databases.main.tags import TagsWorkerStore
|
from synapse.storage.databases.main.tags import TagsWorkerStore
|
||||||
from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore
|
from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore
|
||||||
|
from synapse.storage.databases.main.thread_subscriptions import (
|
||||||
|
ThreadSubscriptionsWorkerStore,
|
||||||
|
)
|
||||||
from synapse.storage.databases.main.transactions import TransactionWorkerStore
|
from synapse.storage.databases.main.transactions import TransactionWorkerStore
|
||||||
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
|
from synapse.storage.databases.main.ui_auth import UIAuthWorkerStore
|
||||||
from synapse.storage.databases.main.user_directory import UserDirectoryStore
|
from synapse.storage.databases.main.user_directory import UserDirectoryStore
|
||||||
@@ -132,6 +135,7 @@ class GenericWorkerStore(
|
|||||||
KeyStore,
|
KeyStore,
|
||||||
RoomWorkerStore,
|
RoomWorkerStore,
|
||||||
DirectoryWorkerStore,
|
DirectoryWorkerStore,
|
||||||
|
ThreadSubscriptionsWorkerStore,
|
||||||
PushRulesWorkerStore,
|
PushRulesWorkerStore,
|
||||||
ApplicationServiceTransactionWorkerStore,
|
ApplicationServiceTransactionWorkerStore,
|
||||||
ApplicationServiceWorkerStore,
|
ApplicationServiceWorkerStore,
|
||||||
|
|||||||
@@ -581,3 +581,7 @@ class ExperimentalConfig(Config):
|
|||||||
|
|
||||||
# MSC4155: Invite filtering
|
# MSC4155: Invite filtering
|
||||||
self.msc4155_enabled: bool = experimental.get("msc4155_enabled", False)
|
self.msc4155_enabled: bool = experimental.get("msc4155_enabled", False)
|
||||||
|
|
||||||
|
# MSC4306: Thread Subscriptions
|
||||||
|
# (and MSC4308: sliding sync extension for thread subscriptions)
|
||||||
|
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)
|
||||||
|
|||||||
@@ -174,6 +174,10 @@ class WriterLocations:
|
|||||||
default=[MAIN_PROCESS_INSTANCE_NAME],
|
default=[MAIN_PROCESS_INSTANCE_NAME],
|
||||||
converter=_instance_to_list_converter,
|
converter=_instance_to_list_converter,
|
||||||
)
|
)
|
||||||
|
thread_subscriptions: List[str] = attr.ib(
|
||||||
|
default=["master"],
|
||||||
|
converter=_instance_to_list_converter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@attr.s(auto_attribs=True)
|
@attr.s(auto_attribs=True)
|
||||||
|
|||||||
@@ -187,6 +187,9 @@ class DeactivateAccountHandler:
|
|||||||
# Remove account data (including ignored users and push rules).
|
# Remove account data (including ignored users and push rules).
|
||||||
await self.store.purge_account_data_for_user(user_id)
|
await self.store.purge_account_data_for_user(user_id)
|
||||||
|
|
||||||
|
# Remove thread subscriptions for the user
|
||||||
|
await self.store.purge_thread_subscription_settings_for_user(user_id)
|
||||||
|
|
||||||
# Delete any server-side backup keys
|
# Delete any server-side backup keys
|
||||||
await self.store.bulk_delete_backup_keys_and_versions_for_user(user_id)
|
await self.store.bulk_delete_backup_keys_and_versions_for_user(user_id)
|
||||||
|
|
||||||
|
|||||||
126
synapse/handlers/thread_subscriptions.py
Normal file
126
synapse/handlers/thread_subscriptions.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
|
from synapse.api.errors import AuthError, NotFoundError
|
||||||
|
from synapse.storage.databases.main.thread_subscriptions import ThreadSubscription
|
||||||
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSubscriptionsHandler:
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
self.event_handler = hs.get_event_handler()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
|
||||||
|
async def get_thread_subscription_settings(
|
||||||
|
self,
|
||||||
|
user_id: UserID,
|
||||||
|
room_id: str,
|
||||||
|
thread_root_event_id: str,
|
||||||
|
) -> Optional[ThreadSubscription]:
|
||||||
|
"""Get thread subscription settings for a specific thread and user.
|
||||||
|
Checks that the thread root is both a real event and also that it is visible
|
||||||
|
to the user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The ID of the user
|
||||||
|
thread_root_event_id: The event ID of the thread root
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `ThreadSubscription` containing the active subscription settings or None if not set
|
||||||
|
"""
|
||||||
|
# First check that the user can access the thread root event
|
||||||
|
# and that it exists
|
||||||
|
try:
|
||||||
|
event = await self.event_handler.get_event(
|
||||||
|
user_id, room_id, thread_root_event_id
|
||||||
|
)
|
||||||
|
if event is None:
|
||||||
|
raise NotFoundError("No such thread root")
|
||||||
|
except AuthError:
|
||||||
|
raise NotFoundError("No such thread root")
|
||||||
|
|
||||||
|
return await self.store.get_subscription_for_thread(
|
||||||
|
user_id.to_string(), event.room_id, thread_root_event_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def subscribe_user_to_thread(
|
||||||
|
self,
|
||||||
|
user_id: UserID,
|
||||||
|
room_id: str,
|
||||||
|
thread_root_event_id: str,
|
||||||
|
*,
|
||||||
|
automatic: bool,
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""Sets or updates a user's subscription settings for a specific thread root.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requester_user_id: The ID of the user whose settings are being updated.
|
||||||
|
thread_root_event_id: The event ID of the thread root.
|
||||||
|
automatic: whether the user was subscribed by an automatic decision by
|
||||||
|
their client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The stream ID for this update, if the update isn't no-opped.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError if the user cannot access the thread root event, or it isn't
|
||||||
|
known to this homeserver.
|
||||||
|
"""
|
||||||
|
# First check that the user can access the thread root event
|
||||||
|
# and that it exists
|
||||||
|
try:
|
||||||
|
event = await self.event_handler.get_event(
|
||||||
|
user_id, room_id, thread_root_event_id
|
||||||
|
)
|
||||||
|
if event is None:
|
||||||
|
raise NotFoundError("No such thread root")
|
||||||
|
except AuthError:
|
||||||
|
logger.info("rejecting thread subscriptions change (thread not accessible)")
|
||||||
|
raise NotFoundError("No such thread root")
|
||||||
|
|
||||||
|
return await self.store.subscribe_user_to_thread(
|
||||||
|
user_id.to_string(),
|
||||||
|
event.room_id,
|
||||||
|
thread_root_event_id,
|
||||||
|
automatic=automatic,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unsubscribe_user_from_thread(
|
||||||
|
self, user_id: UserID, room_id: str, thread_root_event_id: str
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""Clears a user's subscription settings for a specific thread root.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requester_user_id: The ID of the user whose settings are being updated.
|
||||||
|
thread_root_event_id: The event ID of the thread root.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The stream ID for this update, if the update isn't no-opped.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError if the user cannot access the thread root event, or it isn't
|
||||||
|
known to this homeserver.
|
||||||
|
"""
|
||||||
|
# First check that the user can access the thread root event
|
||||||
|
# and that it exists
|
||||||
|
try:
|
||||||
|
event = await self.event_handler.get_event(
|
||||||
|
user_id, room_id, thread_root_event_id
|
||||||
|
)
|
||||||
|
if event is None:
|
||||||
|
raise NotFoundError("No such thread root")
|
||||||
|
except AuthError:
|
||||||
|
logger.info("rejecting thread subscriptions change (thread not accessible)")
|
||||||
|
raise NotFoundError("No such thread root")
|
||||||
|
|
||||||
|
return await self.store.unsubscribe_user_from_thread(
|
||||||
|
user_id.to_string(),
|
||||||
|
event.room_id,
|
||||||
|
thread_root_event_id,
|
||||||
|
)
|
||||||
@@ -72,7 +72,10 @@ from synapse.replication.tcp.streams import (
|
|||||||
ToDeviceStream,
|
ToDeviceStream,
|
||||||
TypingStream,
|
TypingStream,
|
||||||
)
|
)
|
||||||
from synapse.replication.tcp.streams._base import DeviceListsStream
|
from synapse.replication.tcp.streams._base import (
|
||||||
|
DeviceListsStream,
|
||||||
|
ThreadSubscriptionsStream,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
@@ -186,6 +189,15 @@ class ReplicationCommandHandler:
|
|||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if isinstance(stream, ThreadSubscriptionsStream):
|
||||||
|
if (
|
||||||
|
hs.get_instance_name()
|
||||||
|
in hs.config.worker.writers.thread_subscriptions
|
||||||
|
):
|
||||||
|
self._streams_to_replicate.append(stream)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(stream, DeviceListsStream):
|
if isinstance(stream, DeviceListsStream):
|
||||||
if hs.get_instance_name() in hs.config.worker.writers.device_lists:
|
if hs.get_instance_name() in hs.config.worker.writers.device_lists:
|
||||||
self._streams_to_replicate.append(stream)
|
self._streams_to_replicate.append(stream)
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ from synapse.replication.tcp.streams._base import (
|
|||||||
PushRulesStream,
|
PushRulesStream,
|
||||||
ReceiptsStream,
|
ReceiptsStream,
|
||||||
Stream,
|
Stream,
|
||||||
|
ThreadSubscriptionsStream,
|
||||||
ToDeviceStream,
|
ToDeviceStream,
|
||||||
TypingStream,
|
TypingStream,
|
||||||
)
|
)
|
||||||
@@ -67,6 +68,7 @@ STREAMS_MAP = {
|
|||||||
ToDeviceStream,
|
ToDeviceStream,
|
||||||
FederationStream,
|
FederationStream,
|
||||||
AccountDataStream,
|
AccountDataStream,
|
||||||
|
ThreadSubscriptionsStream,
|
||||||
UnPartialStatedRoomStream,
|
UnPartialStatedRoomStream,
|
||||||
UnPartialStatedEventStream,
|
UnPartialStatedEventStream,
|
||||||
)
|
)
|
||||||
@@ -86,6 +88,7 @@ __all__ = [
|
|||||||
"DeviceListsStream",
|
"DeviceListsStream",
|
||||||
"ToDeviceStream",
|
"ToDeviceStream",
|
||||||
"AccountDataStream",
|
"AccountDataStream",
|
||||||
|
"ThreadSubscriptionsStream",
|
||||||
"UnPartialStatedRoomStream",
|
"UnPartialStatedRoomStream",
|
||||||
"UnPartialStatedEventStream",
|
"UnPartialStatedEventStream",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -723,3 +723,47 @@ class AccountDataStream(_StreamFromIdGen):
|
|||||||
heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
|
heapq.merge(room_rows, global_rows, tag_rows, key=lambda row: row[0])
|
||||||
)
|
)
|
||||||
return updates, to_token, limited
|
return updates, to_token, limited
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSubscriptionsStream(_StreamFromIdGen):
|
||||||
|
"""A thread subscription was changed."""
|
||||||
|
|
||||||
|
@attr.s(slots=True, auto_attribs=True)
|
||||||
|
class ThreadSubscriptionsStreamRow:
|
||||||
|
"""Stream to inform workers about changes to thread subscriptions."""
|
||||||
|
|
||||||
|
user_id: str
|
||||||
|
room_id: str
|
||||||
|
event_id: str # The event ID of the thread root
|
||||||
|
|
||||||
|
NAME = "thread_subscriptions"
|
||||||
|
ROW_TYPE = ThreadSubscriptionsStreamRow
|
||||||
|
|
||||||
|
def __init__(self, hs: Any):
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
super().__init__(
|
||||||
|
hs.get_instance_name(),
|
||||||
|
self._update_function,
|
||||||
|
self.store._thread_subscriptions_id_gen,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _update_function(
|
||||||
|
self, instance_name: str, from_token: int, to_token: int, limit: int
|
||||||
|
) -> StreamUpdateResult:
|
||||||
|
updates = await self.store.get_updated_thread_subscriptions(
|
||||||
|
from_token, to_token, limit
|
||||||
|
)
|
||||||
|
rows = [
|
||||||
|
(
|
||||||
|
stream_id,
|
||||||
|
# These are the args to `ThreadSubscriptionsStreamRow`
|
||||||
|
(user_id, room_id, event_id),
|
||||||
|
)
|
||||||
|
for stream_id, user_id, room_id, event_id in updates
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.error("TS %d->%d %r", from_token, to_token, rows)
|
||||||
|
if not rows:
|
||||||
|
return [], to_token, False
|
||||||
|
|
||||||
|
return rows, rows[-1][0], len(updates) == limit
|
||||||
|
|||||||
98
synapse/rest/client/thread_subscriptions.py
Normal file
98
synapse/rest/client/thread_subscriptions.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from synapse._pydantic_compat import StrictBool
|
||||||
|
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||||
|
from synapse.http.server import HttpServer
|
||||||
|
from synapse.http.servlet import (
|
||||||
|
RestServlet,
|
||||||
|
parse_and_validate_json_object_from_request,
|
||||||
|
)
|
||||||
|
from synapse.http.site import SynapseRequest
|
||||||
|
from synapse.rest.client._base import client_patterns
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict, RoomID
|
||||||
|
from synapse.types.rest import RequestBodyModel
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSubscriptionsRestServlet(RestServlet):
|
||||||
|
PATTERNS = client_patterns(
|
||||||
|
"/io.element.msc4306/rooms/(?P<room_id>[^/]*)/thread/(?P<thread_root_id>[^/]*)/subscription$",
|
||||||
|
unstable=True,
|
||||||
|
releases=(),
|
||||||
|
)
|
||||||
|
CATEGORY = "Thread Subscriptions requests (unstable)"
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.is_mine = hs.is_mine
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
self.handler = hs.get_thread_subscriptions_handler()
|
||||||
|
|
||||||
|
class PutBody(RequestBodyModel):
|
||||||
|
automatic: StrictBool
|
||||||
|
|
||||||
|
async def on_GET(
|
||||||
|
self, request: SynapseRequest, room_id: str, thread_root_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
RoomID.from_string(room_id)
|
||||||
|
if not thread_root_id.startswith("$"):
|
||||||
|
raise SynapseError(
|
||||||
|
HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
subscription = await self.handler.get_thread_subscription_settings(
|
||||||
|
requester.user,
|
||||||
|
room_id,
|
||||||
|
thread_root_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if subscription is None:
|
||||||
|
raise NotFoundError("Not subscribed.")
|
||||||
|
|
||||||
|
return HTTPStatus.OK, {"automatic": subscription.automatic}
|
||||||
|
|
||||||
|
async def on_PUT(
|
||||||
|
self, request: SynapseRequest, room_id: str, thread_root_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
RoomID.from_string(room_id)
|
||||||
|
if not thread_root_id.startswith("$"):
|
||||||
|
raise SynapseError(
|
||||||
|
HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
body = parse_and_validate_json_object_from_request(request, self.PutBody)
|
||||||
|
|
||||||
|
await self.handler.subscribe_user_to_thread(
|
||||||
|
requester.user,
|
||||||
|
room_id,
|
||||||
|
thread_root_id,
|
||||||
|
automatic=body.automatic,
|
||||||
|
)
|
||||||
|
|
||||||
|
return HTTPStatus.OK, {}
|
||||||
|
|
||||||
|
async def on_DELETE(
|
||||||
|
self, request: SynapseRequest, room_id: str, thread_root_id: str
|
||||||
|
) -> Tuple[int, JsonDict]:
|
||||||
|
RoomID.from_string(room_id)
|
||||||
|
if not thread_root_id.startswith("$"):
|
||||||
|
raise SynapseError(
|
||||||
|
HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM
|
||||||
|
)
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
|
await self.handler.unsubscribe_user_from_thread(
|
||||||
|
requester.user,
|
||||||
|
room_id,
|
||||||
|
thread_root_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return HTTPStatus.OK, {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||||
|
if hs.config.experimental.msc4306_enabled:
|
||||||
|
ThreadSubscriptionsRestServlet(hs).register(http_server)
|
||||||
@@ -117,6 +117,7 @@ from synapse.handlers.sliding_sync import SlidingSyncHandler
|
|||||||
from synapse.handlers.sso import SsoHandler
|
from synapse.handlers.sso import SsoHandler
|
||||||
from synapse.handlers.stats import StatsHandler
|
from synapse.handlers.stats import StatsHandler
|
||||||
from synapse.handlers.sync import SyncHandler
|
from synapse.handlers.sync import SyncHandler
|
||||||
|
from synapse.handlers.thread_subscriptions import ThreadSubscriptionsHandler
|
||||||
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
|
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
|
||||||
from synapse.handlers.user_directory import UserDirectoryHandler
|
from synapse.handlers.user_directory import UserDirectoryHandler
|
||||||
from synapse.handlers.worker_lock import WorkerLocksHandler
|
from synapse.handlers.worker_lock import WorkerLocksHandler
|
||||||
@@ -789,6 +790,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
|||||||
def get_timestamp_lookup_handler(self) -> TimestampLookupHandler:
|
def get_timestamp_lookup_handler(self) -> TimestampLookupHandler:
|
||||||
return TimestampLookupHandler(self)
|
return TimestampLookupHandler(self)
|
||||||
|
|
||||||
|
@cache_in_self
|
||||||
|
def get_thread_subscriptions_handler(self) -> ThreadSubscriptionsHandler:
|
||||||
|
return ThreadSubscriptionsHandler(self)
|
||||||
|
|
||||||
@cache_in_self
|
@cache_in_self
|
||||||
def get_registration_handler(self) -> RegistrationHandler:
|
def get_registration_handler(self) -> RegistrationHandler:
|
||||||
return RegistrationHandler(self)
|
return RegistrationHandler(self)
|
||||||
|
|||||||
@@ -19,7 +19,6 @@
|
|||||||
# [This file includes modifications made by New Vector Limited]
|
# [This file includes modifications made by New Vector Limited]
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
@@ -35,6 +34,9 @@ from synapse.storage.database import (
|
|||||||
)
|
)
|
||||||
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
|
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
|
||||||
from synapse.storage.databases.main.stats import UserSortOrder
|
from synapse.storage.databases.main.stats import UserSortOrder
|
||||||
|
from synapse.storage.databases.main.thread_subscriptions import (
|
||||||
|
ThreadSubscriptionsWorkerStore,
|
||||||
|
)
|
||||||
from synapse.storage.engines import BaseDatabaseEngine
|
from synapse.storage.engines import BaseDatabaseEngine
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
from synapse.types import get_domain_from_id
|
from synapse.types import get_domain_from_id
|
||||||
@@ -141,6 +143,7 @@ class DataStore(
|
|||||||
SearchStore,
|
SearchStore,
|
||||||
TagsStore,
|
TagsStore,
|
||||||
AccountDataStore,
|
AccountDataStore,
|
||||||
|
ThreadSubscriptionsWorkerStore,
|
||||||
PushRulesWorkerStore,
|
PushRulesWorkerStore,
|
||||||
StreamWorkerStore,
|
StreamWorkerStore,
|
||||||
OpenIdStore,
|
OpenIdStore,
|
||||||
|
|||||||
@@ -2986,6 +2986,10 @@ class PersistEventsStore:
|
|||||||
# Upsert into the threads table, but only overwrite the value if the
|
# Upsert into the threads table, but only overwrite the value if the
|
||||||
# new event is of a later topological order OR if the topological
|
# new event is of a later topological order OR if the topological
|
||||||
# ordering is equal, but the stream ordering is later.
|
# ordering is equal, but the stream ordering is later.
|
||||||
|
# (Note by definition that the stream ordering will always be later
|
||||||
|
# unless this is a backfilled event [= negative stream ordering]
|
||||||
|
# because we are only persisting this event now and stream_orderings
|
||||||
|
# are strictly monotonically increasing)
|
||||||
sql = """
|
sql = """
|
||||||
INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering)
|
INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering)
|
||||||
VALUES (?, ?, ?, ?, ?)
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
|||||||
382
synapse/storage/databases/main/thread_subscriptions.py
Normal file
382
synapse/storage/databases/main/thread_subscriptions.py
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
#
|
||||||
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
#
|
||||||
|
# Copyright (C) 2025 New Vector, Ltd
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# See the GNU Affero General Public License for more details:
|
||||||
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
import logging
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
|
from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream
|
||||||
|
from synapse.storage.database import (
|
||||||
|
DatabasePool,
|
||||||
|
LoggingDatabaseConnection,
|
||||||
|
LoggingTransaction,
|
||||||
|
)
|
||||||
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||||
|
class ThreadSubscription:
|
||||||
|
automatic: bool
|
||||||
|
"""
|
||||||
|
whether the subscription was made automatically (as opposed to by manual
|
||||||
|
action from the user)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
database: DatabasePool,
|
||||||
|
db_conn: LoggingDatabaseConnection,
|
||||||
|
hs: "HomeServer",
|
||||||
|
):
|
||||||
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
|
self._can_write_to_thread_subscriptions = (
|
||||||
|
self._instance_name in hs.config.worker.writers.thread_subscriptions
|
||||||
|
)
|
||||||
|
|
||||||
|
self._thread_subscriptions_id_gen: MultiWriterIdGenerator = (
|
||||||
|
MultiWriterIdGenerator(
|
||||||
|
db_conn=db_conn,
|
||||||
|
db=database,
|
||||||
|
notifier=hs.get_replication_notifier(),
|
||||||
|
stream_name="thread_subscriptions",
|
||||||
|
instance_name=self._instance_name,
|
||||||
|
tables=[
|
||||||
|
("thread_subscriptions", "instance_name", "stream_id"),
|
||||||
|
],
|
||||||
|
sequence_name="thread_subscriptions_sequence",
|
||||||
|
writers=hs.config.worker.writers.thread_subscriptions,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_replication_rows(
|
||||||
|
self,
|
||||||
|
stream_name: str,
|
||||||
|
instance_name: str,
|
||||||
|
token: int,
|
||||||
|
rows: Iterable[Any],
|
||||||
|
) -> None:
|
||||||
|
if stream_name == ThreadSubscriptionsStream.NAME:
|
||||||
|
for row in rows:
|
||||||
|
self.get_subscription_for_thread.invalidate(
|
||||||
|
(row.user_id, row.room_id, row.event_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||||
|
|
||||||
|
def process_replication_position(
|
||||||
|
self, stream_name: str, instance_name: str, token: int
|
||||||
|
) -> None:
|
||||||
|
if stream_name == ThreadSubscriptionsStream.NAME:
|
||||||
|
self._thread_subscriptions_id_gen.advance(instance_name, token)
|
||||||
|
super().process_replication_position(stream_name, instance_name, token)
|
||||||
|
|
||||||
|
async def subscribe_user_to_thread(
|
||||||
|
self, user_id: str, room_id: str, thread_root_event_id: str, *, automatic: bool
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""Updates a user's subscription settings for a specific thread root.
|
||||||
|
|
||||||
|
If no change would be made to the subscription, does not produce any database change.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The ID of the user whose settings are being updated.
|
||||||
|
room_id: The ID of the room the thread root belongs to.
|
||||||
|
thread_root_event_id: The event ID of the thread root.
|
||||||
|
automatic: Whether the subscription was performed automatically by the user's client.
|
||||||
|
Only `False` will overwrite an existing value of automatic for a subscription row.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The stream ID for this update, if the update isn't no-opped.
|
||||||
|
"""
|
||||||
|
assert self._can_write_to_thread_subscriptions
|
||||||
|
|
||||||
|
def _subscribe_user_to_thread_txn(txn: LoggingTransaction) -> Optional[int]:
|
||||||
|
already_automatic = self.db_pool.simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="thread_subscriptions",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"event_id": thread_root_event_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
"subscribed": True,
|
||||||
|
},
|
||||||
|
retcol="automatic",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if already_automatic is None:
|
||||||
|
already_subscribed = False
|
||||||
|
already_automatic = True
|
||||||
|
else:
|
||||||
|
already_subscribed = True
|
||||||
|
# convert int (SQLite bool) to Python bool
|
||||||
|
already_automatic = bool(already_automatic)
|
||||||
|
|
||||||
|
if already_subscribed and already_automatic == automatic:
|
||||||
|
# there is nothing we need to do here
|
||||||
|
return None
|
||||||
|
|
||||||
|
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
|
||||||
|
|
||||||
|
values: Dict[str, Optional[Union[bool, int, str]]] = {
|
||||||
|
"subscribed": True,
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
|
"automatic": already_automatic and automatic,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.db_pool.simple_upsert_txn(
|
||||||
|
txn,
|
||||||
|
table="thread_subscriptions",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"event_id": thread_root_event_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
},
|
||||||
|
values=values,
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.call_after(
|
||||||
|
self.get_subscription_for_thread.invalidate,
|
||||||
|
(user_id, room_id, thread_root_event_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
return stream_id
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"subscribe_user_to_thread", _subscribe_user_to_thread_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
async def unsubscribe_user_from_thread(
|
||||||
|
self, user_id: str, room_id: str, thread_root_event_id: str
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""Unsubscribes a user from a thread.
|
||||||
|
|
||||||
|
If no change would be made to the subscription, does not produce any database change.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The ID of the user whose settings are being updated.
|
||||||
|
room_id: The ID of the room the thread root belongs to.
|
||||||
|
thread_root_event_id: The event ID of the thread root.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The stream ID for this update, if the update isn't no-opped.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert self._can_write_to_thread_subscriptions
|
||||||
|
|
||||||
|
def _unsubscribe_user_from_thread_txn(txn: LoggingTransaction) -> Optional[int]:
|
||||||
|
already_subscribed = self.db_pool.simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="thread_subscriptions",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"event_id": thread_root_event_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
},
|
||||||
|
retcol="subscribed",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if already_subscribed is None or already_subscribed is False:
|
||||||
|
# there is nothing we need to do here
|
||||||
|
return None
|
||||||
|
|
||||||
|
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
|
||||||
|
|
||||||
|
self.db_pool.simple_update_txn(
|
||||||
|
txn,
|
||||||
|
table="thread_subscriptions",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"event_id": thread_root_event_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
"subscribed": True,
|
||||||
|
},
|
||||||
|
updatevalues={
|
||||||
|
"subscribed": False,
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"instance_name": self._instance_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
txn.call_after(
|
||||||
|
self.get_subscription_for_thread.invalidate,
|
||||||
|
(user_id, room_id, thread_root_event_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
return stream_id
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"unsubscribe_user_from_thread", _unsubscribe_user_from_thread_txn
|
||||||
|
)
|
||||||
|
|
||||||
|
async def purge_thread_subscription_settings_for_user(self, user_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Purge all subscriptions for the user.
|
||||||
|
The fact that subscriptions have been purged will not be streamed;
|
||||||
|
all stream rows for the user will in fact be removed.
|
||||||
|
This is intended only for dealing with user deactivation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _purge_thread_subscription_settings_for_user_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> None:
|
||||||
|
self.db_pool.simple_delete_txn(
|
||||||
|
txn,
|
||||||
|
table="thread_subscriptions",
|
||||||
|
keyvalues={"user_id": user_id},
|
||||||
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_subscription_for_thread, (user_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
desc="purge_thread_subscription_settings_for_user",
|
||||||
|
func=_purge_thread_subscription_settings_for_user_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
@cached(tree=True)
|
||||||
|
async def get_subscription_for_thread(
|
||||||
|
self, user_id: str, room_id: str, thread_root_event_id: str
|
||||||
|
) -> Optional[ThreadSubscription]:
|
||||||
|
"""Get the thread subscription for a specific thread and user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The ID of the user
|
||||||
|
room_id: The ID of the room
|
||||||
|
thread_root_event_id: The event ID of the thread root
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `ThreadSubscription` dataclass if there is a subscription,
|
||||||
|
or `None` if there is no subscription.
|
||||||
|
|
||||||
|
If there is a row in the table but `subscribed` is `False`,
|
||||||
|
behaves the same as if there was no row at all and returns `None`.
|
||||||
|
"""
|
||||||
|
row = await self.db_pool.simple_select_one(
|
||||||
|
table="thread_subscriptions",
|
||||||
|
keyvalues={
|
||||||
|
"user_id": user_id,
|
||||||
|
"room_id": room_id,
|
||||||
|
"event_id": thread_root_event_id,
|
||||||
|
"subscribed": True,
|
||||||
|
},
|
||||||
|
retcols=("automatic",),
|
||||||
|
allow_none=True,
|
||||||
|
desc="get_subscription_for_thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
(automatic_rawbool,) = row
|
||||||
|
|
||||||
|
# convert SQLite integer booleans into real booleans
|
||||||
|
automatic = bool(automatic_rawbool)
|
||||||
|
|
||||||
|
return ThreadSubscription(automatic=automatic)
|
||||||
|
|
||||||
|
def get_max_thread_subscriptions_stream_id(self) -> int:
|
||||||
|
"""Get the current maximum stream_id for thread subscriptions.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The maximum stream_id
|
||||||
|
"""
|
||||||
|
return self._thread_subscriptions_id_gen.get_current_token()
|
||||||
|
|
||||||
|
async def get_updated_thread_subscriptions(
|
||||||
|
self, from_id: int, to_id: int, limit: int
|
||||||
|
) -> List[Tuple[int, str, str, str]]:
|
||||||
|
"""Get updates to thread subscriptions between two stream IDs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
from_id: The starting stream ID (exclusive)
|
||||||
|
to_id: The ending stream ID (inclusive)
|
||||||
|
limit: The maximum number of rows to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of (stream_id, user_id, room_id, thread_root_id) tuples
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_updated_thread_subscriptions_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> List[Tuple[int, str, str, str]]:
|
||||||
|
sql = """
|
||||||
|
SELECT stream_id, user_id, room_id, event_id
|
||||||
|
FROM thread_subscriptions
|
||||||
|
WHERE ? < stream_id AND stream_id <= ?
|
||||||
|
ORDER BY stream_id ASC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (from_id, to_id, limit))
|
||||||
|
return cast(List[Tuple[int, str, str, str]], txn.fetchall())
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_updated_thread_subscriptions",
|
||||||
|
get_updated_thread_subscriptions_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_updated_thread_subscriptions_for_user(
|
||||||
|
self, user_id: str, from_id: int, to_id: int, limit: int
|
||||||
|
) -> List[Tuple[int, str, str]]:
|
||||||
|
"""Get updates to thread subscriptions for a specific user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The ID of the user
|
||||||
|
from_id: The starting stream ID (exclusive)
|
||||||
|
to_id: The ending stream ID (inclusive)
|
||||||
|
limit: The maximum number of rows to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of (stream_id, room_id, thread_root_event_id) tuples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_updated_thread_subscriptions_for_user_txn(
|
||||||
|
txn: LoggingTransaction,
|
||||||
|
) -> List[Tuple[int, str, str]]:
|
||||||
|
sql = """
|
||||||
|
SELECT stream_id, room_id, event_id
|
||||||
|
FROM thread_subscriptions
|
||||||
|
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
|
||||||
|
ORDER BY stream_id ASC
|
||||||
|
LIMIT ?
|
||||||
|
"""
|
||||||
|
|
||||||
|
txn.execute(sql, (user_id, from_id, to_id, limit))
|
||||||
|
return [(row[0], row[1], row[2]) for row in txn]
|
||||||
|
|
||||||
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_updated_thread_subscriptions_for_user",
|
||||||
|
get_updated_thread_subscriptions_for_user_txn,
|
||||||
|
)
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
--
|
||||||
|
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
--
|
||||||
|
-- Copyright (C) 2025 New Vector, Ltd
|
||||||
|
--
|
||||||
|
-- This program is free software: you can redistribute it and/or modify
|
||||||
|
-- it under the terms of the GNU Affero General Public License as
|
||||||
|
-- published by the Free Software Foundation, either version 3 of the
|
||||||
|
-- License, or (at your option) any later version.
|
||||||
|
--
|
||||||
|
-- See the GNU Affero General Public License for more details:
|
||||||
|
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
-- Introduce a table for tracking users' subscriptions to threads.
|
||||||
|
CREATE TABLE thread_subscriptions (
|
||||||
|
stream_id INTEGER NOT NULL PRIMARY KEY,
|
||||||
|
instance_name TEXT NOT NULL,
|
||||||
|
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
event_id TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
|
||||||
|
subscribed BOOLEAN NOT NULL,
|
||||||
|
automatic BOOLEAN NOT NULL,
|
||||||
|
|
||||||
|
CONSTRAINT thread_subscriptions_fk_users
|
||||||
|
FOREIGN KEY (user_id)
|
||||||
|
REFERENCES users(name),
|
||||||
|
|
||||||
|
CONSTRAINT thread_subscriptions_fk_rooms
|
||||||
|
FOREIGN KEY (room_id)
|
||||||
|
-- When we delete a room, we should already have deleted all the events in that room
|
||||||
|
-- and so there shouldn't be any subscriptions left in that room.
|
||||||
|
-- So the `ON DELETE CASCADE` should be optional, but included anyway for good measure.
|
||||||
|
REFERENCES rooms(room_id) ON DELETE CASCADE,
|
||||||
|
|
||||||
|
CONSTRAINT thread_subscriptions_fk_events
|
||||||
|
FOREIGN KEY (event_id)
|
||||||
|
REFERENCES events(event_id) ON DELETE CASCADE,
|
||||||
|
|
||||||
|
-- This order provides a useful index for:
|
||||||
|
-- 1. foreign key constraint on (room_id)
|
||||||
|
-- 2. foreign key constraint on (room_id, event_id)
|
||||||
|
-- 3. finding the user's settings for a specific thread (as well as enforcing uniqueness)
|
||||||
|
UNIQUE (room_id, event_id, user_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- this provides a useful index for finding a user's own rules,
|
||||||
|
-- potentially scoped to a single room
|
||||||
|
CREATE INDEX thread_subscriptions_user_room ON thread_subscriptions (user_id, room_id);
|
||||||
|
|
||||||
|
-- this provides a useful way for clients to efficiently find new changes to
|
||||||
|
-- their subscriptions.
|
||||||
|
-- (This is necessary to sync subscriptions between multiple devices.)
|
||||||
|
CREATE INDEX thread_subscriptions_by_user ON thread_subscriptions (user_id, stream_id);
|
||||||
|
|
||||||
|
-- this provides a useful index for deleting the subscriptions when the underlying
|
||||||
|
-- events are removed. This also covers the foreign key constraint on `events`.
|
||||||
|
CREATE INDEX thread_subscriptions_by_event ON thread_subscriptions (event_id);
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
--
|
||||||
|
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
--
|
||||||
|
-- Copyright (C) 2025 New Vector, Ltd
|
||||||
|
--
|
||||||
|
-- This program is free software: you can redistribute it and/or modify
|
||||||
|
-- it under the terms of the GNU Affero General Public License as
|
||||||
|
-- published by the Free Software Foundation, either version 3 of the
|
||||||
|
-- License, or (at your option) any later version.
|
||||||
|
--
|
||||||
|
-- See the GNU Affero General Public License for more details:
|
||||||
|
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
CREATE SEQUENCE thread_subscriptions_sequence
|
||||||
|
-- Synapse streams start at 2, because the default position is 1
|
||||||
|
-- so any item inserted at position 1 is ignored.
|
||||||
|
-- This is also what existing streams do, except they use `setval(..., 1)`
|
||||||
|
-- which is semantically the same except less obvious.
|
||||||
|
START WITH 2;
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
--
|
||||||
|
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
--
|
||||||
|
-- Copyright (C) 2025 New Vector, Ltd
|
||||||
|
--
|
||||||
|
-- This program is free software: you can redistribute it and/or modify
|
||||||
|
-- it under the terms of the GNU Affero General Public License as
|
||||||
|
-- published by the Free Software Foundation, either version 3 of the
|
||||||
|
-- License, or (at your option) any later version.
|
||||||
|
--
|
||||||
|
-- See the GNU Affero General Public License for more details:
|
||||||
|
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
COMMENT ON TABLE thread_subscriptions IS 'Tracks local users that subscribe to threads';
|
||||||
|
|
||||||
|
COMMENT ON COLUMN thread_subscriptions.subscribed IS 'Whether the user is subscribed to the thread or not. We track unsubscribed threads because we need to stream the subscription change to the client.';
|
||||||
|
|
||||||
|
COMMENT ON COLUMN thread_subscriptions.automatic IS 'True if the user was subscribed to the thread automatically by their client, or false if the client manually requested the subscription.';
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
--
|
||||||
|
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
--
|
||||||
|
-- Copyright (C) 2025 New Vector, Ltd
|
||||||
|
--
|
||||||
|
-- This program is free software: you can redistribute it and/or modify
|
||||||
|
-- it under the terms of the GNU Affero General Public License as
|
||||||
|
-- published by the Free Software Foundation, either version 3 of the
|
||||||
|
-- License, or (at your option) any later version.
|
||||||
|
--
|
||||||
|
-- See the GNU Affero General Public License for more details:
|
||||||
|
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
COMMENT ON COLUMN threads.latest_event_id IS
|
||||||
|
'the ID of the event that is latest, ordered by (topological_ordering, stream_ordering)';
|
||||||
|
|
||||||
|
COMMENT ON COLUMN threads.topological_ordering IS
|
||||||
|
$$the topological ordering of the thread''s LATEST event.
|
||||||
|
Used as the primary way of ordering threads by recency in a room.$$;
|
||||||
|
|
||||||
|
COMMENT ON COLUMN threads.stream_ordering IS
|
||||||
|
$$the stream ordering of the thread's LATEST event.
|
||||||
|
Used as a tie-breaker for ordering threads by recency in a room, when the topological order is a tie.
|
||||||
|
Also used for recency ordering in sliding sync.$$;
|
||||||
@@ -184,6 +184,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||||||
|
|
||||||
Note: Only works with Postgres.
|
Note: Only works with Postgres.
|
||||||
|
|
||||||
|
Warning: Streams using this generator start at ID 2, because ID 1 is always assumed
|
||||||
|
to have been 'seen as persisted'.
|
||||||
|
Unclear if this extant behaviour is desirable for some reason.
|
||||||
|
When creating a new sequence for a new stream,
|
||||||
|
it will be necessary to use `START WITH 2`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db_conn
|
db_conn
|
||||||
db
|
db
|
||||||
@@ -269,6 +275,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
|
|||||||
self._known_persisted_positions: List[int] = []
|
self._known_persisted_positions: List[int] = []
|
||||||
|
|
||||||
# The maximum stream ID that we have seen been allocated across any writer.
|
# The maximum stream ID that we have seen been allocated across any writer.
|
||||||
|
# Since this defaults to 1, this means that ID 1 is assumed to have already
|
||||||
|
# been 'seen'. In other words, multi-writer streams start at 2.
|
||||||
|
# Unclear if this is desirable behaviour.
|
||||||
self._max_seen_allocated_stream_id = 1
|
self._max_seen_allocated_stream_id = 1
|
||||||
|
|
||||||
# The maximum position of the local instance. This can be higher than
|
# The maximum position of the local instance. This can be higher than
|
||||||
|
|||||||
@@ -362,7 +362,8 @@ class RoomID(DomainSpecificString):
|
|||||||
|
|
||||||
@attr.s(slots=True, frozen=True, repr=False)
|
@attr.s(slots=True, frozen=True, repr=False)
|
||||||
class EventID(DomainSpecificString):
|
class EventID(DomainSpecificString):
|
||||||
"""Structure representing an event id."""
|
"""Structure representing an event ID which is namespaced to a homeserver.
|
||||||
|
Room versions 3 and above are not supported by this grammar."""
|
||||||
|
|
||||||
SIGIL = "$"
|
SIGIL = "$"
|
||||||
|
|
||||||
|
|||||||
157
tests/replication/tcp/streams/test_thread_subscriptions.py
Normal file
157
tests/replication/tcp/streams/test_thread_subscriptions.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
#
|
||||||
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
#
|
||||||
|
# Copyright (C) 2025 New Vector, Ltd
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# See the GNU Affero General Public License for more details:
|
||||||
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
#
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.replication.tcp.streams._base import (
|
||||||
|
_STREAM_UPDATE_TARGET_ROW_COUNT,
|
||||||
|
ThreadSubscriptionsStream,
|
||||||
|
)
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.database import LoggingTransaction
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
from tests.replication._base import BaseStreamTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
super().prepare(reactor, clock, hs)
|
||||||
|
|
||||||
|
# Postgres
|
||||||
|
def f(txn: LoggingTransaction) -> None:
|
||||||
|
txn.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE thread_subscriptions
|
||||||
|
DROP CONSTRAINT thread_subscriptions_fk_users,
|
||||||
|
DROP CONSTRAINT thread_subscriptions_fk_rooms,
|
||||||
|
DROP CONSTRAINT thread_subscriptions_fk_events;
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
self.hs.get_datastores().main.db_pool.runInteraction(
|
||||||
|
"disable_foreign_keys", f
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_thread_subscription_updates(self) -> None:
|
||||||
|
"""Test replication with thread subscription updates"""
|
||||||
|
store = self.hs.get_datastores().main
|
||||||
|
|
||||||
|
# Create thread subscription updates
|
||||||
|
updates = []
|
||||||
|
room_id = "!test_room:example.com"
|
||||||
|
|
||||||
|
# Generate several thread subscription updates
|
||||||
|
for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 5):
|
||||||
|
thread_root_id = f"$thread_{i}:example.com"
|
||||||
|
self.get_success(
|
||||||
|
store.subscribe_user_to_thread(
|
||||||
|
"@test_user:example.org",
|
||||||
|
room_id,
|
||||||
|
thread_root_id,
|
||||||
|
automatic=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
updates.append(thread_root_id)
|
||||||
|
|
||||||
|
# Also add one in a different room
|
||||||
|
other_room_id = "!other_room:example.com"
|
||||||
|
other_thread_root_id = "$other_thread:example.com"
|
||||||
|
self.get_success(
|
||||||
|
store.subscribe_user_to_thread(
|
||||||
|
"@test_user:example.org",
|
||||||
|
other_room_id,
|
||||||
|
other_thread_root_id,
|
||||||
|
automatic=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Not yet connected: no rows should yet have been received
|
||||||
|
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||||
|
|
||||||
|
# Now reconnect to pull the updates
|
||||||
|
self.reconnect()
|
||||||
|
self.replicate()
|
||||||
|
|
||||||
|
# We should have received all the expected rows in the right order
|
||||||
|
# Filter the updates to only include thread subscription changes
|
||||||
|
received_rows = [
|
||||||
|
upd
|
||||||
|
for upd in self.test_handler.received_rdata_rows
|
||||||
|
if upd[0] == ThreadSubscriptionsStream.NAME
|
||||||
|
]
|
||||||
|
|
||||||
|
# Verify all the thread subscription updates
|
||||||
|
for thread_id in updates:
|
||||||
|
(stream_name, token, row) = received_rows.pop(0)
|
||||||
|
self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME)
|
||||||
|
self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE)
|
||||||
|
self.assertEqual(row.user_id, "@test_user:example.org")
|
||||||
|
self.assertEqual(row.room_id, room_id)
|
||||||
|
self.assertEqual(row.event_id, thread_id)
|
||||||
|
|
||||||
|
# Verify the last update in the different room
|
||||||
|
(stream_name, token, row) = received_rows.pop(0)
|
||||||
|
self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME)
|
||||||
|
self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE)
|
||||||
|
self.assertEqual(row.user_id, "@test_user:example.org")
|
||||||
|
self.assertEqual(row.room_id, other_room_id)
|
||||||
|
self.assertEqual(row.event_id, other_thread_root_id)
|
||||||
|
|
||||||
|
self.assertEqual([], received_rows)
|
||||||
|
|
||||||
|
def test_multiple_users_thread_subscription_updates(self) -> None:
|
||||||
|
"""Test replication with thread subscription updates for multiple users"""
|
||||||
|
store = self.hs.get_datastores().main
|
||||||
|
room_id = "!test_room:example.com"
|
||||||
|
thread_root_id = "$thread_root:example.com"
|
||||||
|
|
||||||
|
# Create updates for multiple users
|
||||||
|
users = ["@user1:example.com", "@user2:example.com", "@user3:example.com"]
|
||||||
|
for user_id in users:
|
||||||
|
self.get_success(
|
||||||
|
store.subscribe_user_to_thread(
|
||||||
|
user_id, room_id, thread_root_id, automatic=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check no rows have been received yet
|
||||||
|
self.replicate()
|
||||||
|
self.assertEqual([], self.test_handler.received_rdata_rows)
|
||||||
|
|
||||||
|
# Not yet connected: no rows should yet have been received
|
||||||
|
self.reconnect()
|
||||||
|
self.replicate()
|
||||||
|
|
||||||
|
# We should have received all the expected rows
|
||||||
|
# Filter the updates to only include thread subscription changes
|
||||||
|
received_rows = [
|
||||||
|
upd
|
||||||
|
for upd in self.test_handler.received_rdata_rows
|
||||||
|
if upd[0] == ThreadSubscriptionsStream.NAME
|
||||||
|
]
|
||||||
|
|
||||||
|
# Should have one update per user
|
||||||
|
self.assertEqual(len(received_rows), len(users))
|
||||||
|
|
||||||
|
# Verify all updates
|
||||||
|
for i, user_id in enumerate(users):
|
||||||
|
(stream_name, token, row) = received_rows[i]
|
||||||
|
self.assertEqual(stream_name, ThreadSubscriptionsStream.NAME)
|
||||||
|
self.assertIsInstance(row, ThreadSubscriptionsStream.ROW_TYPE)
|
||||||
|
self.assertEqual(row.user_id, user_id)
|
||||||
|
self.assertEqual(row.room_id, room_id)
|
||||||
|
self.assertEqual(row.event_id, thread_root_id)
|
||||||
256
tests/rest/client/test_thread_subscriptions.py
Normal file
256
tests/rest/client/test_thread_subscriptions.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
#
|
||||||
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
#
|
||||||
|
# Copyright (C) 2025 New Vector, Ltd
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# See the GNU Affero General Public License for more details:
|
||||||
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.rest import admin
|
||||||
|
from synapse.rest.client import login, profile, room, thread_subscriptions
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.types import JsonDict
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
PREFIX = "/_matrix/client/unstable/io.element.msc4306/rooms"
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||||
|
servlets = [
|
||||||
|
admin.register_servlets_for_client_rest_resource,
|
||||||
|
login.register_servlets,
|
||||||
|
profile.register_servlets,
|
||||||
|
room.register_servlets,
|
||||||
|
thread_subscriptions.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def default_config(self) -> JsonDict:
|
||||||
|
config = super().default_config()
|
||||||
|
config["experimental_features"] = {"msc4306_enabled": True}
|
||||||
|
return config
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.user_id = self.register_user("user", "password")
|
||||||
|
self.token = self.login("user", "password")
|
||||||
|
self.other_user_id = self.register_user("other_user", "password")
|
||||||
|
self.other_token = self.login("other_user", "password")
|
||||||
|
|
||||||
|
# Create a room and send a message to use as a thread root
|
||||||
|
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||||
|
self.helper.join(self.room_id, self.other_user_id, tok=self.other_token)
|
||||||
|
response = self.helper.send(self.room_id, body="Root message", tok=self.token)
|
||||||
|
self.root_event_id = response["event_id"]
|
||||||
|
|
||||||
|
# Send a message in the thread
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id=self.room_id,
|
||||||
|
type="m.room.message",
|
||||||
|
content={
|
||||||
|
"body": "Thread message",
|
||||||
|
"msgtype": "m.text",
|
||||||
|
"m.relates_to": {
|
||||||
|
"rel_type": "m.thread",
|
||||||
|
"event_id": self.root_event_id,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
tok=self.token,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_thread_subscription_unsubscribed(self) -> None:
|
||||||
|
"""Test retrieving thread subscription when not subscribed."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||||
|
|
||||||
|
def test_get_thread_subscription_nonexistent_thread(self) -> None:
|
||||||
|
"""Test retrieving subscription settings for a nonexistent thread."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription",
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||||
|
|
||||||
|
def test_get_thread_subscription_no_access(self) -> None:
|
||||||
|
"""Test that a user can't get thread subscription for a thread they can't access."""
|
||||||
|
self.register_user("no_access", "password")
|
||||||
|
no_access_token = self.login("no_access", "password")
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
access_token=no_access_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||||
|
|
||||||
|
def test_subscribe_manual_then_automatic(self) -> None:
|
||||||
|
"""Test subscribing to a thread, first a manual subscription then an automatic subscription.
|
||||||
|
The manual subscription wins over the automatic one."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
{
|
||||||
|
"automatic": False,
|
||||||
|
},
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Assert the subscription was saved
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
|
self.assertEqual(channel.json_body, {"automatic": False})
|
||||||
|
|
||||||
|
# Now also register an automatic subscription; it should not
|
||||||
|
# override the manual subscription
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
{"automatic": True},
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Assert the manual subscription was not overridden
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
|
self.assertEqual(channel.json_body, {"automatic": False})
|
||||||
|
|
||||||
|
def test_subscribe_automatic_then_manual(self) -> None:
|
||||||
|
"""Test subscribing to a thread, first an automatic subscription then a manual subscription.
|
||||||
|
The manual subscription wins over the automatic one."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
{
|
||||||
|
"automatic": True,
|
||||||
|
},
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Assert the subscription was saved
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
|
self.assertEqual(channel.json_body, {"automatic": True})
|
||||||
|
|
||||||
|
# Now also register a manual subscription
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
{"automatic": False},
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Assert the manual subscription was not overridden
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
|
self.assertEqual(channel.json_body, {"automatic": False})
|
||||||
|
|
||||||
|
def test_unsubscribe(self) -> None:
|
||||||
|
"""Test subscribing to a thread, then unsubscribing."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
{
|
||||||
|
"automatic": True,
|
||||||
|
},
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Assert the subscription was saved
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
|
self.assertEqual(channel.json_body, {"automatic": True})
|
||||||
|
|
||||||
|
# Now also register a manual subscription
|
||||||
|
channel = self.make_request(
|
||||||
|
"DELETE",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Assert the manual subscription was not overridden
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||||
|
|
||||||
|
def test_set_thread_subscription_nonexistent_thread(self) -> None:
|
||||||
|
"""Test setting subscription settings for a nonexistent thread."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription",
|
||||||
|
{"automatic": True},
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||||
|
|
||||||
|
def test_set_thread_subscription_no_access(self) -> None:
|
||||||
|
"""Test that a user can't set thread subscription for a thread they can't access."""
|
||||||
|
self.register_user("no_access2", "password")
|
||||||
|
no_access_token = self.login("no_access2", "password")
|
||||||
|
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
{"automatic": True},
|
||||||
|
access_token=no_access_token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||||
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||||
|
|
||||||
|
def test_invalid_body(self) -> None:
|
||||||
|
"""Test that sending invalid subscription settings is rejected."""
|
||||||
|
channel = self.make_request(
|
||||||
|
"PUT",
|
||||||
|
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||||
|
# non-boolean `automatic`
|
||||||
|
{"automatic": "true"},
|
||||||
|
access_token=self.token,
|
||||||
|
)
|
||||||
|
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
|
||||||
272
tests/storage/test_thread_subscriptions.py
Normal file
272
tests/storage/test_thread_subscriptions.py
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
#
|
||||||
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||||
|
#
|
||||||
|
# Copyright (C) 2025 New Vector, Ltd
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as
|
||||||
|
# published by the Free Software Foundation, either version 3 of the
|
||||||
|
# License, or (at your option) any later version.
|
||||||
|
#
|
||||||
|
# See the GNU Affero General Public License for more details:
|
||||||
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||||
|
#
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.database import LoggingTransaction
|
||||||
|
from synapse.storage.engines.sqlite import Sqlite3Engine
|
||||||
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.store = self.hs.get_datastores().main
|
||||||
|
self.user_id = "@user:test"
|
||||||
|
self.room_id = "!room:test"
|
||||||
|
self.thread_root_id = "$thread_root:test"
|
||||||
|
self.other_thread_root_id = "$other_thread_root:test"
|
||||||
|
|
||||||
|
# Disable foreign key checks for testing
|
||||||
|
# This allows us to insert test data without having to create actual events
|
||||||
|
db_pool = self.store.db_pool
|
||||||
|
if isinstance(db_pool.engine, Sqlite3Engine):
|
||||||
|
self.get_success(
|
||||||
|
db_pool.execute("disable_foreign_keys", "PRAGMA foreign_keys = OFF;")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Postgres
|
||||||
|
def f(txn: LoggingTransaction) -> None:
|
||||||
|
txn.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE thread_subscriptions
|
||||||
|
DROP CONSTRAINT thread_subscriptions_fk_users,
|
||||||
|
DROP CONSTRAINT thread_subscriptions_fk_rooms,
|
||||||
|
DROP CONSTRAINT thread_subscriptions_fk_events;
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_success(db_pool.runInteraction("disable_foreign_keys", f))
|
||||||
|
|
||||||
|
# Create rooms and events in the db to satisfy foreign key constraints
|
||||||
|
self.get_success(db_pool.simple_insert("rooms", {"room_id": self.room_id}))
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
db_pool.simple_insert(
|
||||||
|
"events",
|
||||||
|
{
|
||||||
|
"event_id": self.thread_root_id,
|
||||||
|
"room_id": self.room_id,
|
||||||
|
"topological_ordering": 1,
|
||||||
|
"stream_ordering": 1,
|
||||||
|
"type": "m.room.message",
|
||||||
|
"depth": 1,
|
||||||
|
"processed": True,
|
||||||
|
"outlier": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.get_success(
|
||||||
|
db_pool.simple_insert(
|
||||||
|
"events",
|
||||||
|
{
|
||||||
|
"event_id": self.other_thread_root_id,
|
||||||
|
"room_id": self.room_id,
|
||||||
|
"topological_ordering": 2,
|
||||||
|
"stream_ordering": 2,
|
||||||
|
"type": "m.room.message",
|
||||||
|
"depth": 2,
|
||||||
|
"processed": True,
|
||||||
|
"outlier": False,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the user
|
||||||
|
self.get_success(
|
||||||
|
db_pool.simple_insert("users", {"name": self.user_id, "is_guest": 0})
|
||||||
|
)
|
||||||
|
|
||||||
|
def _subscribe(
|
||||||
|
self,
|
||||||
|
thread_root_id: str,
|
||||||
|
*,
|
||||||
|
automatic: bool,
|
||||||
|
room_id: Optional[str] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
) -> Optional[int]:
|
||||||
|
if user_id is None:
|
||||||
|
user_id = self.user_id
|
||||||
|
|
||||||
|
if room_id is None:
|
||||||
|
room_id = self.room_id
|
||||||
|
|
||||||
|
return self.get_success(
|
||||||
|
self.store.subscribe_user_to_thread(
|
||||||
|
user_id,
|
||||||
|
room_id,
|
||||||
|
thread_root_id,
|
||||||
|
automatic=automatic,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _unsubscribe(
|
||||||
|
self,
|
||||||
|
thread_root_id: str,
|
||||||
|
room_id: Optional[str] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
) -> Optional[int]:
|
||||||
|
if user_id is None:
|
||||||
|
user_id = self.user_id
|
||||||
|
|
||||||
|
if room_id is None:
|
||||||
|
room_id = self.room_id
|
||||||
|
|
||||||
|
return self.get_success(
|
||||||
|
self.store.unsubscribe_user_from_thread(
|
||||||
|
user_id,
|
||||||
|
room_id,
|
||||||
|
thread_root_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_set_and_get_thread_subscription(self) -> None:
|
||||||
|
"""Test basic setting and getting of thread subscriptions."""
|
||||||
|
# Initial state: no subscription
|
||||||
|
subscription = self.get_success(
|
||||||
|
self.store.get_subscription_for_thread(
|
||||||
|
self.user_id, self.room_id, self.thread_root_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertIsNone(subscription)
|
||||||
|
|
||||||
|
# Subscribe
|
||||||
|
self._subscribe(
|
||||||
|
self.thread_root_id,
|
||||||
|
automatic=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert subscription went through
|
||||||
|
subscription = self.get_success(
|
||||||
|
self.store.get_subscription_for_thread(
|
||||||
|
self.user_id, self.room_id, self.thread_root_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(subscription)
|
||||||
|
self.assertTrue(subscription.automatic) # type: ignore
|
||||||
|
|
||||||
|
# Now make it a manual subscription
|
||||||
|
self._subscribe(
|
||||||
|
self.thread_root_id,
|
||||||
|
automatic=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert the manual subscription overrode the automatic one
|
||||||
|
subscription = self.get_success(
|
||||||
|
self.store.get_subscription_for_thread(
|
||||||
|
self.user_id, self.room_id, self.thread_root_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertFalse(subscription.automatic) # type: ignore
|
||||||
|
|
||||||
|
def test_purge_thread_subscriptions_for_user(self) -> None:
|
||||||
|
"""Test purging all thread subscription settings for a user."""
|
||||||
|
# Set subscription settings for multiple threads
|
||||||
|
self._subscribe(self.thread_root_id, automatic=True)
|
||||||
|
self._subscribe(self.other_thread_root_id, automatic=False)
|
||||||
|
|
||||||
|
subscriptions = self.get_success(
|
||||||
|
self.store.get_updated_thread_subscriptions_for_user(
|
||||||
|
self.user_id,
|
||||||
|
from_id=0,
|
||||||
|
to_id=50,
|
||||||
|
limit=50,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
min_id = min(id for (id, _, _) in subscriptions)
|
||||||
|
self.assertEqual(
|
||||||
|
subscriptions,
|
||||||
|
[
|
||||||
|
(min_id, self.room_id, self.thread_root_id),
|
||||||
|
(min_id + 1, self.room_id, self.other_thread_root_id),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Purge all settings for the user
|
||||||
|
self.get_success(
|
||||||
|
self.store.purge_thread_subscription_settings_for_user(self.user_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check user has no subscriptions
|
||||||
|
subscriptions = self.get_success(
|
||||||
|
self.store.get_updated_thread_subscriptions_for_user(
|
||||||
|
self.user_id,
|
||||||
|
from_id=0,
|
||||||
|
to_id=50,
|
||||||
|
limit=50,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(subscriptions, [])
|
||||||
|
|
||||||
|
def test_get_updated_thread_subscriptions(self) -> None:
|
||||||
|
"""Test getting updated thread subscriptions since a stream ID."""
|
||||||
|
|
||||||
|
stream_id1 = self._subscribe(self.thread_root_id, automatic=False)
|
||||||
|
stream_id2 = self._subscribe(self.other_thread_root_id, automatic=True)
|
||||||
|
assert stream_id1 is not None
|
||||||
|
assert stream_id2 is not None
|
||||||
|
|
||||||
|
# Get updates since initial ID (should include both changes)
|
||||||
|
updates = self.get_success(
|
||||||
|
self.store.get_updated_thread_subscriptions(0, stream_id2, 10)
|
||||||
|
)
|
||||||
|
self.assertEqual(len(updates), 2)
|
||||||
|
|
||||||
|
# Get updates since first change (should include only the second change)
|
||||||
|
updates = self.get_success(
|
||||||
|
self.store.get_updated_thread_subscriptions(stream_id1, stream_id2, 10)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
updates,
|
||||||
|
[(stream_id2, self.user_id, self.room_id, self.other_thread_root_id)],
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_get_updated_thread_subscriptions_for_user(self) -> None:
|
||||||
|
"""Test getting updated thread subscriptions for a specific user."""
|
||||||
|
other_user_id = "@other_user:test"
|
||||||
|
|
||||||
|
# Set thread subscription for main user
|
||||||
|
stream_id1 = self._subscribe(self.thread_root_id, automatic=True)
|
||||||
|
assert stream_id1 is not None
|
||||||
|
|
||||||
|
# Set thread subscription for other user
|
||||||
|
stream_id2 = self._subscribe(
|
||||||
|
self.other_thread_root_id,
|
||||||
|
automatic=True,
|
||||||
|
user_id=other_user_id,
|
||||||
|
)
|
||||||
|
assert stream_id2 is not None
|
||||||
|
|
||||||
|
# Get updates for main user
|
||||||
|
updates = self.get_success(
|
||||||
|
self.store.get_updated_thread_subscriptions_for_user(
|
||||||
|
self.user_id, 0, stream_id2, 10
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)])
|
||||||
|
|
||||||
|
# Get updates for other user
|
||||||
|
updates = self.get_success(
|
||||||
|
self.store.get_updated_thread_subscriptions_for_user(
|
||||||
|
other_user_id, 0, max(stream_id1, stream_id2), 10
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
updates, [(stream_id2, self.room_id, self.other_thread_root_id)]
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user