Merge branch 'develop' into release-v1.136

This commit is contained in:
Devon Hudson
2025-08-06 16:44:12 -06:00
37 changed files with 1145 additions and 380 deletions

1
changelog.d/18756.misc Normal file
View File

@@ -0,0 +1 @@
Update implementation of [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306) to include automatic subscription conflict prevention as introduced in later drafts.

View File

@@ -0,0 +1 @@
Implement the push rules for experimental [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306).

View File

@@ -61,6 +61,7 @@ fn bench_match_exact(b: &mut Bencher) {
vec![],
false,
false,
false,
)
.unwrap();
@@ -71,10 +72,10 @@ fn bench_match_exact(b: &mut Bencher) {
},
));
let matched = eval.match_condition(&condition, None, None).unwrap();
let matched = eval.match_condition(&condition, None, None, None).unwrap();
assert!(matched, "Didn't match");
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
b.iter(|| eval.match_condition(&condition, None, None, None).unwrap());
}
#[bench]
@@ -107,6 +108,7 @@ fn bench_match_word(b: &mut Bencher) {
vec![],
false,
false,
false,
)
.unwrap();
@@ -117,10 +119,10 @@ fn bench_match_word(b: &mut Bencher) {
},
));
let matched = eval.match_condition(&condition, None, None).unwrap();
let matched = eval.match_condition(&condition, None, None, None).unwrap();
assert!(matched, "Didn't match");
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
b.iter(|| eval.match_condition(&condition, None, None, None).unwrap());
}
#[bench]
@@ -153,6 +155,7 @@ fn bench_match_word_miss(b: &mut Bencher) {
vec![],
false,
false,
false,
)
.unwrap();
@@ -163,10 +166,10 @@ fn bench_match_word_miss(b: &mut Bencher) {
},
));
let matched = eval.match_condition(&condition, None, None).unwrap();
let matched = eval.match_condition(&condition, None, None, None).unwrap();
assert!(!matched, "Didn't match");
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
b.iter(|| eval.match_condition(&condition, None, None, None).unwrap());
}
#[bench]
@@ -199,6 +202,7 @@ fn bench_eval_message(b: &mut Bencher) {
vec![],
false,
false,
false,
)
.unwrap();
@@ -210,7 +214,8 @@ fn bench_eval_message(b: &mut Bencher) {
false,
false,
false,
false,
);
b.iter(|| eval.run(&rules, Some("bob"), Some("person")));
b.iter(|| eval.run(&rules, Some("bob"), Some("person"), None));
}

View File

