Update implementation of MSC4306: Thread Subscriptions to include automatic subscription conflict prevention as introduced in later drafts. (#18756)

Follows: #18674

Implements new drafts of MSC4306

---------

Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
Co-authored-by: Eric Eastwood <erice@element.io>
This commit is contained in:
reivilibre
2025-08-05 19:22:53 +01:00
committed by GitHub
parent 076db0ab49
commit 8306cee06a
14 changed files with 586 additions and 98 deletions

1
changelog.d/18756.misc Normal file
View File

@@ -0,0 +1 @@
Update implementation of [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306) to include automatic subscription conflict prevention as introduced in later drafts.

View File

@@ -140,6 +140,12 @@ class Codes(str, Enum):
# Part of MSC4155
INVITE_BLOCKED = "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED"
# Part of MSC4306: Thread Subscriptions
MSC4306_CONFLICTING_UNSUBSCRIPTION = (
"IO.ELEMENT.MSC4306.M_CONFLICTING_UNSUBSCRIPTION"
)
MSC4306_NOT_IN_THREAD = "IO.ELEMENT.MSC4306.M_NOT_IN_THREAD"
class CodeMessageException(RuntimeError):
"""An exception with integer code, a message string attributes and optional headers.

View File

@@ -1,9 +1,15 @@
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import AuthError, NotFoundError
from synapse.storage.databases.main.thread_subscriptions import ThreadSubscription
from synapse.types import UserID
from synapse.api.constants import RelationTypes
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.events import relation_from_event
from synapse.storage.databases.main.thread_subscriptions import (
AutomaticSubscriptionConflicted,
ThreadSubscription,
)
from synapse.types import EventOrderings, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -55,42 +61,79 @@ class ThreadSubscriptionsHandler:
room_id: str,
thread_root_event_id: str,
*,
automatic: bool,
automatic_event_id: Optional[str],
) -> Optional[int]:
"""Sets or updates a user's subscription settings for a specific thread root.
Args:
requester_user_id: The ID of the user whose settings are being updated.
thread_root_event_id: The event ID of the thread root.
automatic: whether the user was subscribed by an automatic decision by
their client.
automatic_event_id: if the user was subscribed by an automatic decision by
their client, the event ID that caused this.
Returns:
The stream ID for this update, if the update isn't no-opped.
Raises:
NotFoundError if the user cannot access the thread root event, or it isn't
known to this homeserver.
known to this homeserver. Ditto for the automatic cause event if supplied.
SynapseError(400, M_NOT_IN_THREAD): if client supplied an automatic cause event
but user cannot access the event.
SynapseError(409, M_SKIPPED): if client requested an automatic subscription
but it was skipped because the cause event is logically later than an unsubscription.
"""
# First check that the user can access the thread root event
# and that it exists
try:
event = await self.event_handler.get_event(
thread_root_event = await self.event_handler.get_event(
user_id, room_id, thread_root_event_id
)
if event is None:
if thread_root_event is None:
raise NotFoundError("No such thread root")
except AuthError:
logger.info("rejecting thread subscriptions change (thread not accessible)")
raise NotFoundError("No such thread root")
return await self.store.subscribe_user_to_thread(
if automatic_event_id:
autosub_cause_event = await self.event_handler.get_event(
user_id, room_id, automatic_event_id
)
if autosub_cause_event is None:
raise NotFoundError("Automatic subscription event not found")
relation = relation_from_event(autosub_cause_event)
if (
relation is None
or relation.rel_type != RelationTypes.THREAD
or relation.parent_id != thread_root_event_id
):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Automatic subscription must use an event in the thread",
errcode=Codes.MSC4306_NOT_IN_THREAD,
)
automatic_event_orderings = EventOrderings.from_event(autosub_cause_event)
else:
automatic_event_orderings = None
outcome = await self.store.subscribe_user_to_thread(
user_id.to_string(),
event.room_id,
room_id,
thread_root_event_id,
automatic=automatic,
automatic_event_orderings=automatic_event_orderings,
)
if isinstance(outcome, AutomaticSubscriptionConflicted):
raise SynapseError(
HTTPStatus.CONFLICT,
"Automatic subscription obsoleted by an unsubscription request.",
errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION,
)
return outcome
async def unsubscribe_user_from_thread(
self, user_id: UserID, room_id: str, thread_root_event_id: str
) -> Optional[int]:

View File

@@ -739,7 +739,7 @@ class ThreadSubscriptionsStream(_StreamFromIdGen):
NAME = "thread_subscriptions"
ROW_TYPE = ThreadSubscriptionsStreamRow
def __init__(self, hs: Any):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
@@ -751,7 +751,7 @@ class ThreadSubscriptionsStream(_StreamFromIdGen):
self, instance_name: str, from_token: int, to_token: int, limit: int
) -> StreamUpdateResult:
updates = await self.store.get_updated_thread_subscriptions(
from_token, to_token, limit
from_id=from_token, to_id=to_token, limit=limit
)
rows = [
(

View File

@@ -1,7 +1,6 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
from synapse._pydantic_compat import StrictBool
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
@@ -12,6 +11,7 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict, RoomID
from synapse.types.rest import RequestBodyModel
from synapse.util.pydantic_models import AnyEventId
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -32,7 +32,12 @@ class ThreadSubscriptionsRestServlet(RestServlet):
self.handler = hs.get_thread_subscriptions_handler()
class PutBody(RequestBodyModel):
automatic: StrictBool
automatic: Optional[AnyEventId]
"""
If supplied, the event ID of an event giving rise to this automatic subscription.
If omitted, this subscription is a manual subscription.
"""
async def on_GET(
self, request: SynapseRequest, room_id: str, thread_root_id: str
@@ -63,15 +68,15 @@ class ThreadSubscriptionsRestServlet(RestServlet):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM
)
requester = await self.auth.get_user_by_req(request)
body = parse_and_validate_json_object_from_request(request, self.PutBody)
requester = await self.auth.get_user_by_req(request)
await self.handler.subscribe_user_to_thread(
requester.user,
room_id,
thread_root_id,
automatic=body.automatic,
automatic_event_id=body.automatic,
)
return HTTPStatus.OK, {}

View File

@@ -14,7 +14,6 @@ import logging
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Optional,
@@ -33,6 +32,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import EventOrderings
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -50,6 +50,14 @@ class ThreadSubscription:
"""
class AutomaticSubscriptionConflicted:
"""
Marker return value to signal that an automatic subscription was skipped,
because it conflicted with an unsubscription that we consider to have
been made later than the event causing the automatic subscription.
"""
class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -101,61 +109,172 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
self._thread_subscriptions_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
@staticmethod
def _should_skip_autosubscription_after_unsubscription(
*,
autosub: EventOrderings,
unsubscribed_at: EventOrderings,
) -> bool:
"""
Returns whether an automatic subscription occurring *after* an unsubscription
should be skipped, because the unsubscription already 'acknowledges' the event
causing the automatic subscription (the cause event).
To determine *after*, we use `stream_ordering` unless the event is backfilled
(negative `stream_ordering`) and fallback to topological ordering.
Args:
autosub: the stream_ordering and topological_ordering of the cause event
unsubscribed_at:
the maximum stream ordering and the maximum topological ordering at the time of unsubscription
Returns:
True if the automatic subscription should be skipped
"""
# For normal rooms, these two orderings should be positive, because
# they don't refer to a specific event but rather the maximum at the
# time of unsubscription.
#
# However, for rooms that have never been joined and that are being peeked at,
# we might not have a single non-backfilled event and therefore the stream
# ordering might be negative, so we don't assert this case.
assert unsubscribed_at.topological > 0
unsubscribed_at_backfilled = unsubscribed_at.stream < 0
if (
not unsubscribed_at_backfilled
and unsubscribed_at.stream >= autosub.stream > 0
):
# non-backfilled events: the unsubscription is later according to
# the stream
return True
if autosub.stream < 0:
# the auto-subscription cause event was backfilled, so fall back to
# topological ordering
if unsubscribed_at.topological >= autosub.topological:
return True
return False
async def subscribe_user_to_thread(
self, user_id: str, room_id: str, thread_root_event_id: str, *, automatic: bool
) -> Optional[int]:
self,
user_id: str,
room_id: str,
thread_root_event_id: str,
*,
automatic_event_orderings: Optional[EventOrderings],
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
"""Updates a user's subscription settings for a specific thread root.
If no change would be made to the subscription, does not produce any database change.
Case-by-case:
- if we already have an automatic subscription:
- new automatic subscriptions will be no-ops (no database write),
- new manual subscriptions will overwrite the automatic subscription
- if we already have a manual subscription:
we don't update (no database write) in either case, because:
- the existing manual subscription wins over a new automatic subscription request
- there would be no need to write a manual subscription because we already have one
Args:
user_id: The ID of the user whose settings are being updated.
room_id: The ID of the room the thread root belongs to.
thread_root_event_id: The event ID of the thread root.
automatic: Whether the subscription was performed automatically by the user's client.
Only `False` will overwrite an existing value of automatic for a subscription row.
automatic_event_orderings:
Value depends on whether the subscription was performed automatically by the user's client.
For manual subscriptions: None.
For automatic subscriptions: the orderings of the event.
Returns:
The stream ID for this update, if the update isn't no-opped.
If a subscription is made: (int) the stream ID for this update.
If a subscription already exists and did not need to be updated: None
If an automatic subscription conflicted with an unsubscription: AutomaticSubscriptionConflicted
"""
assert self._can_write_to_thread_subscriptions
def _subscribe_user_to_thread_txn(txn: LoggingTransaction) -> Optional[int]:
already_automatic = self.db_pool.simple_select_one_onecol_txn(
def _subscribe_user_to_thread_txn(
txn: LoggingTransaction,
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
requested_automatic = automatic_event_orderings is not None
row = self.db_pool.simple_select_one_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
"subscribed": True,
},
retcol="automatic",
retcols=(
"subscribed",
"automatic",
"unsubscribed_at_stream_ordering",
"unsubscribed_at_topological_ordering",
),
allow_none=True,
)
if already_automatic is None:
already_subscribed = False
already_automatic = True
else:
already_subscribed = True
# convert int (SQLite bool) to Python bool
already_automatic = bool(already_automatic)
if row is None:
# We have never subscribed before, simply insert the row and finish
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
self.db_pool.simple_insert_txn(
txn,
table="thread_subscriptions",
values={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
"subscribed": True,
"stream_id": stream_id,
"instance_name": self._instance_name,
"automatic": requested_automatic,
"unsubscribed_at_stream_ordering": None,
"unsubscribed_at_topological_ordering": None,
},
)
txn.call_after(
self.get_subscription_for_thread.invalidate,
(user_id, room_id, thread_root_event_id),
)
return stream_id
if already_subscribed and already_automatic == automatic:
# there is nothing we need to do here
# we already have either a subscription or a prior unsubscription here
(
subscribed,
already_automatic,
unsubscribed_at_stream_ordering,
unsubscribed_at_topological_ordering,
) = row
if subscribed and (not already_automatic or requested_automatic):
# we are already subscribed and the current subscription state
# is good enough (either we already have a manual subscription,
# or we requested an automatic subscription)
# In that case, nothing to change here.
# (See docstring for case-by-case explanation)
return None
if not subscribed and requested_automatic:
assert automatic_event_orderings is not None
# we previously unsubscribed and we are now automatically subscribing
# Check whether the new autosubscription should be skipped
if ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription(
autosub=automatic_event_orderings,
unsubscribed_at=EventOrderings(
unsubscribed_at_stream_ordering,
unsubscribed_at_topological_ordering,
),
):
# skip the subscription
return AutomaticSubscriptionConflicted()
# At this point: we have now finished checking that we need to make
# a subscription, updating the current row.
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
values: Dict[str, Optional[Union[bool, int, str]]] = {
"subscribed": True,
"stream_id": stream_id,
"instance_name": self._instance_name,
"automatic": already_automatic and automatic,
}
self.db_pool.simple_upsert_txn(
self.db_pool.simple_update_txn(
txn,
table="thread_subscriptions",
keyvalues={
@@ -163,9 +282,15 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
"event_id": thread_root_event_id,
"room_id": room_id,
},
values=values,
updatevalues={
"subscribed": True,
"stream_id": stream_id,
"instance_name": self._instance_name,
"automatic": requested_automatic,
"unsubscribed_at_stream_ordering": None,
"unsubscribed_at_topological_ordering": None,
},
)
txn.call_after(
self.get_subscription_for_thread.invalidate,
(user_id, room_id, thread_root_event_id),
@@ -214,6 +339,21 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
# Find the maximum stream ordering and topological ordering of the room,
# which we then store against this unsubscription so we can skip future
# automatic subscriptions that are caused by an event logically earlier
# than this unsubscription.
txn.execute(
"""
SELECT MAX(stream_ordering) AS mso, MAX(topological_ordering) AS mto FROM events
WHERE room_id = ?
""",
(room_id,),
)
ord_row = txn.fetchone()
assert ord_row is not None
max_stream_ordering, max_topological_ordering = ord_row
self.db_pool.simple_update_txn(
txn,
table="thread_subscriptions",
@@ -227,6 +367,8 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
"subscribed": False,
"stream_id": stream_id,
"instance_name": self._instance_name,
"unsubscribed_at_stream_ordering": max_stream_ordering,
"unsubscribed_at_topological_ordering": max_topological_ordering,
},
)
@@ -316,7 +458,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
return self._thread_subscriptions_id_gen.get_current_token()
async def get_updated_thread_subscriptions(
self, from_id: int, to_id: int, limit: int
self, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str, str]]:
"""Get updates to thread subscriptions between two stream IDs.
@@ -349,7 +491,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
)
async def get_updated_thread_subscriptions_for_user(
self, user_id: str, from_id: int, to_id: int, limit: int
self, user_id: str, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str]]:
"""Get updates to thread subscriptions for a specific user.

View File

@@ -0,0 +1,20 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- The maximum stream_ordering in the room when the unsubscription was made.
ALTER TABLE thread_subscriptions
ADD COLUMN unsubscribed_at_stream_ordering BIGINT;
-- The maximum topological_ordering in the room when the unsubscription was made.
ALTER TABLE thread_subscriptions
ADD COLUMN unsubscribed_at_topological_ordering BIGINT;

View File

@@ -0,0 +1,18 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_stream_ordering IS
$$The maximum stream_ordering in the room when the unsubscription was made.$$;
COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_topological_ordering IS
$$The maximum topological_ordering in the room when the unsubscription was made.$$;

View File

@@ -73,6 +73,7 @@ if TYPE_CHECKING:
from typing_extensions import Self
from synapse.appservice.api import ApplicationService
from synapse.events import EventBase
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -1464,3 +1465,31 @@ class ScheduledTask:
result: Optional[JsonMapping]
# Optional error that should be assigned a value when the status is FAILED
error: Optional[str]
@attr.s(auto_attribs=True, frozen=True, slots=True)
class EventOrderings:
stream: int
"""
The stream_ordering of the event.
Negative numbers mean the event was backfilled.
"""
topological: int
"""
The topological_ordering of the event.
Currently this is equivalent to the `depth` attributes of
the PDU.
"""
@staticmethod
def from_event(event: "EventBase") -> "EventOrderings":
"""
Get the orderings from an event.
Preconditions:
- the event must have been persisted (otherwise it won't have a stream ordering)
"""
stream = event.internal_metadata.stream_ordering
assert stream is not None
return EventOrderings(stream, event.depth)

View File

@@ -13,7 +13,11 @@
#
#
from synapse._pydantic_compat import BaseModel, Extra
import re
from typing import Any, Callable, Generator
from synapse._pydantic_compat import BaseModel, Extra, StrictStr
from synapse.types import EventID
class ParseModel(BaseModel):
@@ -37,3 +41,43 @@ class ParseModel(BaseModel):
extra = Extra.ignore
# By default, don't allow fields to be reassigned after parsing.
allow_mutation = False
class AnyEventId(StrictStr):
"""
A validator for strings that need to be an Event ID.
Accepts any valid grammar of Event ID from any room version.
"""
EVENT_ID_HASH_ROOM_VERSION_3_PLUS = re.compile(
r"^([a-zA-Z0-9-_]{43}|[a-zA-Z0-9+/]{43})$"
)
@classmethod
def __get_validators__(cls) -> Generator[Callable[..., Any], Any, Any]:
yield from super().__get_validators__() # type: ignore
yield cls.validate_event_id
@classmethod
def validate_event_id(cls, value: str) -> str:
if not value.startswith("$"):
raise ValueError("Event ID must start with `$`")
if ":" in value:
# Room versions 1 and 2
EventID.from_string(value) # throws on fail
else:
# Room versions 3+: event ID is $ + a base64 sha256 hash
# Room version 3 is base64, 4+ are base64Url
# In both cases, the base64 is unpadded.
# refs:
# - https://spec.matrix.org/v1.15/rooms/v3/ e.g. $acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk
# - https://spec.matrix.org/v1.15/rooms/v4/ e.g. $Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg
b64_hash = value[1:]
if cls.EVENT_ID_HASH_ROOM_VERSION_3_PLUS.fullmatch(b64_hash) is None:
raise ValueError(
"Event ID must either have a domain part or be a valid hash"
)
return value

View File

@@ -62,7 +62,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
"@test_user:example.org",
room_id,
thread_root_id,
automatic=True,
automatic_event_orderings=None,
)
)
updates.append(thread_root_id)
@@ -75,7 +75,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
"@test_user:example.org",
other_room_id,
other_thread_root_id,
automatic=False,
automatic_event_orderings=None,
)
)
@@ -124,7 +124,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
for user_id in users:
self.get_success(
store.subscribe_user_to_thread(
user_id, room_id, thread_root_id, automatic=True
user_id, room_id, thread_root_id, automatic_event_orderings=None
)
)

