1
0

Compare commits

...

22 Commits

Author SHA1 Message Date
Olivier 'reivilibre
e3fa8416cb Merge branch 'rei/msc4306_postcontent' into rei/threadsubs_all 2025-09-03 14:57:30 +01:00
Olivier 'reivilibre
fa8e3b62c7 Simplify if 2025-09-03 13:53:02 +01:00
Olivier 'reivilibre
924c1bfd0e Use copy_and_replace in get_current_token_for_pagination 2025-09-03 13:48:46 +01:00
Olivier 'reivilibre
0cf178a601 Add notifier hooks for sliding sync 2025-09-02 18:30:46 +01:00
reivilibre
1374895a73 Update synapse/handlers/sliding_sync/extensions.py
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
2025-09-02 10:30:47 +01:00
reivilibre
168b67b43a Update synapse/handlers/sliding_sync/extensions.py
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
2025-09-02 10:30:38 +01:00
reivilibre
921cd53a69 Update tests/rest/client/sliding_sync/test_extension_thread_subscriptions.py
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
2025-09-02 10:23:45 +01:00
Olivier 'reivilibre
f4cd180c6c Newsfile
Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
2025-08-21 17:26:51 +01:00
Olivier 'reivilibre
e72d6cdb5f Add companion endpoint for backpagination of thread subscriptions 2025-08-21 17:26:51 +01:00
Olivier 'reivilibre
4a34641af8 Implement sliding sync extension part of MSC4308
Put MSC4308 behind the MSC4306 feature flag
2025-08-21 17:26:51 +01:00
Olivier 'reivilibre
18881b11f2 Add overload for parse_integer_from_args 2025-08-21 12:39:43 +01:00
Olivier 'reivilibre
0c310b9ef7 Add comment to MultiWriterIdGenerator about cursed sequence semantics 2025-08-21 12:39:43 +01:00
Olivier 'reivilibre
0ce5dce42b Fix thread_subscriptions stream sequence
Works around https://github.com/element-hq/synapse/issues/18712
2025-08-21 12:39:43 +01:00
Olivier 'reivilibre
4dcd12b8d1 Add subscribed and automatic to get_updated_thread_subscriptions_for_user 2025-08-21 12:39:42 +01:00
Olivier 'reivilibre
748316c14a Add thread subscriptions position to StreamToken 2025-08-21 12:39:30 +01:00
Olivier 'reivilibre
f1f56570d1 Add overload for gather_optional_coroutines/6 2025-08-21 12:39:30 +01:00
Olivier 'reivilibre
09f86339c1 Add models for Thread Subscriptions extension to Sliding Sync 2025-08-21 12:39:30 +01:00
Olivier 'reivilibre
875dbf70c8 spelling 2025-08-21 12:39:30 +01:00
Olivier 'reivilibre
2961006785 Update simplified sliding sync docstring 2025-08-21 12:39:30 +01:00
Olivier 'reivilibre
3de8c2146d Newsfile
Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
2025-08-20 16:28:21 +01:00
Olivier 'reivilibre
c7c398eef8 Prevent users from creating user-defined postcontent rules 2025-08-20 16:28:21 +01:00
Olivier 'reivilibre
947d9127d8 Move the MSC4306 push rules to a new kind postcontent 2025-08-20 16:28:21 +01:00
32 changed files with 1067 additions and 69 deletions

View File