@@ -290,6 +290,26 @@ pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule {
}];
pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
PushRule {
rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.unsubscribed_thread"),
priority_class: 1,
conditions: Cow::Borrowed(&[Condition::Known(
KnownCondition::Msc4306ThreadSubscription { subscribed: false },
)]),
actions: Cow::Borrowed(&[]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.subscribed_thread"),
priority_class: 1,
conditions: Cow::Borrowed(&[Condition::Known(
KnownCondition::Msc4306ThreadSubscription { subscribed: true },
)]),
actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/underride/.m.rule.call"),
priority_class: 1,

View File

@@ -106,8 +106,11 @@ pub struct PushRuleEvaluator {
/// flag as MSC1767 (extensible events core).
msc3931_enabled: bool,
// If MSC4210 (remove legacy mentions) is enabled.
/// If MSC4210 (remove legacy mentions) is enabled.
msc4210_enabled: bool,
/// If MSC4306 (thread subscriptions) is enabled.
msc4306_enabled: bool,
}
#[pymethods]
@@ -126,6 +129,7 @@ impl PushRuleEvaluator {
room_version_feature_flags,
msc3931_enabled,
msc4210_enabled,
msc4306_enabled,
))]
pub fn py_new(
flattened_keys: BTreeMap<String, JsonValue>,
@@ -138,6 +142,7 @@ impl PushRuleEvaluator {
room_version_feature_flags: Vec<String>,
msc3931_enabled: bool,
msc4210_enabled: bool,
msc4306_enabled: bool,
) -> Result<Self, Error> {
let body = match flattened_keys.get("content.body") {
Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone().into_owned(),
@@ -156,6 +161,7 @@ impl PushRuleEvaluator {
room_version_feature_flags,
msc3931_enabled,
msc4210_enabled,
msc4306_enabled,
})
}
@@ -167,12 +173,19 @@ impl PushRuleEvaluator {
///
/// Returns the set of actions, if any, that match (filtering out any
/// `dont_notify` and `coalesce` actions).
#[pyo3(signature = (push_rules, user_id=None, display_name=None))]
///
/// msc4306_thread_subscription_state: (Only populated if MSC4306 is enabled)
/// The thread subscription state corresponding to the thread containing this event.
/// - `None` if the event is not in a thread, or if MSC4306 is disabled.
/// - `Some(true)` if the event is in a thread and the user has a subscription for that thread
/// - `Some(false)` if the event is in a thread and the user does NOT have a subscription for that thread
#[pyo3(signature = (push_rules, user_id=None, display_name=None, msc4306_thread_subscription_state=None))]
pub fn run(
&self,
push_rules: &FilteredPushRules,
user_id: Option<&str>,
display_name: Option<&str>,
msc4306_thread_subscription_state: Option<bool>,
) -> Vec<Action> {
'outer: for (push_rule, enabled) in push_rules.iter() {
if !enabled {
@@ -204,7 +217,12 @@ impl PushRuleEvaluator {
Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }),
);
match self.match_condition(condition, user_id, display_name) {
match self.match_condition(
condition,
user_id,
display_name,
msc4306_thread_subscription_state,
) {
Ok(true) => {}
Ok(false) => continue 'outer,
Err(err) => {
@@ -237,14 +255,20 @@ impl PushRuleEvaluator {
}
/// Check if the given condition matches.
#[pyo3(signature = (condition, user_id=None, display_name=None))]
#[pyo3(signature = (condition, user_id=None, display_name=None, msc4306_thread_subscription_state=None))]
fn matches(
&self,
condition: Condition,
user_id: Option<&str>,
display_name: Option<&str>,
msc4306_thread_subscription_state: Option<bool>,
) -> bool {
match self.match_condition(&condition, user_id, display_name) {
match self.match_condition(
&condition,
user_id,
display_name,
msc4306_thread_subscription_state,
) {
Ok(true) => true,
Ok(false) => false,
Err(err) => {
@@ -262,6 +286,7 @@ impl PushRuleEvaluator {
condition: &Condition,
user_id: Option<&str>,
display_name: Option<&str>,
msc4306_thread_subscription_state: Option<bool>,
) -> Result<bool, Error> {
let known_condition = match condition {
Condition::Known(known) => known,
@@ -393,6 +418,13 @@ impl PushRuleEvaluator {
&& self.room_version_feature_flags.contains(&flag)
}
}
KnownCondition::Msc4306ThreadSubscription { subscribed } => {
if !self.msc4306_enabled {
false
} else {
msc4306_thread_subscription_state == Some(*subscribed)
}
}
};
Ok(result)
@@ -536,10 +568,11 @@ fn push_rule_evaluator() {
vec![],
true,
false,
false,
)
.unwrap();
let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"), None);
assert_eq!(result.len(), 3);
}
@@ -566,6 +599,7 @@ fn test_requires_room_version_supports_condition() {
flags,
true,
false,
false,
)
.unwrap();
@@ -575,6 +609,7 @@ fn test_requires_room_version_supports_condition() {
&FilteredPushRules::default(),
Some("@bob:example.org"),
None,
None,
);
assert_eq!(result.len(), 3);
@@ -593,7 +628,17 @@ fn test_requires_room_version_supports_condition() {
};
let rules = PushRules::new(vec![custom_rule]);
result = evaluator.run(
&FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false, false),
&FilteredPushRules::py_new(
rules,
BTreeMap::new(),
true,
false,
true,
false,
false,
false,
),
None,
None,
None,
);

View File

@@ -369,6 +369,10 @@ pub enum KnownCondition {
RoomVersionSupports {
feature: Cow<'static, str>,
},
#[serde(rename = "io.element.msc4306.thread_subscription")]
Msc4306ThreadSubscription {
subscribed: bool,
},
}
impl<'source> IntoPyObject<'source> for Condition {
@@ -547,11 +551,13 @@ pub struct FilteredPushRules {
msc3664_enabled: bool,
msc4028_push_encrypted_events: bool,
msc4210_enabled: bool,
msc4306_enabled: bool,
}
#[pymethods]
impl FilteredPushRules {
#[new]
#[allow(clippy::too_many_arguments)]
pub fn py_new(
push_rules: PushRules,
enabled_map: BTreeMap<String, bool>,
@@ -560,6 +566,7 @@ impl FilteredPushRules {
msc3664_enabled: bool,
msc4028_push_encrypted_events: bool,
msc4210_enabled: bool,
msc4306_enabled: bool,
) -> Self {
Self {
push_rules,
@@ -569,6 +576,7 @@ impl FilteredPushRules {
msc3664_enabled,
msc4028_push_encrypted_events,
msc4210_enabled,
msc4306_enabled,
}
}
@@ -619,6 +627,10 @@ impl FilteredPushRules {
return false;
}
if !self.msc4306_enabled && rule.rule_id.contains("/.io.element.msc4306.rule.") {
return false;
}
true
})
.map(|r| {

View File

@@ -140,6 +140,12 @@ class Codes(str, Enum):
# Part of MSC4155
INVITE_BLOCKED = "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED"
# Part of MSC4306: Thread Subscriptions
MSC4306_CONFLICTING_UNSUBSCRIPTION = (
"IO.ELEMENT.MSC4306.M_CONFLICTING_UNSUBSCRIPTION"
)
MSC4306_NOT_IN_THREAD = "IO.ELEMENT.MSC4306.M_NOT_IN_THREAD"
class CodeMessageException(RuntimeError):
"""An exception with integer code, a message string attributes and optional headers.

View File

@@ -37,7 +37,6 @@ Events are replicated via a separate events stream.
"""
import logging
from enum import Enum
from typing import (
TYPE_CHECKING,
Dict,
@@ -68,25 +67,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class QueueNames(str, Enum):
PRESENCE_MAP = "presence_map"
KEYED_EDU = "keyed_edu"
KEYED_EDU_CHANGED = "keyed_edu_changed"
EDUS = "edus"
POS_TIME = "pos_time"
PRESENCE_DESTINATIONS = "presence_destinations"
queue_name_to_gauge_map: Dict[QueueNames, LaterGauge] = {}
for queue_name in QueueNames:
queue_name_to_gauge_map[queue_name] = LaterGauge(
name=f"synapse_federation_send_queue_{queue_name.value}_size",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
class FederationRemoteSendQueue(AbstractFederationSender):
"""A drop in replacement for FederationSender"""
@@ -131,15 +111,23 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# we make a new function, so we need to make a new function so the inner
# lambda binds to the queue rather than to the name of the queue which
# changes. ARGH.
def register(queue_name: QueueNames, queue: Sized) -> None:
queue_name_to_gauge_map[queue_name].register_hook(
lambda: {(self.server_name,): len(queue)}
def register(name: str, queue: Sized) -> None:
LaterGauge(
name="synapse_federation_send_queue_%s_size" % (queue_name,),
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): len(queue)},
)
for queue_name in QueueNames:
queue = getattr(self, queue_name.value)
assert isinstance(queue, Sized)
register(queue_name, queue=queue)
for queue_name in [
"presence_map",
"keyed_edu",
"keyed_edu_changed",
"edus",
"pos_time",
"presence_destinations",
]:
register(queue_name, getattr(self, queue_name))
self.clock.looping_call(self._clear_queue, 30 * 1000)

View File

@@ -199,24 +199,6 @@ sent_pdus_destination_dist_total = Counter(
labelnames=[SERVER_NAME_LABEL],
)
transaction_queue_pending_destinations_gauge = LaterGauge(
name="synapse_federation_transaction_queue_pending_destinations",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
transaction_queue_pending_pdus_gauge = LaterGauge(
name="synapse_federation_transaction_queue_pending_pdus",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
transaction_queue_pending_edus_gauge = LaterGauge(
name="synapse_federation_transaction_queue_pending_edus",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
# Time (in s) to wait before trying to wake up destinations that have
# catch-up outstanding.
# Please note that rate limiting still applies, so while the loop is
@@ -416,28 +398,38 @@ class FederationSender(AbstractFederationSender):
# map from destination to PerDestinationQueue
self._per_destination_queues: Dict[str, PerDestinationQueue] = {}
transaction_queue_pending_destinations_gauge.register_hook(
lambda: {
LaterGauge(
name="synapse_federation_transaction_queue_pending_destinations",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {
(self.server_name,): sum(
1
for d in self._per_destination_queues.values()
if d.transmission_loop_running
)
}
},
)
transaction_queue_pending_pdus_gauge.register_hook(
lambda: {
LaterGauge(
name="synapse_federation_transaction_queue_pending_pdus",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {
(self.server_name,): sum(
d.pending_pdu_count() for d in self._per_destination_queues.values()
)
}
},
)
transaction_queue_pending_edus_gauge.register_hook(
lambda: {
LaterGauge(
name="synapse_federation_transaction_queue_pending_edus",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {
(self.server_name,): sum(
d.pending_edu_count() for d in self._per_destination_queues.values()
)
}
},
)
self._is_processing = False

View File

@@ -173,18 +173,6 @@ state_transition_counter = Counter(
labelnames=["locality", "from", "to", SERVER_NAME_LABEL],
)
presence_user_to_current_state_size_gauge = LaterGauge(
name="synapse_handlers_presence_user_to_current_state_size",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
presence_wheel_timer_size_gauge = LaterGauge(
name="synapse_handlers_presence_wheel_timer_size",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them
# "currently_active"
LAST_ACTIVE_GRANULARITY = 60 * 1000
@@ -791,8 +779,11 @@ class PresenceHandler(BasePresenceHandler):
EduTypes.PRESENCE, self.incoming_presence
)
presence_user_to_current_state_size_gauge.register_hook(
lambda: {(self.server_name,): len(self.user_to_current_state)}
LaterGauge(
name="synapse_handlers_presence_user_to_current_state_size",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): len(self.user_to_current_state)},
)
# The per-device presence state, maps user to devices to per-device presence state.
@@ -891,8 +882,11 @@ class PresenceHandler(BasePresenceHandler):
60 * 1000,
)
presence_wheel_timer_size_gauge.register_hook(
lambda: {(self.server_name,): len(self.wheel_timer)}
LaterGauge(
name="synapse_handlers_presence_wheel_timer_size",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): len(self.wheel_timer)},
)
# Used to handle sending of presence to newly joined users/servers

View File

@@ -1,9 +1,15 @@
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import AuthError, NotFoundError
from synapse.storage.databases.main.thread_subscriptions import ThreadSubscription
from synapse.types import UserID
from synapse.api.constants import RelationTypes
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.events import relation_from_event
from synapse.storage.databases.main.thread_subscriptions import (
AutomaticSubscriptionConflicted,
ThreadSubscription,
)
from synapse.types import EventOrderings, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -55,42 +61,79 @@ class ThreadSubscriptionsHandler:
room_id: str,
thread_root_event_id: str,
*,
automatic: bool,
automatic_event_id: Optional[str],
) -> Optional[int]:
"""Sets or updates a user's subscription settings for a specific thread root.
Args:
requester_user_id: The ID of the user whose settings are being updated.
thread_root_event_id: The event ID of the thread root.
automatic: whether the user was subscribed by an automatic decision by
their client.
automatic_event_id: if the user was subscribed by an automatic decision by
their client, the event ID that caused this.
Returns:
The stream ID for this update, if the update isn't no-opped.
Raises:
NotFoundError if the user cannot access the thread root event, or it isn't
known to this homeserver.
known to this homeserver. Ditto for the automatic cause event if supplied.
SynapseError(400, M_NOT_IN_THREAD): if client supplied an automatic cause event
but user cannot access the event.
SynapseError(409, M_SKIPPED): if client requested an automatic subscription
but it was skipped because the cause event is logically later than an unsubscription.
"""
# First check that the user can access the thread root event
# and that it exists
try:
event = await self.event_handler.get_event(
thread_root_event = await self.event_handler.get_event(
user_id, room_id, thread_root_event_id
)
if event is None:
if thread_root_event is None:
raise NotFoundError("No such thread root")
except AuthError:
logger.info("rejecting thread subscriptions change (thread not accessible)")
raise NotFoundError("No such thread root")
return await self.store.subscribe_user_to_thread(
if automatic_event_id:
autosub_cause_event = await self.event_handler.get_event(
user_id, room_id, automatic_event_id
)
if autosub_cause_event is None:
raise NotFoundError("Automatic subscription event not found")
relation = relation_from_event(autosub_cause_event)
if (
relation is None
or relation.rel_type != RelationTypes.THREAD
or relation.parent_id != thread_root_event_id
):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Automatic subscription must use an event in the thread",
errcode=Codes.MSC4306_NOT_IN_THREAD,
)
automatic_event_orderings = EventOrderings.from_event(autosub_cause_event)
else:
automatic_event_orderings = None
outcome = await self.store.subscribe_user_to_thread(
user_id.to_string(),
event.room_id,
room_id,
thread_root_event_id,
automatic=automatic,
automatic_event_orderings=automatic_event_orderings,
)
if isinstance(outcome, AutomaticSubscriptionConflicted):
raise SynapseError(
HTTPStatus.CONFLICT,
"Automatic subscription obsoleted by an unsubscription request.",
errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION,
)
return outcome
async def unsubscribe_user_from_thread(
self, user_id: UserID, room_id: str, thread_root_event_id: str
) -> Optional[int]:

View File

@@ -164,12 +164,12 @@ def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]:
return counts
in_flight_requests = LaterGauge(
LaterGauge(
name="synapse_http_server_in_flight_requests_count",
desc="",
labelnames=["method", "servlet", SERVER_NAME_LABEL],
caller=_get_in_flight_counts,
)
in_flight_requests.register_hook(_get_in_flight_counts)
class RequestMetrics:

View File

@@ -31,7 +31,6 @@ from typing import (
Dict,
Generic,
Iterable,
List,
Mapping,
Optional,
Sequence,
@@ -74,6 +73,8 @@ logger = logging.getLogger(__name__)
METRICS_PREFIX = "/_synapse/metrics"
all_gauges: Dict[str, Collector] = {}
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
SERVER_NAME_LABEL = "server_name"
@@ -162,47 +163,42 @@ class LaterGauge(Collector):
name: str
desc: str
labelnames: Optional[StrSequence] = attr.ib(hash=False)
# List of callbacks: each callback should either return a value (if there are no
# labels for this metric), or dict mapping from a label tuple to a value
_hooks: List[
Callable[
[], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
]
] = attr.ib(factory=list, hash=False)
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
caller: Callable[
[], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
]
def collect(self) -> Iterable[Metric]:
# The decision to add `SERVER_NAME_LABEL` is from the `LaterGauge` usage itself
# (we don't enforce it here, one level up).
g = GaugeMetricFamily(self.name, self.desc, labels=self.labelnames) # type: ignore[missing-server-name-label]
for hook in self._hooks:
try:
hook_result = hook()
except Exception:
logger.exception(
"Exception running callback for LaterGauge(%s)", self.name
)
yield g
return
if isinstance(hook_result, (int, float)):
g.add_metric([], hook_result)
else:
for k, v in hook_result.items():
g.add_metric(k, v)
try:
calls = self.caller()
except Exception:
logger.exception("Exception running callback for LaterGauge(%s)", self.name)
yield g
return
def register_hook(
self,
hook: Callable[
[], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
],
) -> None:
self._hooks.append(hook)
if isinstance(calls, (int, float)):
g.add_metric([], calls)
else:
for k, v in calls.items():
g.add_metric(k, v)
yield g
def __attrs_post_init__(self) -> None:
self._register()
def _register(self) -> None:
if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering", self.name)
REGISTRY.unregister(all_gauges.pop(self.name))
REGISTRY.register(self)
all_gauges[self.name] = self
# `MetricsEntry` only makes sense when it is a `Protocol`,
@@ -254,7 +250,7 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
# Protects access to _registrations
self._lock = threading.Lock()
REGISTRY.register(self)
self._register_with_collector()
def register(
self,
@@ -345,6 +341,14 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
gauge.add_metric(labels=key, value=getattr(metrics, name))
yield gauge
def _register_with_collector(self) -> None:
if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering", self.name)
REGISTRY.unregister(all_gauges.pop(self.name))
REGISTRY.register(self)
all_gauges[self.name] = self
class GaugeHistogramMetricFamilyWithLabels(GaugeHistogramMetricFamily):
"""

View File

@@ -86,24 +86,6 @@ users_woken_by_stream_counter = Counter(
labelnames=["stream", SERVER_NAME_LABEL],
)
notifier_listeners_gauge = LaterGauge(
name="synapse_notifier_listeners",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
notifier_rooms_gauge = LaterGauge(
name="synapse_notifier_rooms",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
notifier_users_gauge = LaterGauge(
name="synapse_notifier_users",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
T = TypeVar("T")
@@ -299,16 +281,28 @@ class Notifier:
)
}
notifier_listeners_gauge.register_hook(count_listeners)
notifier_rooms_gauge.register_hook(
lambda: {
LaterGauge(
name="synapse_notifier_listeners",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=count_listeners,
)
LaterGauge(
name="synapse_notifier_rooms",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {
(self.server_name,): count(
bool, list(self.room_to_user_streams.values())
)
}
},
)
notifier_users_gauge.register_hook(
lambda: {(self.server_name,): len(self.user_to_user_stream)}
LaterGauge(
name="synapse_notifier_users",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): len(self.user_to_user_stream)},
)
def add_replication_callback(self, cb: Callable[[], None]) -> None:

View File

@@ -25,6 +25,7 @@ from typing import (
Any,
Collection,
Dict,
FrozenSet,
List,
Mapping,
Optional,
@@ -477,8 +478,18 @@ class BulkPushRuleEvaluator:
event.room_version.msc3931_push_features,
self.hs.config.experimental.msc1767_enabled, # MSC3931 flag
self.hs.config.experimental.msc4210_enabled,
self.hs.config.experimental.msc4306_enabled,
)
msc4306_thread_subscribers: Optional[FrozenSet[str]] = None
if self.hs.config.experimental.msc4306_enabled and thread_id != MAIN_TIMELINE:
# pull out, in batch, all local subscribers to this thread
# (in the common case, they will all be getting processed for push
# rules right now)
msc4306_thread_subscribers = await self.store.get_subscribers_to_thread(
event.room_id, thread_id
)
for uid, rules in rules_by_user.items():
if event.sender == uid:
continue
@@ -503,7 +514,13 @@ class BulkPushRuleEvaluator:
# current user, it'll be added to the dict later.
actions_by_user[uid] = []
actions = evaluator.run(rules, uid, display_name)
msc4306_thread_subscription_state: Optional[bool] = None
if msc4306_thread_subscribers is not None:
msc4306_thread_subscription_state = uid in msc4306_thread_subscribers
actions = evaluator.run(
rules, uid, display_name, msc4306_thread_subscription_state
)
if "notify" in actions:
# Push rules say we should notify the user of this event
actions_by_user[uid] = actions

View File

@@ -106,18 +106,6 @@ user_ip_cache_counter = Counter(
"synapse_replication_tcp_resource_user_ip_cache", "", labelnames=[SERVER_NAME_LABEL]
)
tcp_resource_total_connections_gauge = LaterGauge(
name="synapse_replication_tcp_resource_total_connections",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
tcp_command_queue_gauge = LaterGauge(
name="synapse_replication_tcp_command_queue",
desc="Number of inbound RDATA/POSITION commands queued for processing",
labelnames=["stream_name", SERVER_NAME_LABEL],
)
# the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[
@@ -255,8 +243,11 @@ class ReplicationCommandHandler:
# outgoing replication commands to.)
self._connections: List[IReplicationConnection] = []
tcp_resource_total_connections_gauge.register_hook(
lambda: {(self.server_name,): len(self._connections)}
LaterGauge(
name="synapse_replication_tcp_resource_total_connections",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): len(self._connections)},
)
# When POSITION or RDATA commands arrive, we stick them in a queue and process
@@ -275,11 +266,14 @@ class ReplicationCommandHandler:
# from that connection.
self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {}
tcp_command_queue_gauge.register_hook(
lambda: {
LaterGauge(
name="synapse_replication_tcp_command_queue",
desc="Number of inbound RDATA/POSITION commands queued for processing",
labelnames=["stream_name", SERVER_NAME_LABEL],
caller=lambda: {
(stream_name, self.server_name): len(queue)
for stream_name, queue in self._command_queues_by_stream.items()
}
},
)
self._is_master = hs.config.worker.worker_app is None

View File

@@ -527,11 +527,9 @@ pending_commands = LaterGauge(
name="synapse_replication_tcp_protocol_pending_commands",
desc="",
labelnames=["name", SERVER_NAME_LABEL],
)
pending_commands.register_hook(
lambda: {
caller=lambda: {
(p.name, p.server_name): len(p.pending_commands) for p in connected_connections
}
},
)
@@ -546,11 +544,9 @@ transport_send_buffer = LaterGauge(
name="synapse_replication_tcp_protocol_transport_send_buffer",
desc="",
labelnames=["name", SERVER_NAME_LABEL],
)
transport_send_buffer.register_hook(
lambda: {
caller=lambda: {
(p.name, p.server_name): transport_buffer_size(p) for p in connected_connections
}
},
)
@@ -575,12 +571,10 @@ tcp_transport_kernel_send_buffer = LaterGauge(
name="synapse_replication_tcp_protocol_transport_kernel_send_buffer",
desc="",
labelnames=["name", SERVER_NAME_LABEL],
)
tcp_transport_kernel_send_buffer.register_hook(
lambda: {
caller=lambda: {
(p.name, p.server_name): transport_kernel_read_buffer_size(p, False)
for p in connected_connections
}
},
)
@@ -588,10 +582,8 @@ tcp_transport_kernel_read_buffer = LaterGauge(
name="synapse_replication_tcp_protocol_transport_kernel_read_buffer",
desc="",
labelnames=["name", SERVER_NAME_LABEL],
)
tcp_transport_kernel_read_buffer.register_hook(
lambda: {
caller=lambda: {
(p.name, p.server_name): transport_kernel_read_buffer_size(p, True)
for p in connected_connections
}
},
)

View File

@@ -739,7 +739,7 @@ class ThreadSubscriptionsStream(_StreamFromIdGen):
NAME = "thread_subscriptions"
ROW_TYPE = ThreadSubscriptionsStreamRow
def __init__(self, hs: Any):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
@@ -751,7 +751,7 @@ class ThreadSubscriptionsStream(_StreamFromIdGen):
self, instance_name: str, from_token: int, to_token: int, limit: int
) -> StreamUpdateResult:
updates = await self.store.get_updated_thread_subscriptions(
from_token, to_token, limit
from_id=from_token, to_id=to_token, limit=limit
)
rows = [
(

View File

@@ -1,7 +1,6 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
from synapse._pydantic_compat import StrictBool
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
@@ -12,6 +11,7 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict, RoomID
from synapse.types.rest import RequestBodyModel
from synapse.util.pydantic_models import AnyEventId
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -32,7 +32,12 @@ class ThreadSubscriptionsRestServlet(RestServlet):
self.handler = hs.get_thread_subscriptions_handler()
class PutBody(RequestBodyModel):
automatic: StrictBool
automatic: Optional[AnyEventId]
"""
If supplied, the event ID of an event giving rise to this automatic subscription.
If omitted, this subscription is a manual subscription.
"""
async def on_GET(
self, request: SynapseRequest, room_id: str, thread_root_id: str
@@ -63,15 +68,15 @@ class ThreadSubscriptionsRestServlet(RestServlet):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM
)
requester = await self.auth.get_user_by_req(request)
body = parse_and_validate_json_object_from_request(request, self.PutBody)
requester = await self.auth.get_user_by_req(request)
await self.handler.subscribe_user_to_thread(
requester.user,
room_id,
thread_root_id,
automatic=body.automatic,
automatic_event_id=body.automatic,
)
return HTTPStatus.OK, {}

View File

@@ -100,12 +100,6 @@ sql_txn_duration = Counter(
labelnames=["desc", SERVER_NAME_LABEL],
)
background_update_status = LaterGauge(
name="synapse_background_update_status",
desc="Background update status",
labelnames=[SERVER_NAME_LABEL],
)
# Unique indexes which have been added in background updates. Maps from table name
# to the name of the background update which added the unique index to that table.
@@ -617,8 +611,11 @@ class DatabasePool:
)
self.updates = BackgroundUpdater(hs, self)
background_update_status.register_hook(
lambda: {(self.server_name,): self.updates.get_status()},
LaterGauge(
name="synapse_background_update_status",
desc="Background update status",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): self.updates.get_status()},
)
self._previous_txn_total_time = 0.0

View File

@@ -110,6 +110,7 @@ def _load_rules(
msc3381_polls_enabled=experimental_config.msc3381_polls_enabled,
msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events,
msc4210_enabled=experimental_config.msc4210_enabled,
msc4306_enabled=experimental_config.msc4306_enabled,
)
return filtered_rules

View File

@@ -84,13 +84,6 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
_POPULATE_PARTICIPANT_BG_UPDATE_BATCH_SIZE = 1000
federation_known_servers_gauge = LaterGauge(
name="synapse_federation_known_servers",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
@attr.s(frozen=True, slots=True, auto_attribs=True)
class EventIdMembership:
"""Returned by `get_membership_from_event_ids`"""
@@ -123,8 +116,11 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
1,
self._count_known_servers,
)
federation_known_servers_gauge.register_hook(
lambda: {(self.server_name,): self._known_servers_count}
LaterGauge(
name="synapse_federation_known_servers",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): self._known_servers_count},
)
@wrap_as_background_process("_count_known_servers")

View File

@@ -14,7 +14,7 @@ import logging
from typing import (
TYPE_CHECKING,
Any,
Dict,
FrozenSet,
Iterable,
List,
Optional,
@@ -33,6 +33,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import EventOrderings
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -50,6 +51,14 @@ class ThreadSubscription:
"""
class AutomaticSubscriptionConflicted:
"""
Marker return value to signal that an automatic subscription was skipped,
because it conflicted with an unsubscription that we consider to have
been made later than the event causing the automatic subscription.
"""
class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -91,6 +100,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
self.get_subscription_for_thread.invalidate(
(row.user_id, row.room_id, row.event_id)
)
self.get_subscribers_to_thread.invalidate((row.room_id, row.event_id))
super().process_replication_rows(stream_name, instance_name, token, rows)
@@ -101,75 +111,196 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
self._thread_subscriptions_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
@staticmethod
def _should_skip_autosubscription_after_unsubscription(
*,
autosub: EventOrderings,
unsubscribed_at: EventOrderings,
) -> bool:
"""
Returns whether an automatic subscription occurring *after* an unsubscription
should be skipped, because the unsubscription already 'acknowledges' the event
causing the automatic subscription (the cause event).
To determine *after*, we use `stream_ordering` unless the event is backfilled
(negative `stream_ordering`) and fallback to topological ordering.
Args:
autosub: the stream_ordering and topological_ordering of the cause event
unsubscribed_at:
the maximum stream ordering and the maximum topological ordering at the time of unsubscription
Returns:
True if the automatic subscription should be skipped
"""
# For normal rooms, these two orderings should be positive, because
# they don't refer to a specific event but rather the maximum at the
# time of unsubscription.
#
# However, for rooms that have never been joined and that are being peeked at,
# we might not have a single non-backfilled event and therefore the stream
# ordering might be negative, so we don't assert this case.
assert unsubscribed_at.topological > 0
unsubscribed_at_backfilled = unsubscribed_at.stream < 0
if (
not unsubscribed_at_backfilled
and unsubscribed_at.stream >= autosub.stream > 0
):
# non-backfilled events: the unsubscription is later according to
# the stream
return True
if autosub.stream < 0:
# the auto-subscription cause event was backfilled, so fall back to
# topological ordering
if unsubscribed_at.topological >= autosub.topological:
return True
return False
async def subscribe_user_to_thread(
self, user_id: str, room_id: str, thread_root_event_id: str, *, automatic: bool
) -> Optional[int]:
self,
user_id: str,
room_id: str,
thread_root_event_id: str,
*,
automatic_event_orderings: Optional[EventOrderings],
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
"""Updates a user's subscription settings for a specific thread root.
If no change would be made to the subscription, does not produce any database change.
Case-by-case:
- if we already have an automatic subscription:
- new automatic subscriptions will be no-ops (no database write),
- new manual subscriptions will overwrite the automatic subscription
- if we already have a manual subscription:
we don't update (no database write) in either case, because:
- the existing manual subscription wins over a new automatic subscription request
- there would be no need to write a manual subscription because we already have one
Args:
user_id: The ID of the user whose settings are being updated.
room_id: The ID of the room the thread root belongs to.
thread_root_event_id: The event ID of the thread root.
automatic: Whether the subscription was performed automatically by the user's client.
Only `False` will overwrite an existing value of automatic for a subscription row.
automatic_event_orderings:
Value depends on whether the subscription was performed automatically by the user's client.
For manual subscriptions: None.
For automatic subscriptions: the orderings of the event.
Returns:
The stream ID for this update, if the update isn't no-opped.
If a subscription is made: (int) the stream ID for this update.
If a subscription already exists and did not need to be updated: None
If an automatic subscription conflicted with an unsubscription: AutomaticSubscriptionConflicted
"""
assert self._can_write_to_thread_subscriptions
def _subscribe_user_to_thread_txn(txn: LoggingTransaction) -> Optional[int]:
already_automatic = self.db_pool.simple_select_one_onecol_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
"subscribed": True,
},
retcol="automatic",
allow_none=True,
)
if already_automatic is None:
already_subscribed = False
already_automatic = True
else:
already_subscribed = True
# convert int (SQLite bool) to Python bool
already_automatic = bool(already_automatic)
if already_subscribed and already_automatic == automatic:
# there is nothing we need to do here
return None
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
values: Dict[str, Optional[Union[bool, int, str]]] = {
"subscribed": True,
"stream_id": stream_id,
"instance_name": self._instance_name,
"automatic": already_automatic and automatic,
}
self.db_pool.simple_upsert_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
},
values=values,
)
def _invalidate_subscription_caches(txn: LoggingTransaction) -> None:
txn.call_after(
self.get_subscription_for_thread.invalidate,
(user_id, room_id, thread_root_event_id),
)
txn.call_after(
self.get_subscribers_to_thread.invalidate,
(room_id, thread_root_event_id),
)
def _subscribe_user_to_thread_txn(
txn: LoggingTransaction,
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
requested_automatic = automatic_event_orderings is not None
row = self.db_pool.simple_select_one_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
},
retcols=(
"subscribed",
"automatic",
"unsubscribed_at_stream_ordering",
"unsubscribed_at_topological_ordering",
),
allow_none=True,
)
if row is None:
# We have never subscribed before, simply insert the row and finish
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
self.db_pool.simple_insert_txn(
txn,
table="thread_subscriptions",
values={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
"subscribed": True,
"stream_id": stream_id,
"instance_name": self._instance_name,
"automatic": requested_automatic,
"unsubscribed_at_stream_ordering": None,
"unsubscribed_at_topological_ordering": None,
},
)
_invalidate_subscription_caches(txn)
return stream_id
# we already have either a subscription or a prior unsubscription here
(
subscribed,
already_automatic,
unsubscribed_at_stream_ordering,
unsubscribed_at_topological_ordering,
) = row
if subscribed and (not already_automatic or requested_automatic):
# we are already subscribed and the current subscription state
# is good enough (either we already have a manual subscription,
# or we requested an automatic subscription)
# In that case, nothing to change here.
# (See docstring for case-by-case explanation)
return None
if not subscribed and requested_automatic:
assert automatic_event_orderings is not None
# we previously unsubscribed and we are now automatically subscribing
# Check whether the new autosubscription should be skipped
if ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription(
autosub=automatic_event_orderings,
unsubscribed_at=EventOrderings(
unsubscribed_at_stream_ordering,
unsubscribed_at_topological_ordering,
),
):
# skip the subscription
return AutomaticSubscriptionConflicted()
# At this point: we have now finished checking that we need to make
# a subscription, updating the current row.
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
self.db_pool.simple_update_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
},
updatevalues={
"subscribed": True,
"stream_id": stream_id,
"instance_name": self._instance_name,
"automatic": requested_automatic,
"unsubscribed_at_stream_ordering": None,
"unsubscribed_at_topological_ordering": None,
},
)
_invalidate_subscription_caches(txn)
return stream_id
@@ -214,6 +345,21 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
# Find the maximum stream ordering and topological ordering of the room,
# which we then store against this unsubscription so we can skip future
# automatic subscriptions that are caused by an event logically earlier
# than this unsubscription.
txn.execute(
"""
SELECT MAX(stream_ordering) AS mso, MAX(topological_ordering) AS mto FROM events
WHERE room_id = ?
""",
(room_id,),
)
ord_row = txn.fetchone()
assert ord_row is not None
max_stream_ordering, max_topological_ordering = ord_row
self.db_pool.simple_update_txn(
txn,
table="thread_subscriptions",
@@ -227,6 +373,8 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
"subscribed": False,
"stream_id": stream_id,
"instance_name": self._instance_name,
"unsubscribed_at_stream_ordering": max_stream_ordering,
"unsubscribed_at_topological_ordering": max_topological_ordering,
},
)
@@ -234,6 +382,10 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
self.get_subscription_for_thread.invalidate,
(user_id, room_id, thread_root_event_id),
)
txn.call_after(
self.get_subscribers_to_thread.invalidate,
(room_id, thread_root_event_id),
)
return stream_id
@@ -246,7 +398,9 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
Purge all subscriptions for the user.
The fact that subscriptions have been purged will not be streamed;
all stream rows for the user will in fact be removed.
This is intended only for dealing with user deactivation.
This must only be used for user deactivation,
because it does not invalidate the `subscribers_to_thread` cache.
"""
def _purge_thread_subscription_settings_for_user_txn(
@@ -307,6 +461,42 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
return ThreadSubscription(automatic=automatic)
# max_entries=100 rationale:
# this returns a potentially large datastructure
# (since each entry contains a set which contains a potentially large number of user IDs),
# whereas the default of 10'000 entries for @cached feels more
# suitable for very small cache entries.
#
# Overall, when bearing in mind the usual profile of a small community-server or company-server
# (where cache tuning hasn't been done, so we're in out-of-box configuration), it is very
# unlikely we would benefit from keeping hot the subscribers for as many as 100 threads,
# since it's unlikely that so many threads will be active in a short span of time on a small homeserver.
# It feels that medium servers will probably also not exhaust this limit.
# Larger homeservers are more likely to be carefully tuned, either with a larger global cache factor
# or carefully following the usage patterns & cache metrics.
# Finally, the query is not so intensive that computing it every time is a huge deal, but given people
# often send messages back-to-back in the same thread it seems like it would offer a mild benefit.
@cached(max_entries=100)
async def get_subscribers_to_thread(
self, room_id: str, thread_root_event_id: str
) -> FrozenSet[str]:
"""
Returns:
the set of user_ids for local users who are subscribed to the given thread.
"""
return frozenset(
await self.db_pool.simple_select_onecol(
table="thread_subscriptions",
keyvalues={
"room_id": room_id,
"event_id": thread_root_event_id,
"subscribed": True,
},
retcol="user_id",
desc="get_subscribers_to_thread",
)
)
def get_max_thread_subscriptions_stream_id(self) -> int:
"""Get the current maximum stream_id for thread subscriptions.
@@ -316,7 +506,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
return self._thread_subscriptions_id_gen.get_current_token()
async def get_updated_thread_subscriptions(
self, from_id: int, to_id: int, limit: int
self, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str, str]]:
"""Get updates to thread subscriptions between two stream IDs.
@@ -349,7 +539,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
)
async def get_updated_thread_subscriptions_for_user(
self, user_id: str, from_id: int, to_id: int, limit: int
self, user_id: str, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str]]:
"""Get updates to thread subscriptions for a specific user.

View File

@@ -0,0 +1,20 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- The maximum stream_ordering in the room when the unsubscription was made.
ALTER TABLE thread_subscriptions
ADD COLUMN unsubscribed_at_stream_ordering BIGINT;
-- The maximum topological_ordering in the room when the unsubscription was made.
ALTER TABLE thread_subscriptions
ADD COLUMN unsubscribed_at_topological_ordering BIGINT;

View File

@@ -0,0 +1,18 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_stream_ordering IS
$$The maximum stream_ordering in the room when the unsubscription was made.$$;
COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_topological_ordering IS
$$The maximum topological_ordering in the room when the unsubscription was made.$$;

View File

@@ -49,6 +49,7 @@ class FilteredPushRules:
msc3664_enabled: bool,
msc4028_push_encrypted_events: bool,
msc4210_enabled: bool,
msc4306_enabled: bool,
): ...
def rules(self) -> Collection[Tuple[PushRule, bool]]: ...
@@ -67,13 +68,19 @@ class PushRuleEvaluator:
room_version_feature_flags: Tuple[str, ...],
msc3931_enabled: bool,
msc4210_enabled: bool,
msc4306_enabled: bool,
): ...
def run(
self,
push_rules: FilteredPushRules,
user_id: Optional[str],
display_name: Optional[str],
msc4306_thread_subscription_state: Optional[bool],
) -> Collection[Union[Mapping, str]]: ...
def matches(
self, condition: JsonDict, user_id: Optional[str], display_name: Optional[str]
self,
condition: JsonDict,
user_id: Optional[str],
display_name: Optional[str],
msc4306_thread_subscription_state: Optional[bool] = None,
) -> bool: ...

View File

@@ -73,6 +73,7 @@ if TYPE_CHECKING:
from typing_extensions import Self
from synapse.appservice.api import ApplicationService
from synapse.events import EventBase
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -1464,3 +1465,31 @@ class ScheduledTask:
result: Optional[JsonMapping]
# Optional error that should be assigned a value when the status is FAILED
error: Optional[str]
@attr.s(auto_attribs=True, frozen=True, slots=True)
class EventOrderings:
stream: int
"""
The stream_ordering of the event.
Negative numbers mean the event was backfilled.
"""
topological: int
"""
The topological_ordering of the event.
Currently this is equivalent to the `depth` attributes of
the PDU.
"""
@staticmethod
def from_event(event: "EventBase") -> "EventOrderings":
"""
Get the orderings from an event.
Preconditions:
- the event must have been persisted (otherwise it won't have a stream ordering)
"""
stream = event.internal_metadata.stream_ordering
assert stream is not None
return EventOrderings(stream, event.depth)

View File

@@ -13,7 +13,11 @@
#
#
from synapse._pydantic_compat import BaseModel, Extra
import re
from typing import Any, Callable, Generator
from synapse._pydantic_compat import BaseModel, Extra, StrictStr
from synapse.types import EventID
class ParseModel(BaseModel):
@@ -37,3 +41,43 @@ class ParseModel(BaseModel):
extra = Extra.ignore
# By default, don't allow fields to be reassigned after parsing.
allow_mutation = False
class AnyEventId(StrictStr):
"""
A validator for strings that need to be an Event ID.
Accepts any valid grammar of Event ID from any room version.
"""
EVENT_ID_HASH_ROOM_VERSION_3_PLUS = re.compile(
r"^([a-zA-Z0-9-_]{43}|[a-zA-Z0-9+/]{43})$"
)
@classmethod
def __get_validators__(cls) -> Generator[Callable[..., Any], Any, Any]:
yield from super().__get_validators__() # type: ignore
yield cls.validate_event_id
@classmethod
def validate_event_id(cls, value: str) -> str:
if not value.startswith("$"):
raise ValueError("Event ID must start with `$`")
if ":" in value:
# Room versions 1 and 2
EventID.from_string(value) # throws on fail
else:
# Room versions 3+: event ID is $ + a base64 sha256 hash
# Room version 3 is base64, 4+ are base64Url
# In both cases, the base64 is unpadded.
# refs:
# - https://spec.matrix.org/v1.15/rooms/v3/ e.g. $acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk
# - https://spec.matrix.org/v1.15/rooms/v4/ e.g. $Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg
b64_hash = value[1:]
if cls.EVENT_ID_HASH_ROOM_VERSION_3_PLUS.fullmatch(b64_hash) is None:
raise ValueError(
"Event ID must either have a domain part or be a valid hash"
)
return value

View File

@@ -131,31 +131,27 @@ def _get_counts_from_rate_limiter_instance(
# We track the number of affected hosts per time-period so we can
# differentiate one really noisy homeserver from a general
# ratelimit tuning problem across the federation.
sleep_affected_hosts_gauge = LaterGauge(
LaterGauge(
name="synapse_rate_limit_sleep_affected_hosts",
desc="Number of hosts that had requests put to sleep",
labelnames=["rate_limiter_name", SERVER_NAME_LABEL],
)
sleep_affected_hosts_gauge.register_hook(
lambda: _get_counts_from_rate_limiter_instance(
caller=lambda: _get_counts_from_rate_limiter_instance(
lambda rate_limiter_instance: sum(
ratelimiter.should_sleep()
for ratelimiter in rate_limiter_instance.ratelimiters.values()
)
)
),
)
reject_affected_hosts_gauge = LaterGauge(
LaterGauge(
name="synapse_rate_limit_reject_affected_hosts",
desc="Number of hosts that had requests rejected",
labelnames=["rate_limiter_name", SERVER_NAME_LABEL],
)
reject_affected_hosts_gauge.register_hook(
lambda: _get_counts_from_rate_limiter_instance(
caller=lambda: _get_counts_from_rate_limiter_instance(
lambda rate_limiter_instance: sum(
ratelimiter.should_reject()
for ratelimiter in rate_limiter_instance.ratelimiters.values()
)
)
),
)

View File

@@ -44,13 +44,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
running_tasks_gauge = LaterGauge(
name="synapse_scheduler_running_tasks",
desc="The number of concurrent running tasks handled by the TaskScheduler",
labelnames=[SERVER_NAME_LABEL],
)
class TaskScheduler:
"""
This is a simple task scheduler designed for resumable tasks. Normally,
@@ -137,8 +130,11 @@ class TaskScheduler:
TaskScheduler.SCHEDULE_INTERVAL_MS,
)
running_tasks_gauge.register_hook(
lambda: {(self.server_name,): len(self._running_tasks)}
LaterGauge(
name="synapse_scheduler_running_tasks",
desc="The number of concurrent running tasks handled by the TaskScheduler",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): len(self._running_tasks)},
)
def register_action(

View File

@@ -22,13 +22,7 @@ from typing import Dict, Protocol, Tuple
from prometheus_client.core import Sample
from synapse.metrics import (
REGISTRY,
SERVER_NAME_LABEL,
InFlightGauge,
LaterGauge,
generate_latest,
)
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
from synapse.util.caches.deferred_cache import DeferredCache
from tests import unittest
@@ -291,42 +285,6 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
self.assertEqual(hs2_cache_max_size_metric_value, "777.0")
class LaterGaugeTests(unittest.HomeserverTestCase):
def test_later_gauge_multiple_servers(self) -> None:
"""
Test that LaterGauge metrics are reported correctly across multiple servers. We
will have an metrics entry for each homeserver that is labeled with the
`server_name` label.
"""
later_gauge = LaterGauge(
name="foo",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
later_gauge.register_hook(lambda: {("hs1",): 1})
later_gauge.register_hook(lambda: {("hs2",): 2})
metrics_map = get_latest_metrics()
# Find the metrics for the caches from both homeservers
hs1_metric = 'foo{server_name="hs1"}'
hs1_metric_value = metrics_map.get(hs1_metric)
self.assertIsNotNone(
hs1_metric_value,
f"Missing metric {hs1_metric} in cache metrics {metrics_map}",
)
hs2_metric = 'foo{server_name="hs2"}'
hs2_metric_value = metrics_map.get(hs2_metric)
self.assertIsNotNone(
hs2_metric_value,
f"Missing metric {hs2_metric} in cache metrics {metrics_map}",
)
# Sanity check the metric values
self.assertEqual(hs1_metric_value, "1.0")
self.assertEqual(hs2_metric_value, "2.0")
def get_latest_metrics() -> Dict[str, str]:
"""
Collect the latest metrics from the registry and parse them into an easy to use map.

View File

@@ -26,7 +26,7 @@ from parameterized import parameterized
from twisted.internet.testing import MemoryReactor
from synapse.api.constants import EventContentFields, RelationTypes
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.rest import admin
@@ -206,7 +206,10 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
bulk_evaluator._action_for_event_by_user.assert_not_called()
def _create_and_process(
self, bulk_evaluator: BulkPushRuleEvaluator, content: Optional[JsonDict] = None
self,
bulk_evaluator: BulkPushRuleEvaluator,
content: Optional[JsonDict] = None,
type: str = "test",
) -> bool:
"""Returns true iff the `mentions` trigger an event push action."""
# Create a new message event which should cause a notification.
@@ -214,7 +217,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
self.event_creation_handler.create_event(
self.requester,
{
"type": "test",
"type": type,
"room_id": self.room_id,
"content": content or {},
"sender": f"@bob:{self.hs.hostname}",
@@ -446,3 +449,73 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
@override_config({"experimental_features": {"msc4306_enabled": True}})
def test_thread_subscriptions(self) -> None:
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
(thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token)
self.assertFalse(
self._create_and_process(
bulk_evaluator,
{
"msgtype": "m.text",
"body": "test message before subscription",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
type=EventTypes.Message,
)
)
self.get_success(
self.hs.get_datastores().main.subscribe_user_to_thread(
self.alice,
self.room_id,
thread_root_id,
automatic_event_orderings=None,
)
)
self.assertTrue(
self._create_and_process(
bulk_evaluator,
{
"msgtype": "m.text",
"body": "test message after subscription",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
type="m.room.message",
)
)
def test_with_disabled_thread_subscriptions(self) -> None:
"""
Test what happens with threaded events when MSC4306 is disabled.
FUTURE: If MSC4306 becomes enabled-by-default/accepted, this test is to be removed.
"""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
(thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token)
# When MSC4306 is not enabled, a threaded message generates a notification
# by default.
self.assertTrue(
self._create_and_process(
bulk_evaluator,
{
"msgtype": "m.text",
"body": "test message before subscription",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
type="m.room.message",
)
)

View File

@@ -150,6 +150,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
*,
related_events: Optional[JsonDict] = None,
msc4210: bool = False,
msc4306: bool = False,
) -> PushRuleEvaluator:
event = FrozenEvent(
{
@@ -176,6 +177,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True,
msc4210_enabled=msc4210,
msc4306_enabled=msc4306,
)
def test_display_name(self) -> None:
@@ -806,6 +808,112 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
)
)
def test_thread_subscription_subscribed(self) -> None:
"""
Test MSC4306 thread subscription push rules against an event in a subscribed thread.
"""
evaluator = self._get_evaluator(
{
"msgtype": "m.text",
"body": "Squawk",
"m.relates_to": {
"event_id": "$threadroot",
"rel_type": "m.thread",
},
},
msc4306=True,
)
self.assertTrue(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": True,
},
None,
None,
msc4306_thread_subscription_state=True,
)
)
self.assertFalse(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": False,
},
None,
None,
msc4306_thread_subscription_state=True,
)
)
def test_thread_subscription_unsubscribed(self) -> None:
"""
Test MSC4306 thread subscription push rules against an event in an unsubscribed thread.
"""
evaluator = self._get_evaluator(
{
"msgtype": "m.text",
"body": "Squawk",
"m.relates_to": {
"event_id": "$threadroot",
"rel_type": "m.thread",
},
},
msc4306=True,
)
self.assertFalse(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": True,
},
None,
None,
msc4306_thread_subscription_state=False,
)
)
self.assertTrue(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": False,
},
None,
None,
msc4306_thread_subscription_state=False,
)
)
def test_thread_subscription_unthreaded(self) -> None:
"""
Test MSC4306 thread subscription push rules against an unthreaded event.
"""
evaluator = self._get_evaluator(
{"msgtype": "m.text", "body": "Squawk"}, msc4306=True
)
self.assertFalse(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": True,
},
None,
None,
msc4306_thread_subscription_state=None,
)
)
self.assertFalse(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": False,
},
None,
None,
msc4306_thread_subscription_state=None,
)
)
class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
"""Tests for the bulk push rule evaluator"""

View File

@@ -62,7 +62,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
"@test_user:example.org",
room_id,
thread_root_id,
automatic=True,
automatic_event_orderings=None,
)
)
updates.append(thread_root_id)
@@ -75,7 +75,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
"@test_user:example.org",
other_room_id,
other_thread_root_id,
automatic=False,
automatic_event_orderings=None,
)
)
@@ -124,7 +124,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
for user_id in users:
self.get_success(
store.subscribe_user_to_thread(
user_id, room_id, thread_root_id, automatic=True
user_id, room_id, thread_root_id, automatic_event_orderings=None
)
)

View File

@@ -15,6 +15,7 @@ from http import HTTPStatus
from twisted.internet.testing import MemoryReactor
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import login, profile, room, thread_subscriptions
from synapse.server import HomeServer
@@ -49,15 +50,16 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Create a room and send a message to use as a thread root
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
self.helper.join(self.room_id, self.other_user_id, tok=self.other_token)
response = self.helper.send(self.room_id, body="Root message", tok=self.token)
self.root_event_id = response["event_id"]
(self.root_event_id,) = self.helper.send_messages(
self.room_id, 1, tok=self.token
)
# Send a message in the thread
self.helper.send_event(
room_id=self.room_id,
type="m.room.message",
content={
"body": "Thread message",
self.threaded_events = self.helper.send_messages(
self.room_id,
2,
content_fn=lambda idx: {
"body": f"Thread message {idx}",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": "m.thread",
@@ -106,9 +108,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": False,
},
{},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -127,7 +127,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": True},
{"automatic": self.threaded_events[0]},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -148,11 +148,11 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": True,
"automatic": self.threaded_events[0],
},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body)
# Assert the subscription was saved
channel = self.make_request(
@@ -167,7 +167,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": False},
{},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -187,7 +187,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": True,
"automatic": self.threaded_events[0],
},
access_token=self.token,
)
@@ -202,7 +202,6 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.json_body, {"automatic": True})
# Now also register a manual subscription
channel = self.make_request(
"DELETE",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
@@ -210,7 +209,6 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK)
# Assert the manual subscription was not overridden
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
@@ -224,7 +222,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription",
{"automatic": True},
{},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
@@ -238,7 +236,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": True},
{},
access_token=no_access_token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
@@ -249,8 +247,105 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
# non-boolean `automatic`
{"automatic": "true"},
# non-Event ID `automatic`
{"automatic": True},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
# non-Event ID `automatic`
{"automatic": "$malformedEventId"},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
def test_auto_subscribe_cause_event_not_in_thread(self) -> None:
"""
Test making an automatic subscription, where the cause event is not
actually in the thread.
This is an error.
"""
(unrelated_event_id,) = self.helper.send_messages(
self.room_id, 1, tok=self.token
)
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": unrelated_event_id},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.text_body)
self.assertEqual(channel.json_body["errcode"], Codes.MSC4306_NOT_IN_THREAD)
def test_auto_resubscription_conflict(self) -> None:
"""
Test that an automatic subscription that conflicts with an unsubscription
is skipped.
"""
# Reuse the test that subscribes and unsubscribes
self.test_unsubscribe()
# Now no matter which event we present as the cause of an automatic subscription,
# the automatic subscription is skipped.
# This is because the unsubscription happened after all of the events.
for event in self.threaded_events:
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": event,
},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.CONFLICT, channel.text_body)
self.assertEqual(
channel.json_body["errcode"],
Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION,
channel.text_body,
)
# Check the subscription was not made
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
# But if a new event is sent after the unsubscription took place,
# that one can be used for an automatic subscription
(later_event_id,) = self.helper.send_messages(
self.room_id,
1,
content_fn=lambda _: {
"body": "Thread message after unsubscription",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": "m.thread",
"event_id": self.root_event_id,
},
},
tok=self.token,
)
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": later_event_id,
},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body)
# Check the subscription was made
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.json_body, {"automatic": True})