View File

@@ -15,6 +15,7 @@ from http import HTTPStatus
from twisted.internet.testing import MemoryReactor
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import login, profile, room, thread_subscriptions
from synapse.server import HomeServer
@@ -49,15 +50,16 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Create a room and send a message to use as a thread root
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
self.helper.join(self.room_id, self.other_user_id, tok=self.other_token)
response = self.helper.send(self.room_id, body="Root message", tok=self.token)
self.root_event_id = response["event_id"]
(self.root_event_id,) = self.helper.send_messages(
self.room_id, 1, tok=self.token
)
# Send a message in the thread
self.helper.send_event(
room_id=self.room_id,
type="m.room.message",
content={
"body": "Thread message",
self.threaded_events = self.helper.send_messages(
self.room_id,
2,
content_fn=lambda idx: {
"body": f"Thread message {idx}",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": "m.thread",
@@ -106,9 +108,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": False,
},
{},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -127,7 +127,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": True},
{"automatic": self.threaded_events[0]},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -148,11 +148,11 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": True,
"automatic": self.threaded_events[0],
},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body)
# Assert the subscription was saved
channel = self.make_request(
@@ -167,7 +167,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": False},
{},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -187,7 +187,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": True,
"automatic": self.threaded_events[0],
},
access_token=self.token,
)
@@ -202,7 +202,6 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.json_body, {"automatic": True})
# Now also register a manual subscription
channel = self.make_request(
"DELETE",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
@@ -210,7 +209,6 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK)
# Assert the manual subscription was not overridden
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
@@ -224,7 +222,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription",
{"automatic": True},
{},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
@@ -238,7 +236,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": True},
{},
access_token=no_access_token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
@@ -249,8 +247,105 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
# non-boolean `automatic`
{"automatic": "true"},
# non-Event ID `automatic`
{"automatic": True},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
# non-Event ID `automatic`
{"automatic": "$malformedEventId"},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
def test_auto_subscribe_cause_event_not_in_thread(self) -> None:
"""
Test making an automatic subscription, where the cause event is not
actually in the thread.
This is an error.
"""
(unrelated_event_id,) = self.helper.send_messages(
self.room_id, 1, tok=self.token
)
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": unrelated_event_id},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.text_body)
self.assertEqual(channel.json_body["errcode"], Codes.MSC4306_NOT_IN_THREAD)
def test_auto_resubscription_conflict(self) -> None:
"""
Test that an automatic subscription that conflicts with an unsubscription
is skipped.
"""
# Reuse the test that subscribes and unsubscribes
self.test_unsubscribe()
# Now no matter which event we present as the cause of an automatic subscription,
# the automatic subscription is skipped.
# This is because the unsubscription happened after all of the events.
for event in self.threaded_events:
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": event,
},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.CONFLICT, channel.text_body)
self.assertEqual(
channel.json_body["errcode"],
Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION,
channel.text_body,
)
# Check the subscription was not made
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
# But if a new event is sent after the unsubscription took place,
# that one can be used for an automatic subscription
(later_event_id,) = self.helper.send_messages(
self.room_id,
1,
content_fn=lambda _: {
"body": "Thread message after unsubscription",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": "m.thread",
"event_id": self.root_event_id,
},
},
tok=self.token,
)
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": later_event_id,
},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body)
# Check the subscription was made
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.json_body, {"automatic": True})