@@ -0,0 +1 @@
Add experimental support for [MSC4308: Thread Subscriptions extension to Sliding Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4308) when [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-spec-proposals/pull/4306) and [MSC4186: Simplified Sliding Sync](https://github.com/matrix-org/matrix-spec-proposals/pull/4186) are enabled.

View File

@@ -0,0 +1 @@
Update push rules for experimental [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306) to follow newer draft.

View File

@@ -289,10 +289,10 @@ pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule {
default_enabled: true,
}];
pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
pub const BASE_APPEND_POSTCONTENT_RULES: &[PushRule] = &[
PushRule {
rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.unsubscribed_thread"),
priority_class: 1,
rule_id: Cow::Borrowed("global/postcontent/.io.element.msc4306.rule.unsubscribed_thread"),
priority_class: 6,
conditions: Cow::Borrowed(&[Condition::Known(
KnownCondition::Msc4306ThreadSubscription { subscribed: false },
)]),
@@ -301,8 +301,8 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.subscribed_thread"),
priority_class: 1,
rule_id: Cow::Borrowed("global/postcontent/.io.element.msc4306.rule.subscribed_thread"),
priority_class: 6,
conditions: Cow::Borrowed(&[Condition::Known(
KnownCondition::Msc4306ThreadSubscription { subscribed: true },
)]),
@@ -310,6 +310,9 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
default: true,
default_enabled: true,
},
];
pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
PushRule {
rule_id: Cow::Borrowed("global/underride/.m.rule.call"),
priority_class: 1,
@@ -726,6 +729,7 @@ lazy_static! {
.iter()
.chain(BASE_APPEND_OVERRIDE_RULES.iter())
.chain(BASE_APPEND_CONTENT_RULES.iter())
.chain(BASE_APPEND_POSTCONTENT_RULES.iter())
.chain(BASE_APPEND_UNDERRIDE_RULES.iter())
.map(|rule| { (&*rule.rule_id, rule) })
.collect();

View File

@@ -527,6 +527,7 @@ impl PushRules {
.chain(base_rules::BASE_APPEND_OVERRIDE_RULES.iter())
.chain(self.content.iter())
.chain(base_rules::BASE_APPEND_CONTENT_RULES.iter())
.chain(base_rules::BASE_APPEND_POSTCONTENT_RULES.iter())
.chain(self.room.iter())
.chain(self.sender.iter())
.chain(self.underride.iter())

View File

@@ -590,5 +590,5 @@ class ExperimentalConfig(Config):
self.msc4293_enabled: bool = experimental.get("msc4293_enabled", False)
# MSC4306: Thread Subscriptions
# (and MSC4308: sliding sync extension for thread subscriptions)
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)

View File

@@ -135,7 +135,7 @@ class PublicRoomList(BaseFederationServlet):
if not self.allow_access:
raise FederationDeniedError(origin)
limit = parse_integer_from_args(query, "limit", 0)
limit: Optional[int] = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None)
include_all_networks = parse_boolean_from_args(
query, "include_all_networks", default=False

View File

@@ -211,7 +211,7 @@ class SlidingSyncHandler:
Args:
sync_config: Sync configuration
to_token: The point in the stream to sync up to.
to_token: The latest point in the stream to sync up to.
from_token: The point in the stream to sync from. Token of the end of the
previous batch. May be `None` if this is the initial sync request.
"""

View File

@@ -27,7 +27,7 @@ from typing import (
cast,
)
from typing_extensions import assert_never
from typing_extensions import TypeAlias, assert_never
from synapse.api.constants import AccountDataTypes, EduTypes
from synapse.handlers.receipts import ReceiptEventSource
@@ -40,6 +40,7 @@ from synapse.types import (
SlidingSyncStreamToken,
StrCollection,
StreamToken,
ThreadSubscriptionsToken,
)
from synapse.types.handlers.sliding_sync import (
HaveSentRoomFlag,
@@ -54,6 +55,13 @@ from synapse.util.async_helpers import (
gather_optional_coroutines,
)
_ThreadSubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
)
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -68,6 +76,7 @@ class SlidingSyncExtensionHandler:
self.event_sources = hs.get_event_sources()
self.device_handler = hs.get_device_handler()
self.push_rules_handler = hs.get_push_rules_handler()
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
@trace
async def get_extensions_response(
@@ -93,7 +102,7 @@ class SlidingSyncExtensionHandler:
actual_room_ids: The actual room IDs in the the Sliding Sync response.
actual_room_response_map: A map of room ID to room results in the the
Sliding Sync response.
to_token: The point in the stream to sync up to.
to_token: The latest point in the stream to sync up to.
from_token: The point in the stream to sync from.
"""
@@ -156,18 +165,32 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
thread_subs_coro = None
if (
sync_config.extensions.thread_subscriptions is not None
and self._enable_thread_subscriptions
):
thread_subs_coro = self.get_thread_subscriptions_extension_response(
sync_config=sync_config,
thread_subscriptions_request=sync_config.extensions.thread_subscriptions,
to_token=to_token,
from_token=from_token,
)
(
to_device_response,
e2ee_response,
account_data_response,
receipts_response,
typing_response,
thread_subs_response,
) = await gather_optional_coroutines(
to_device_coro,
e2ee_coro,
account_data_coro,
receipts_coro,
typing_coro,
thread_subs_coro,
)
return SlidingSyncResult.Extensions(
@@ -176,6 +199,7 @@ class SlidingSyncExtensionHandler:
account_data=account_data_response,
receipts=receipts_response,
typing=typing_response,
thread_subscriptions=thread_subs_response,
)
def find_relevant_room_ids_for_extension(
@@ -877,3 +901,72 @@ class SlidingSyncExtensionHandler:
return SlidingSyncResult.Extensions.TypingExtension(
room_id_to_typing_map=room_id_to_typing_map,
)
async def get_thread_subscriptions_extension_response(
self,
sync_config: SlidingSyncConfig,
thread_subscriptions_request: SlidingSyncConfig.Extensions.ThreadSubscriptionsExtension,
to_token: StreamToken,
from_token: Optional[SlidingSyncStreamToken],
) -> Optional[SlidingSyncResult.Extensions.ThreadSubscriptionsExtension]:
"""Handle Thread Subscriptions extension (MSC4308)
Args:
sync_config: Sync configuration
thread_subscriptions_request: The thread_subscriptions extension from the request
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
Returns:
the response (None if empty or thread subscriptions are disabled)
"""
if not thread_subscriptions_request.enabled:
return None
limit = thread_subscriptions_request.limit
if from_token:
from_stream_id = from_token.stream_token.thread_subscriptions_key
else:
from_stream_id = StreamToken.START.thread_subscriptions_key
to_stream_id = to_token.thread_subscriptions_key
updates = await self.store.get_latest_updated_thread_subscriptions_for_user(
user_id=sync_config.user.to_string(),
from_id=from_stream_id,
to_id=to_stream_id,
limit=limit,
)
if len(updates) == 0:
return None
subscribed_threads: Dict[str, Dict[str, _ThreadSubscription]] = {}
unsubscribed_threads: Dict[str, Dict[str, _ThreadUnsubscription]] = {}
for stream_id, room_id, thread_root_id, subscribed, automatic in updates:
if subscribed:
subscribed_threads.setdefault(room_id, {})[thread_root_id] = (
_ThreadSubscription(
automatic=automatic,
bump_stamp=stream_id,
)
)
else:
unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = (
_ThreadUnsubscription(bump_stamp=stream_id)
)
prev_batch = None
if len(updates) == limit:
# Tell the client about a potential gap where there may be more
# thread subscriptions for it to backpaginate.
# We subtract one because the 'later in the stream' bound is inclusive,
# and we already saw the element at index 0.
prev_batch = ThreadSubscriptionsToken(updates[0][0] - 1)
return SlidingSyncResult.Extensions.ThreadSubscriptionsExtension(
subscribed=subscribed_threads,
unsubscribed=unsubscribed_threads,
prev_batch=prev_batch,
)

View File

@@ -9,7 +9,7 @@ from synapse.storage.databases.main.thread_subscriptions import (
AutomaticSubscriptionConflicted,
ThreadSubscription,
)
from synapse.types import EventOrderings, UserID
from synapse.types import EventOrderings, StreamKeyType, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -22,6 +22,7 @@ class ThreadSubscriptionsHandler:
self.store = hs.get_datastores().main
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
self._notifier = hs.get_notifier()
async def get_thread_subscription_settings(
self,
@@ -132,6 +133,15 @@ class ThreadSubscriptionsHandler:
errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION,
)
if outcome is not None:
# wake up user streams (e.g. sliding sync) on the same worker
self._notifier.on_new_event(
StreamKeyType.THREAD_SUBSCRIPTIONS,
# outcome is a stream_id
outcome,
users=[user_id.to_string()],
)
return outcome
async def unsubscribe_user_from_thread(
@@ -162,8 +172,19 @@ class ThreadSubscriptionsHandler:
logger.info("rejecting thread subscriptions change (thread not accessible)")
raise NotFoundError("No such thread root")
return await self.store.unsubscribe_user_from_thread(
outcome = await self.store.unsubscribe_user_from_thread(
user_id.to_string(),
event.room_id,
thread_root_event_id,
)
if outcome is not None:
# wake up user streams (e.g. sliding sync) on the same worker
self._notifier.on_new_event(
StreamKeyType.THREAD_SUBSCRIPTIONS,
# outcome is a stream_id
outcome,
users=[user_id.to_string()],
)
return outcome

View File

@@ -130,6 +130,16 @@ def parse_integer(
return parse_integer_from_args(args, name, default, required, negative)
@overload
def parse_integer_from_args(
args: Mapping[bytes, Sequence[bytes]],
name: str,
default: int,
required: Literal[False] = False,
negative: bool = False,
) -> int: ...
@overload
def parse_integer_from_args(
args: Mapping[bytes, Sequence[bytes]],

View File

@@ -522,6 +522,7 @@ class Notifier:
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
StreamKeyType.THREAD_SUBSCRIPTIONS,
],
new_token: int,
users: Optional[Collection[Union[str, UserID]]] = None,

View File

@@ -91,7 +91,7 @@ def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]:
unscoped_rule_id = _rule_id_from_namespaced(rule.rule_id)
template_name = _priority_class_to_template_name(rule.priority_class)
if template_name in ["override", "underride"]:
if template_name in ["override", "underride", "postcontent"]:
templaterule = {"conditions": rule.conditions, "actions": rule.actions}
elif template_name in ["sender", "room"]:
templaterule = {"actions": rule.actions}

View File

@@ -19,10 +19,14 @@
#
#
# Integer literals for push rule `kind`s
# This is used to store them in the database.
PRIORITY_CLASS_MAP = {
"underride": 1,
"sender": 2,
"room": 3,
# MSC4306
"postcontent": 6,
"content": 4,
"override": 5,
}

View File

@@ -44,6 +44,7 @@ from synapse.replication.tcp.streams import (
UnPartialStatedEventStream,
UnPartialStatedRoomStream,
)
from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamEventRow,
@@ -255,6 +256,12 @@ class ReplicationDataHandler:
self._state_storage_controller.notify_event_un_partial_stated(
row.event_id
)
elif stream_name == ThreadSubscriptionsStream.NAME:
self.notifier.on_new_event(
StreamKeyType.THREAD_SUBSCRIPTIONS,
token,
users=[row.user_id for row in rows],
)
await self._presence_handler.process_replication_rows(
stream_name, instance_name, token, rows

View File

@@ -19,9 +19,11 @@
#
#
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Tuple, Union
from synapse.api.errors import (
Codes,
NotFoundError,
StoreError,
SynapseError,
@@ -239,6 +241,15 @@ def _rule_spec_from_path(path: List[str]) -> RuleSpec:
def _rule_tuple_from_request_object(
rule_template: str, rule_id: str, req_obj: JsonDict
) -> Tuple[List[JsonDict], List[Union[str, JsonDict]]]:
if rule_template == "postcontent":
# postcontent is from MSC4306, which says that clients
# cannot create their own postcontent rules right now.
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"user-defined rules using `postcontent` are not accepted",
errcode=Codes.INVALID_PARAM,
)
if rule_template in ["override", "underride"]:
if "conditions" not in req_obj:
raise InvalidRuleException("Missing 'conditions'")

View File

@@ -23,6 +23,8 @@ import logging
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union
import attr
from synapse.api.constants import AccountDataTypes, EduTypes, Membership, PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import FilterCollection
@@ -805,12 +807,21 @@ class SlidingSyncE2eeRestServlet(RestServlet):
class SlidingSyncRestServlet(RestServlet):
"""
API endpoint for MSC3575 Sliding Sync `/sync`. Allows for clients to request a
API endpoint for MSC4186 Simplified Sliding Sync `/sync`, which was historically derived
from MSC3575 (Sliding Sync; now abandoned). Allows for clients to request a
subset (sliding window) of rooms, state, and timeline events (just what they need)
in order to bootstrap quickly and subscribe to only what the client cares about.
Because the client can specify what it cares about, we can respond quickly and skip
all of the work we would normally have to do with a sync v2 response.
Extensions of various features are defined in:
- to-device messaging (MSC3885)
- end-to-end encryption (MSC3884)
- typing notifications (MSC3961)
- receipts (MSC3960)
- account data (MSC3959)
- thread subscriptions (MSC4308)
Request query parameters:
timeout: How long to wait for new events in milliseconds.
pos: Stream position token when asking for incremental deltas.
@@ -1247,9 +1258,48 @@ class SlidingSyncRestServlet(RestServlet):
"rooms": extensions.typing.room_id_to_typing_map,
}
# excludes both None and falsy `thread_subscriptions`
if extensions.thread_subscriptions:
serialized_extensions["io.element.msc4308.thread_subscriptions"] = (
_serialise_thread_subscriptions(extensions.thread_subscriptions)
)
return serialized_extensions
def _serialise_thread_subscriptions(
thread_subscriptions: SlidingSyncResult.Extensions.ThreadSubscriptionsExtension,
) -> JsonDict:
out: JsonDict = {}
if thread_subscriptions.subscribed:
out["subscribed"] = {
room_id: {
thread_root_id: attr.asdict(
change, filter=lambda _attr, v: v is not None
)
for thread_root_id, change in room_threads.items()
}
for room_id, room_threads in thread_subscriptions.subscribed.items()
}
if thread_subscriptions.unsubscribed:
out["unsubscribed"] = {
room_id: {
thread_root_id: attr.asdict(
change, filter=lambda _attr, v: v is not None
)
for thread_root_id, change in room_threads.items()
}
for room_id, room_threads in thread_subscriptions.unsubscribed.items()
}
if thread_subscriptions.prev_batch:
out["prev_batch"] = thread_subscriptions.prev_batch.to_string()
return out
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)

View File

@@ -1,21 +1,39 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple
import attr
from typing_extensions import TypeAlias
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
parse_and_validate_json_object_from_request,
parse_integer,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict, RoomID
from synapse.types import (
JsonDict,
RoomID,
SlidingSyncStreamToken,
ThreadSubscriptionsToken,
)
from synapse.types.handlers.sliding_sync import SlidingSyncResult
from synapse.types.rest import RequestBodyModel
from synapse.util.pydantic_models import AnyEventId
if TYPE_CHECKING:
from synapse.server import HomeServer
_ThreadSubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
)
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
class ThreadSubscriptionsRestServlet(RestServlet):
PATTERNS = client_patterns(
@@ -100,6 +118,129 @@ class ThreadSubscriptionsRestServlet(RestServlet):
return HTTPStatus.OK, {}
class ThreadSubscriptionsPaginationRestServlet(RestServlet):
PATTERNS = client_patterns(
"/io.element.msc4308/thread_subscriptions$",
unstable=True,
releases=(),
)
CATEGORY = "Thread Subscriptions requests (unstable)"
# Maximum number of thread subscriptions to return in one request.
MAX_LIMIT = 512
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.is_mine = hs.is_mine
self.store = hs.get_datastores().main
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
limit = min(
parse_integer(request, "limit", default=100, negative=False),
ThreadSubscriptionsPaginationRestServlet.MAX_LIMIT,
)
from_end_opt = parse_string(request, "from", required=False)
to_start_opt = parse_string(request, "to", required=False)
_direction = parse_string(request, "dir", required=True, allowed_values=("b",))
if limit <= 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"limit must be greater than 0",
errcode=Codes.INVALID_PARAM,
)
if from_end_opt is not None:
try:
# because of backwards pagination, the `from` token is actually the
# bound closest to the end of the stream
end_stream_id = ThreadSubscriptionsToken.from_string(
from_end_opt
).stream_id
except ValueError:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"`from` is not a valid token",
errcode=Codes.INVALID_PARAM,
)
else:
end_stream_id = self.store.get_max_thread_subscriptions_stream_id()
if to_start_opt is not None:
# because of backwards pagination, the `to` token is actually the
# bound closest to the start of the stream
try:
start_stream_id = ThreadSubscriptionsToken.from_string(
to_start_opt
).stream_id
except ValueError:
# we also accept sliding sync `pos` tokens on this parameter
try:
sliding_sync_pos = await SlidingSyncStreamToken.from_string(
self.store, to_start_opt
)
start_stream_id = (
sliding_sync_pos.stream_token.thread_subscriptions_key
)
except ValueError:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"`to` is not a valid token",
errcode=Codes.INVALID_PARAM,
)
else:
# the start of time is ID 1; the lower bound is exclusive though
start_stream_id = 0
subscriptions = (
await self.store.get_latest_updated_thread_subscriptions_for_user(
requester.user.to_string(),
from_id=start_stream_id,
to_id=end_stream_id,
limit=limit,
)
)
subscribed_threads: Dict[str, Dict[str, JsonDict]] = {}
unsubscribed_threads: Dict[str, Dict[str, JsonDict]] = {}
for stream_id, room_id, thread_root_id, subscribed, automatic in subscriptions:
if subscribed:
subscribed_threads.setdefault(room_id, {})[thread_root_id] = (
attr.asdict(
_ThreadSubscription(
automatic=automatic,
bump_stamp=stream_id,
)
)
)
else:
unsubscribed_threads.setdefault(room_id, {})[thread_root_id] = (
attr.asdict(_ThreadUnsubscription(bump_stamp=stream_id))
)
result: JsonDict = {}
if subscribed_threads:
result["subscribed"] = subscribed_threads
if unsubscribed_threads:
result["unsubscribed"] = unsubscribed_threads
if len(subscriptions) == limit:
# We hit the limit, so there might be more entries to return.
# Generate a new token that has moved backwards, ready for the next
# request.
min_returned_stream_id, _, _, _, _ = subscriptions[0]
result["end"] = ThreadSubscriptionsToken(
# We subtract one because the 'later in the stream' bound is inclusive,
# and we already saw the element at index 0.
stream_id=min_returned_stream_id - 1
).to_string()
return HTTPStatus.OK, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.experimental.msc4306_enabled:
ThreadSubscriptionsRestServlet(hs).register(http_server)
ThreadSubscriptionsPaginationRestServlet(hs).register(http_server)

View File

@@ -53,7 +53,7 @@ from synapse.storage.databases.main.stream import (
generate_pagination_where_clause,
)
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, MultiWriterStreamToken, StreamKeyType, StreamToken
from synapse.types import JsonDict, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -316,17 +316,8 @@ class RelationsWorkerStore(SQLBaseStore):
StreamKeyType.ROOM, next_key
)
else:
next_token = StreamToken(
room_key=next_key,
presence_key=0,
typing_key=0,
receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=MultiWriterStreamToken(stream=0),
groups_key=0,
un_partial_stated_rooms_key=0,
next_token = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM, next_key
)
return events[:limit], next_token

View File

@@ -492,7 +492,7 @@ class PerConnectionStateDB:
"""An equivalent to `PerConnectionState` that holds data in a format stored
in the DB.
The principle difference is that the tokens for the different streams are
The principal difference is that the tokens for the different streams are
serialized to strings.
When persisting this *only* contains updates to the state.

View File

@@ -505,6 +505,9 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
"""
return self._thread_subscriptions_id_gen.get_current_token()
def get_thread_subscriptions_stream_id_generator(self) -> MultiWriterIdGenerator:
return self._thread_subscriptions_id_gen
async def get_updated_thread_subscriptions(
self, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str, str]]:
@@ -538,34 +541,52 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
get_updated_thread_subscriptions_txn,
)
async def get_updated_thread_subscriptions_for_user(
async def get_latest_updated_thread_subscriptions_for_user(
self, user_id: str, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str]]:
"""Get updates to thread subscriptions for a specific user.
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
"""Get the latest updates to thread subscriptions for a specific user.
Args:
user_id: The ID of the user
from_id: The starting stream ID (exclusive)
to_id: The ending stream ID (inclusive)
limit: The maximum number of rows to return
If there are too many rows to return, rows from the start (closer to `from_id`)
will be omitted.
Returns:
A list of (stream_id, room_id, thread_root_event_id) tuples.
A list of (stream_id, room_id, thread_root_event_id, subscribed, automatic) tuples.
The row with lowest `stream_id` is the first row.
"""
def get_updated_thread_subscriptions_for_user_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str]]:
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
sql = """
SELECT stream_id, room_id, event_id
FROM thread_subscriptions
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
WITH the_updates AS (
SELECT stream_id, room_id, event_id, subscribed, automatic
FROM thread_subscriptions
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id DESC
LIMIT ?
)
SELECT stream_id, room_id, event_id, subscribed, automatic
FROM the_updates
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (user_id, from_id, to_id, limit))
return [(row[0], row[1], row[2]) for row in txn]
return [
(
stream_id,
room_id,
event_id,
# SQLite integer to boolean conversions
bool(subscribed),
bool(automatic) if subscribed else None,
)
for (stream_id, room_id, event_id, subscribed, automatic) in txn
]
return await self.db_pool.runInteraction(
"get_updated_thread_subscriptions_for_user",

View File

@@ -0,0 +1,19 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Work around https://github.com/element-hq/synapse/issues/18712 by advancing the
-- stream sequence.
-- This makes last_value of the sequence point to a position that will not get later
-- returned by nextval.
-- (For blank thread subscription streams, this means last_value = 2, nextval() = 3 after this line.)
SELECT nextval('thread_subscriptions_sequence');

View File

@@ -187,8 +187,12 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
Warning: Streams using this generator start at ID 2, because ID 1 is always assumed
to have been 'seen as persisted'.
Unclear if this extant behaviour is desirable for some reason.
When creating a new sequence for a new stream,
it will be necessary to use `START WITH 2`.
When creating a new sequence for a new stream, it will be necessary to advance it
so that position 1 is consumed.
DO NOT USE `START WITH 2` FOR THIS PURPOSE:
see https://github.com/element-hq/synapse/issues/18712
Instead, use `SELECT nextval('sequence_name');` immediately after the
`CREATE SEQUENCE` statement.
Args:
db_conn

View File

@@ -33,7 +33,6 @@ from synapse.logging.opentracing import trace
from synapse.streams import EventSource
from synapse.types import (
AbstractMultiWriterStreamToken,
MultiWriterStreamToken,
StreamKeyType,
StreamToken,
)
@@ -84,6 +83,7 @@ class EventSources:
un_partial_stated_rooms_key = self.store.get_un_partial_stated_rooms_token(
self._instance_name
)
thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id()
token = StreamToken(
room_key=self.sources.room.get_current_key(),
@@ -97,6 +97,7 @@ class EventSources:
# Groups key is unused.
groups_key=0,
un_partial_stated_rooms_key=un_partial_stated_rooms_key,
thread_subscriptions_key=thread_subscriptions_key,
)
return token
@@ -123,6 +124,7 @@ class EventSources:
StreamKeyType.TO_DEVICE: self.store.get_to_device_id_generator(),
StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(),
StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(),
}
for _, key in StreamKeyType.__members__.items():
@@ -195,16 +197,7 @@ class EventSources:
Returns:
The current token for pagination.
"""
token = StreamToken(
room_key=await self.sources.room.get_current_key_for_room(room_id),
presence_key=0,
typing_key=0,
receipt_key=MultiWriterStreamToken(stream=0),
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=MultiWriterStreamToken(stream=0),
groups_key=0,
un_partial_stated_rooms_key=0,
return StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
await self.sources.room.get_current_key_for_room(room_id),
)
return token

View File

@@ -996,6 +996,7 @@ class StreamKeyType(Enum):
TO_DEVICE = "to_device_key"
DEVICE_LIST = "device_list_key"
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
THREAD_SUBSCRIPTIONS = "thread_subscriptions_key"
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -1003,7 +1004,7 @@ class StreamToken:
"""A collection of keys joined together by underscores in the following
order and which represent the position in their respective streams.
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379`
ex. `s2633508_17_338_6732159_1082514_541479_274711_265584_1_379_4242`
1. `room_key`: `s2633508` which is a `RoomStreamToken`
- `RoomStreamToken`'s can also look like `t426-2633508` or `m56~2.58~3.59`
- See the docstring for `RoomStreamToken` for more details.
@@ -1016,6 +1017,7 @@ class StreamToken:
8. `device_list_key`: `265584`
9. `groups_key`: `1` (note that this key is now unused)
10. `un_partial_stated_rooms_key`: `379`
11. `thread_subscriptions_key`: 4242
You can see how many of these keys correspond to the various
fields in a "/sync" response:
@@ -1074,6 +1076,7 @@ class StreamToken:
# Note that the groups key is no longer used and may have bogus values.
groups_key: int
un_partial_stated_rooms_key: int
thread_subscriptions_key: int
_SEPARATOR = "_"
START: ClassVar["StreamToken"]
@@ -1101,6 +1104,7 @@ class StreamToken:
device_list_key,
groups_key,
un_partial_stated_rooms_key,
thread_subscriptions_key,
) = keys
return cls(
@@ -1116,6 +1120,7 @@ class StreamToken:
),
groups_key=int(groups_key),
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
thread_subscriptions_key=int(thread_subscriptions_key),
)
except CancelledError:
raise
@@ -1138,6 +1143,7 @@ class StreamToken:
# if additional tokens are added.
str(self.groups_key),
str(self.un_partial_stated_rooms_key),
str(self.thread_subscriptions_key),
]
)
@@ -1202,6 +1208,7 @@ class StreamToken:
StreamKeyType.TO_DEVICE,
StreamKeyType.TYPING,
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
StreamKeyType.THREAD_SUBSCRIPTIONS,
],
) -> int: ...
@@ -1257,7 +1264,8 @@ class StreamToken:
f"typing: {self.typing_key}, receipt: {self.receipt_key}, "
f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, "
f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, "
f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key})"
f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key},"
f"thread_subscriptions: {self.thread_subscriptions_key})"
)
@@ -1272,6 +1280,7 @@ StreamToken.START = StreamToken(
device_list_key=MultiWriterStreamToken(stream=0),
groups_key=0,
un_partial_stated_rooms_key=0,
thread_subscriptions_key=0,
)
@@ -1318,6 +1327,27 @@ class SlidingSyncStreamToken:
return f"{self.connection_position}/{stream_token_str}"
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadSubscriptionsToken:
"""
Token for a position in the thread subscriptions stream.
Format: `ts<stream_id>`
"""
stream_id: int
@staticmethod
def from_string(s: str) -> "ThreadSubscriptionsToken":
if not s.startswith("ts"):
raise ValueError("thread subscription token must start with `ts`")
return ThreadSubscriptionsToken(stream_id=int(s[2:]))
def to_string(self) -> str:
return f"ts{self.stream_id}"
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PersistedPosition:
"""Position of a newly persisted row with instance that persisted it."""

View File

@@ -50,6 +50,7 @@ from synapse.types import (
SlidingSyncStreamToken,
StrCollection,
StreamToken,
ThreadSubscriptionsToken,
UserID,
)
from synapse.types.rest.client import SlidingSyncBody
@@ -357,11 +358,50 @@ class SlidingSyncResult:
def __bool__(self) -> bool:
return bool(self.room_id_to_typing_map)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadSubscriptionsExtension:
"""The Thread Subscriptions extension (MSC4308)
Attributes:
subscribed: map (room_id -> thread_root_id -> info) of new or changed subscriptions
unsubscribed: map (room_id -> thread_root_id -> info) of new unsubscriptions
prev_batch: if present, there is a gap and the client can use this token to backpaginate
"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadSubscription:
# always present when `subscribed`
automatic: Optional[bool]
# the same as our stream_id; useful for clients to resolve
# race conditions locally
bump_stamp: int
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUnsubscription:
# the same as our stream_id; useful for clients to resolve
# race conditions locally
bump_stamp: int
# room_id -> event_id (of thread root) -> the subscription change
subscribed: Optional[Mapping[str, Mapping[str, ThreadSubscription]]]
# room_id -> event_id (of thread root) -> the unsubscription
unsubscribed: Optional[Mapping[str, Mapping[str, ThreadUnsubscription]]]
prev_batch: Optional[ThreadSubscriptionsToken]
def __bool__(self) -> bool:
return (
bool(self.subscribed)
or bool(self.unsubscribed)
or bool(self.prev_batch)
)
to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
account_data: Optional[AccountDataExtension] = None
receipts: Optional[ReceiptsExtension] = None
typing: Optional[TypingExtension] = None
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None
def __bool__(self) -> bool:
return bool(
@@ -370,6 +410,7 @@ class SlidingSyncResult:
or self.account_data
or self.receipts
or self.typing
or self.thread_subscriptions
)
next_pos: SlidingSyncStreamToken

View File

@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from synapse._pydantic_compat import (
Extra,
Field,
StrictBool,
StrictInt,
StrictStr,
@@ -364,11 +365,25 @@ class SlidingSyncBody(RequestBodyModel):
# Process all room subscriptions defined in the Room Subscription API. (This is the default.)
rooms: Optional[List[StrictStr]] = ["*"]
class ThreadSubscriptionsExtension(RequestBodyModel):
"""The Thread Subscriptions extension (MSC4308)
Attributes:
enabled
limit: maximum number of subscription changes to return (default 100)
"""
enabled: Optional[StrictBool] = False
limit: StrictInt = 100
to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
account_data: Optional[AccountDataExtension] = None
receipts: Optional[ReceiptsExtension] = None
typing: Optional[TypingExtension] = None
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field(
alias="io.element.msc4308.thread_subscriptions"
)
conn_id: Optional[StrictStr]

View File

@@ -347,6 +347,7 @@ T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
@overload
@@ -461,6 +462,23 @@ async def gather_optional_coroutines(
) -> Tuple[Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5]]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
Optional[Coroutine[Any, Any, T3]],
Optional[Coroutine[Any, Any, T4]],
Optional[Coroutine[Any, Any, T5]],
Optional[Coroutine[Any, Any, T6]],
]
],
) -> Tuple[
Optional[T1], Optional[T2], Optional[T3], Optional[T4], Optional[T5], Optional[T6]
]: ...
async def gather_optional_coroutines(
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
) -> Tuple[Optional[T1], ...]:

View File

@@ -2244,7 +2244,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
def test_topo_token_is_accepted(self) -> None:
"""Test Topo Token is accepted."""
token = "t1-0_0_0_0_0_0_0_0_0_0"
token = "t1-0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
@@ -2258,7 +2258,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
"""Test that stream token is accepted for forward pagination."""
token = "s0_0_0_0_0_0_0_0_0_0"
token = "s0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET",
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),