View File

@@ -29,12 +29,14 @@ from http import HTTPStatus
from typing import (
Any,
AnyStr,
Callable,
Dict,
Iterable,
Literal,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
overload,
)
@@ -45,7 +47,7 @@ import attr
from twisted.internet.testing import MemoryReactorClock
from twisted.web.server import Site
from synapse.api.constants import Membership, ReceiptTypes
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.errors import Codes
from synapse.server import HomeServer
from synapse.types import JsonDict
@@ -394,6 +396,32 @@ class RestHelper:
custom_headers=custom_headers,
)
def send_messages(
self,
room_id: str,
num_events: int,
content_fn: Callable[[int], JsonDict] = lambda idx: {
"msgtype": "m.text",
"body": f"Test event {idx}",
},
tok: Optional[str] = None,
) -> Sequence[str]:
"""
Helper to send a handful of sequential events and return their event IDs as a sequence.
"""
event_ids = []
for event_index in range(num_events):
response = self.send_event(
room_id,
EventTypes.Message,
content_fn(event_index),
tok=tok,
)
event_ids.append(response["event_id"])
return event_ids
def send_event(
self,
room_id: str,

View File

@@ -12,13 +12,18 @@
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
from typing import Optional
from typing import Optional, Union
from twisted.internet.testing import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.thread_subscriptions import (
AutomaticSubscriptionConflicted,
ThreadSubscriptionsWorkerStore,
)
from synapse.storage.engines.sqlite import Sqlite3Engine
from synapse.types import EventOrderings
from synapse.util import Clock
from tests import unittest
@@ -97,10 +102,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
self,
thread_root_id: str,
*,
automatic: bool,
automatic_event_orderings: Optional[EventOrderings],
room_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> Optional[int]:
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
if user_id is None:
user_id = self.user_id
@@ -112,7 +117,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
user_id,
room_id,
thread_root_id,
automatic=automatic,
automatic_event_orderings=automatic_event_orderings,
)
)
@@ -149,7 +154,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Subscribe
self._subscribe(
self.thread_root_id,
automatic=True,
automatic_event_orderings=EventOrderings(1, 1),
)
# Assert subscription went through
@@ -164,7 +169,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Now make it a manual subscription
self._subscribe(
self.thread_root_id,
automatic=False,
automatic_event_orderings=None,
)
# Assert the manual subscription overrode the automatic one
@@ -178,8 +183,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
def test_purge_thread_subscriptions_for_user(self) -> None:
"""Test purging all thread subscription settings for a user."""
# Set subscription settings for multiple threads
self._subscribe(self.thread_root_id, automatic=True)
self._subscribe(self.other_thread_root_id, automatic=False)
self._subscribe(
self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1)
)
self._subscribe(self.other_thread_root_id, automatic_event_orderings=None)
subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
@@ -217,20 +224,32 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
def test_get_updated_thread_subscriptions(self) -> None:
"""Test getting updated thread subscriptions since a stream ID."""
stream_id1 = self._subscribe(self.thread_root_id, automatic=False)
stream_id2 = self._subscribe(self.other_thread_root_id, automatic=True)
assert stream_id1 is not None
assert stream_id2 is not None
stream_id1 = self._subscribe(
self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1)
)
stream_id2 = self._subscribe(
self.other_thread_root_id, automatic_event_orderings=EventOrderings(2, 2)
)
assert stream_id1 is not None and not isinstance(
stream_id1, AutomaticSubscriptionConflicted
)
assert stream_id2 is not None and not isinstance(
stream_id2, AutomaticSubscriptionConflicted
)
# Get updates since initial ID (should include both changes)
updates = self.get_success(
self.store.get_updated_thread_subscriptions(0, stream_id2, 10)
self.store.get_updated_thread_subscriptions(
from_id=0, to_id=stream_id2, limit=10
)
)
self.assertEqual(len(updates), 2)
# Get updates since first change (should include only the second change)
updates = self.get_success(
self.store.get_updated_thread_subscriptions(stream_id1, stream_id2, 10)
self.store.get_updated_thread_subscriptions(
from_id=stream_id1, to_id=stream_id2, limit=10
)
)
self.assertEqual(
updates,
@@ -242,21 +261,27 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
other_user_id = "@other_user:test"
# Set thread subscription for main user
stream_id1 = self._subscribe(self.thread_root_id, automatic=True)
assert stream_id1 is not None
stream_id1 = self._subscribe(
self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1)
)
assert stream_id1 is not None and not isinstance(
stream_id1, AutomaticSubscriptionConflicted
)
# Set thread subscription for other user
stream_id2 = self._subscribe(
self.other_thread_root_id,
automatic=True,
automatic_event_orderings=EventOrderings(1, 1),
user_id=other_user_id,
)
assert stream_id2 is not None
assert stream_id2 is not None and not isinstance(
stream_id2, AutomaticSubscriptionConflicted
)
# Get updates for main user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.user_id, 0, stream_id2, 10
self.user_id, from_id=0, to_id=stream_id2, limit=10
)
)
self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)])
@@ -264,9 +289,80 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Get updates for other user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
other_user_id, 0, max(stream_id1, stream_id2), 10
other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10
)
)
self.assertEqual(
updates, [(stream_id2, self.room_id, self.other_thread_root_id)]
)
def test_should_skip_autosubscription_after_unsubscription(self) -> None:
"""
Tests the comparison logic for whether an autoscription should be skipped
due to a chronologically earlier but logically later unsubscription.
"""
func = ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription
# Order of arguments:
# automatic cause event: stream order, then topological order
# unsubscribe maximums: stream order, then tological order
# both orderings agree that the unsub is after the cause event
self.assertTrue(
func(autosub=EventOrderings(1, 1), unsubscribed_at=EventOrderings(2, 2))
)
# topological ordering is inconsistent with stream ordering,
# in that case favour stream ordering because it's what /sync uses
self.assertTrue(
func(autosub=EventOrderings(1, 2), unsubscribed_at=EventOrderings(2, 1))
)
# the automatic subscription is caused by a backfilled event here
# unfortunately we must fall back to topological ordering here
self.assertTrue(
func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 3))
)
self.assertFalse(
func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 1))
)
def test_get_subscribers_to_thread(self) -> None:
"""
Test getting all subscribers to a thread at once.
To check cache invalidations are correct, we do multiple
step-by-step rounds of subscription changes and assertions.
"""
other_user_id = "@other_user:test"
subscribers = self.get_success(
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
)
self.assertEqual(subscribers, frozenset())
self._subscribe(
self.thread_root_id, automatic_event_orderings=None, user_id=self.user_id
)
subscribers = self.get_success(
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
)
self.assertEqual(subscribers, frozenset((self.user_id,)))
self._subscribe(
self.thread_root_id, automatic_event_orderings=None, user_id=other_user_id
)
subscribers = self.get_success(
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
)
self.assertEqual(subscribers, frozenset((self.user_id, other_user_id)))
self._unsubscribe(self.thread_root_id, user_id=self.user_id)
subscribers = self.get_success(
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
)
self.assertEqual(subscribers, frozenset((other_user_id,)))