View File

@@ -29,12 +29,14 @@ from http import HTTPStatus
from typing import (
Any,
AnyStr,
Callable,
Dict,
Iterable,
Literal,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
overload,
)
@@ -45,7 +47,7 @@ import attr
from twisted.internet.testing import MemoryReactorClock
from twisted.web.server import Site
from synapse.api.constants import Membership, ReceiptTypes
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.errors import Codes
from synapse.server import HomeServer
from synapse.types import JsonDict
@@ -394,6 +396,32 @@ class RestHelper:
custom_headers=custom_headers,
)
def send_messages(
self,
room_id: str,
num_events: int,
content_fn: Callable[[int], JsonDict] = lambda idx: {
"msgtype": "m.text",
"body": f"Test event {idx}",
},
tok: Optional[str] = None,
) -> Sequence[str]:
"""
Helper to send a handful of sequential events and return their event IDs as a sequence.
"""
event_ids = []
for event_index in range(num_events):
response = self.send_event(
room_id,
EventTypes.Message,
content_fn(event_index),
tok=tok,
)
event_ids.append(response["event_id"])
return event_ids
def send_event(
self,
room_id: str,

View File

@@ -12,13 +12,18 @@
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
from typing import Optional
from typing import Optional, Union
from twisted.internet.testing import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.thread_subscriptions import (
AutomaticSubscriptionConflicted,
ThreadSubscriptionsWorkerStore,
)
from synapse.storage.engines.sqlite import Sqlite3Engine
from synapse.types import EventOrderings
from synapse.util import Clock
from tests import unittest
@@ -97,10 +102,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
self,
thread_root_id: str,
*,
automatic: bool,
automatic_event_orderings: Optional[EventOrderings],
room_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> Optional[int]:
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
if user_id is None:
user_id = self.user_id
@@ -112,7 +117,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
user_id,
room_id,
thread_root_id,
automatic=automatic,
automatic_event_orderings=automatic_event_orderings,
)
)
@@ -149,7 +154,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Subscribe
self._subscribe(
self.thread_root_id,
automatic=True,
automatic_event_orderings=EventOrderings(1, 1),
)
# Assert subscription went through
@@ -164,7 +169,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Now make it a manual subscription
self._subscribe(
self.thread_root_id,
automatic=False,
automatic_event_orderings=None,
)
# Assert the manual subscription overrode the automatic one
@@ -178,8 +183,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
def test_purge_thread_subscriptions_for_user(self) -> None:
"""Test purging all thread subscription settings for a user."""
# Set subscription settings for multiple threads
self._subscribe(self.thread_root_id, automatic=True)
self._subscribe(self.other_thread_root_id, automatic=False)
self._subscribe(
self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1)
)
self._subscribe(self.other_thread_root_id, automatic_event_orderings=None)
subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
@@ -217,20 +224,32 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
def test_get_updated_thread_subscriptions(self) -> None:
"""Test getting updated thread subscriptions since a stream ID."""
stream_id1 = self._subscribe(self.thread_root_id, automatic=False)
stream_id2 = self._subscribe(self.other_thread_root_id, automatic=True)
assert stream_id1 is not None
assert stream_id2 is not None
stream_id1 = self._subscribe(
self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1)
)
stream_id2 = self._subscribe(
self.other_thread_root_id, automatic_event_orderings=EventOrderings(2, 2)
)
assert stream_id1 is not None and not isinstance(
stream_id1, AutomaticSubscriptionConflicted
)
assert stream_id2 is not None and not isinstance(
stream_id2, AutomaticSubscriptionConflicted
)
# Get updates since initial ID (should include both changes)
updates = self.get_success(
self.store.get_updated_thread_subscriptions(0, stream_id2, 10)
self.store.get_updated_thread_subscriptions(
from_id=0, to_id=stream_id2, limit=10
)
)
self.assertEqual(len(updates), 2)
# Get updates since first change (should include only the second change)
updates = self.get_success(
self.store.get_updated_thread_subscriptions(stream_id1, stream_id2, 10)
self.store.get_updated_thread_subscriptions(
from_id=stream_id1, to_id=stream_id2, limit=10
)
)
self.assertEqual(
updates,
@@ -242,21 +261,27 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
other_user_id = "@other_user:test"
# Set thread subscription for main user
stream_id1 = self._subscribe(self.thread_root_id, automatic=True)
assert stream_id1 is not None
stream_id1 = self._subscribe(
self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1)
)
assert stream_id1 is not None and not isinstance(
stream_id1, AutomaticSubscriptionConflicted
)
# Set thread subscription for other user
stream_id2 = self._subscribe(
self.other_thread_root_id,
automatic=True,
automatic_event_orderings=EventOrderings(1, 1),
user_id=other_user_id,
)
assert stream_id2 is not None
assert stream_id2 is not None and not isinstance(
stream_id2, AutomaticSubscriptionConflicted
)
# Get updates for main user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.user_id, 0, stream_id2, 10
self.user_id, from_id=0, to_id=stream_id2, limit=10
)
)
self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)])
@@ -264,9 +289,41 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Get updates for other user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
other_user_id, 0, max(stream_id1, stream_id2), 10
other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10
)
)
self.assertEqual(
updates, [(stream_id2, self.room_id, self.other_thread_root_id)]
)
def test_should_skip_autosubscription_after_unsubscription(self) -> None:
"""
Tests the comparison logic for whether an autoscription should be skipped
due to a chronologically earlier but logically later unsubscription.
"""
func = ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription
# Order of arguments:
# automatic cause event: stream order, then topological order
# unsubscribe maximums: stream order, then tological order
# both orderings agree that the unsub is after the cause event
self.assertTrue(
func(autosub=EventOrderings(1, 1), unsubscribed_at=EventOrderings(2, 2))
)
# topological ordering is inconsistent with stream ordering,
# in that case favour stream ordering because it's what /sync uses
self.assertTrue(
func(autosub=EventOrderings(1, 2), unsubscribed_at=EventOrderings(2, 1))
)
# the automatic subscription is caused by a backfilled event here
# unfortunately we must fall back to topological ordering here
self.assertTrue(
func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 3))
)
self.assertFalse(
func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 1))
)