View File

@@ -0,0 +1,497 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
import logging
from http import HTTPStatus
from typing import List, Optional, Tuple, cast
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.rest.client import login, room, sync, thread_subscriptions
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase
logger = logging.getLogger(__name__)
# The name of the extension. Currently unstable-prefixed.
EXT_NAME = "io.element.msc4308.thread_subscriptions"
class SlidingSyncThreadSubscriptionsExtensionTestCase(SlidingSyncBase):
"""
Test the thread subscriptions extension in the Sliding Sync API.
"""
maxDiff = None
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
room.register_servlets,
sync.register_servlets,
thread_subscriptions.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = {"msc4306_enabled": True}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.storage_controllers = hs.get_storage_controllers()
super().prepare(reactor, clock, hs)
def test_no_data_initial_sync(self) -> None:
"""
Test enabling thread subscriptions extension during initial sync with no data.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
self.assertNotIn(EXT_NAME, response_body["extensions"])
def test_no_data_incremental_sync(self) -> None:
"""
Test enabling thread subscriptions extension during incremental sync with no data.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
initial_sync_body: JsonDict = {
"lists": {},
}
# Initial sync
response_body, sync_pos = self.do_sync(initial_sync_body, tok=user1_tok)
# Incremental sync with extension enabled
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
# Assert
self.assertNotIn(
EXT_NAME,
response_body["extensions"],
response_body,
)
def test_thread_subscription_initial_sync(self) -> None:
"""
Test thread subscriptions appear in initial sync response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
self._subscribe_to_thread(user1_id, room_id, thread_root_id)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
self.assertEqual(
response_body["extensions"][EXT_NAME],
{
"subscribed": {
room_id: {
thread_root_id: {
"automatic": False,
"bump_stamp": base + 1,
}
}
}
},
)
def test_thread_subscription_incremental_sync(self) -> None:
"""
Test new thread subscriptions appear in incremental sync response.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
# Initial sync
_, sync_pos = self.do_sync(sync_body, tok=user1_tok)
logger.info("Synced to: %r, now subscribing to thread", sync_pos)
# Subscribe
self._subscribe_to_thread(user1_id, room_id, thread_root_id)
# Incremental sync
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
logger.info("Synced to: %r", sync_pos)
# Assert
self.assertEqual(
response_body["extensions"][EXT_NAME],
{
"subscribed": {
room_id: {
thread_root_id: {
"automatic": False,
"bump_stamp": base + 1,
}
}
}
},
)
def test_unsubscribe_from_thread(self) -> None:
"""
Test unsubscribing from a thread.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok)
thread_root_id = thread_root_resp["event_id"]
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
self._subscribe_to_thread(user1_id, room_id, thread_root_id)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok)
# Assert: Subscription present
self.assertIn(EXT_NAME, response_body["extensions"])
self.assertEqual(
response_body["extensions"][EXT_NAME],
{
"subscribed": {
room_id: {
thread_root_id: {"automatic": False, "bump_stamp": base + 1}
}
}
},
)
# Unsubscribe
self._unsubscribe_from_thread(user1_id, room_id, thread_root_id)
# Incremental sync
response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos)
# Assert: Unsubscription present
self.assertEqual(
response_body["extensions"][EXT_NAME],
{"unsubscribed": {room_id: {thread_root_id: {"bump_stamp": base + 2}}}},
)
def test_multiple_thread_subscriptions(self) -> None:
"""
Test handling of multiple thread subscriptions.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create thread roots
thread_root_resp1 = self.helper.send(
room_id, body="Thread root 1", tok=user1_tok
)
thread_root_id1 = thread_root_resp1["event_id"]
thread_root_resp2 = self.helper.send(
room_id, body="Thread root 2", tok=user1_tok
)
thread_root_id2 = thread_root_resp2["event_id"]
thread_root_resp3 = self.helper.send(
room_id, body="Thread root 3", tok=user1_tok
)
thread_root_id3 = thread_root_resp3["event_id"]
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
# Subscribe to threads
self._subscribe_to_thread(user1_id, room_id, thread_root_id1)
self._subscribe_to_thread(user1_id, room_id, thread_root_id2)
self._subscribe_to_thread(user1_id, room_id, thread_root_id3)
sync_body = {
"lists": {},
"extensions": {
EXT_NAME: {
"enabled": True,
}
},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
self.assertEqual(
response_body["extensions"][EXT_NAME],
{
"subscribed": {
room_id: {
thread_root_id1: {
"automatic": False,
"bump_stamp": base + 1,
},
thread_root_id2: {
"automatic": False,
"bump_stamp": base + 2,
},
thread_root_id3: {
"automatic": False,
"bump_stamp": base + 3,
},
}
}
},
)
def test_limit_parameter(self) -> None:
"""
Test limit parameter in thread subscriptions extension.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create 5 thread roots and subscribe to each
thread_root_ids = []
for i in range(5):
thread_root_resp = self.helper.send(
room_id, body=f"Thread root {i}", tok=user1_tok
)
thread_root_ids.append(thread_root_resp["event_id"])
self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1])
sync_body = {
"lists": {},
"extensions": {EXT_NAME: {"enabled": True, "limit": 3}},
}
# Sync
response_body, _ = self.do_sync(sync_body, tok=user1_tok)
# Assert
thread_subscriptions = response_body["extensions"][EXT_NAME]
self.assertEqual(
len(thread_subscriptions["subscribed"][room_id]), 3, thread_subscriptions
)
def test_limit_and_companion_backpagination(self) -> None:
"""
Create 1 thread subscription, do a sync, create 4 more,
then sync with a limit of 2 and fill in the gap
using the companion /thread_subscriptions endpoint.
"""
thread_root_ids: List[str] = []
def make_subscription() -> None:
thread_root_resp = self.helper.send(
room_id, body="Some thread root", tok=user1_tok
)
thread_root_ids.append(thread_root_resp["event_id"])
self._subscribe_to_thread(user1_id, room_id, thread_root_ids[-1])
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# get the baseline stream_id of the thread_subscriptions stream
# before we write any data.
# Required because the initial value differs between SQLite and Postgres.
base = self.store.get_max_thread_subscriptions_stream_id()
# Make our first subscription
make_subscription()
# Sync for the first time
sync_body = {
"lists": {},
"extensions": {EXT_NAME: {"enabled": True, "limit": 2}},
}
sync_resp, first_sync_pos = self.do_sync(sync_body, tok=user1_tok)
thread_subscriptions = sync_resp["extensions"][EXT_NAME]
self.assertEqual(
thread_subscriptions["subscribed"],
{
room_id: {
thread_root_ids[0]: {"automatic": False, "bump_stamp": base + 1},
}
},
)
# Get our pos for the next sync
first_sync_pos = sync_resp["pos"]
# Create 5 more thread subscriptions and subscribe to each
for _ in range(5):
make_subscription()
# Now sync again. Our limit is 2,
# so we should get the latest 2 subscriptions,
# with a gap of 3 more subscriptions in the middle
sync_resp, _pos = self.do_sync(sync_body, tok=user1_tok, since=first_sync_pos)
thread_subscriptions = sync_resp["extensions"][EXT_NAME]
self.assertEqual(
thread_subscriptions["subscribed"],
{
room_id: {
thread_root_ids[4]: {"automatic": False, "bump_stamp": base + 5},
thread_root_ids[5]: {"automatic": False, "bump_stamp": base + 6},
}
},
)
# 1st backpagination: expecting a page with 2 subscriptions
page, end_tok = self._do_backpaginate(
from_tok=thread_subscriptions["prev_batch"],
to_tok=first_sync_pos,
limit=2,
access_token=user1_tok,
)
self.assertIsNotNone(end_tok, "backpagination should continue")
self.assertEqual(
page["subscribed"],
{
room_id: {
thread_root_ids[2]: {"automatic": False, "bump_stamp": base + 3},
thread_root_ids[3]: {"automatic": False, "bump_stamp": base + 4},
}
},
)
# 2nd backpagination: expecting a page with only 1 subscription
# and no other token for further backpagination
assert end_tok is not None
page, end_tok = self._do_backpaginate(
from_tok=end_tok, to_tok=first_sync_pos, limit=2, access_token=user1_tok
)
self.assertIsNone(end_tok, "backpagination should have finished")
self.assertEqual(
page["subscribed"],
{
room_id: {
thread_root_ids[1]: {"automatic": False, "bump_stamp": base + 2},
}
},
)
def _do_backpaginate(
self, *, from_tok: str, to_tok: str, limit: int, access_token: str
) -> Tuple[JsonDict, Optional[str]]:
channel = self.make_request(
"GET",
"/_matrix/client/unstable/io.element.msc4308/thread_subscriptions"
f"?from={from_tok}&to={to_tok}&limit={limit}&dir=b",
access_token=access_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
body = channel.json_body
return body, cast(Optional[str], body.get("end"))
def _subscribe_to_thread(
self, user_id: str, room_id: str, thread_root_id: str
) -> None:
"""
Helper method to subscribe a user to a thread.
"""
self.get_success(
self.store.subscribe_user_to_thread(
user_id=user_id,
room_id=room_id,
thread_root_event_id=thread_root_id,
automatic_event_orderings=None,
)
)
def _unsubscribe_from_thread(
self, user_id: str, room_id: str, thread_root_id: str
) -> None:
"""
Helper method to unsubscribe a user from a thread.
"""
self.get_success(
self.store.unsubscribe_user_from_thread(
user_id=user_id,
room_id=room_id,
thread_root_event_id=thread_root_id,
)
)

View File

@@ -18,6 +18,8 @@
# [This file includes modifications made by New Vector Limited]
#
#
from http import HTTPStatus
import synapse
from synapse.api.errors import Codes
from synapse.rest.client import login, push_rule, room
@@ -486,3 +488,23 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
},
channel.json_body,
)
def test_no_user_defined_postcontent_rules(self) -> None:
"""
Tests that clients are not permitted to create MSC4306 `postcontent` rules.
"""
self.register_user("bob", "pass")
token = self.login("bob", "pass")
channel = self.make_request(
"PUT",
"/pushrules/global/postcontent/some.user.rule",
{},
access_token=token,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
self.assertEqual(
Codes.INVALID_PARAM,
channel.json_body["errcode"],
)

View File

@@ -2245,7 +2245,7 @@ class RoomMessageListTestCase(RoomBase):
self.room_id = self.helper.create_room_as(self.user_id)
def test_topo_token_is_accepted(self) -> None:
token = "t1-0_0_0_0_0_0_0_0_0_0"
token = "t1-0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)
@@ -2256,7 +2256,7 @@ class RoomMessageListTestCase(RoomBase):
self.assertTrue("end" in channel.json_body)
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
token = "s0_0_0_0_0_0_0_0_0_0"
token = "s0_0_0_0_0_0_0_0_0_0_0"
channel = self.make_request(
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
)

View File

@@ -189,19 +189,19 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
self._subscribe(self.other_thread_root_id, automatic_event_orderings=None)
subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
self.user_id,
from_id=0,
to_id=50,
limit=50,
)
)
min_id = min(id for (id, _, _) in subscriptions)
min_id = min(id for (id, _, _, _, _) in subscriptions)
self.assertEqual(
subscriptions,
[
(min_id, self.room_id, self.thread_root_id),
(min_id + 1, self.room_id, self.other_thread_root_id),
(min_id, self.room_id, self.thread_root_id, True, True),
(min_id + 1, self.room_id, self.other_thread_root_id, True, False),
],
)
@@ -212,7 +212,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Check user has no subscriptions
subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
self.user_id,
from_id=0,
to_id=50,
@@ -280,20 +280,22 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Get updates for main user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
self.user_id, from_id=0, to_id=stream_id2, limit=10
)
)
self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)])
self.assertEqual(
updates, [(stream_id1, self.room_id, self.thread_root_id, True, True)]
)
# Get updates for other user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.store.get_latest_updated_thread_subscriptions_for_user(
other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10
)
)
self.assertEqual(
updates, [(stream_id2, self.room_id, self.other_thread_root_id)]
updates, [(stream_id2, self.room_id, self.other_thread_root_id, True, True)]
)
def test_should_skip_autosubscription_after_unsubscription(self) -> None: