Merge branch 'develop' into release-v1.136
This commit is contained in:
1
changelog.d/18756.misc
Normal file
1
changelog.d/18756.misc
Normal file
@@ -0,0 +1 @@
|
||||
Update implementation of [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306) to include automatic subscription conflict prevention as introduced in later drafts.
|
||||
1
changelog.d/18762.feature
Normal file
1
changelog.d/18762.feature
Normal file
@@ -0,0 +1 @@
|
||||
Implement the push rules for experimental [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306).
|
||||
@@ -61,6 +61,7 @@ fn bench_match_exact(b: &mut Bencher) {
|
||||
vec![],
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -71,10 +72,10 @@ fn bench_match_exact(b: &mut Bencher) {
|
||||
},
|
||||
));
|
||||
|
||||
let matched = eval.match_condition(&condition, None, None).unwrap();
|
||||
let matched = eval.match_condition(&condition, None, None, None).unwrap();
|
||||
assert!(matched, "Didn't match");
|
||||
|
||||
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
|
||||
b.iter(|| eval.match_condition(&condition, None, None, None).unwrap());
|
||||
}
|
||||
|
||||
#[bench]
|
||||
@@ -107,6 +108,7 @@ fn bench_match_word(b: &mut Bencher) {
|
||||
vec![],
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -117,10 +119,10 @@ fn bench_match_word(b: &mut Bencher) {
|
||||
},
|
||||
));
|
||||
|
||||
let matched = eval.match_condition(&condition, None, None).unwrap();
|
||||
let matched = eval.match_condition(&condition, None, None, None).unwrap();
|
||||
assert!(matched, "Didn't match");
|
||||
|
||||
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
|
||||
b.iter(|| eval.match_condition(&condition, None, None, None).unwrap());
|
||||
}
|
||||
|
||||
#[bench]
|
||||
@@ -153,6 +155,7 @@ fn bench_match_word_miss(b: &mut Bencher) {
|
||||
vec![],
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -163,10 +166,10 @@ fn bench_match_word_miss(b: &mut Bencher) {
|
||||
},
|
||||
));
|
||||
|
||||
let matched = eval.match_condition(&condition, None, None).unwrap();
|
||||
let matched = eval.match_condition(&condition, None, None, None).unwrap();
|
||||
assert!(!matched, "Didn't match");
|
||||
|
||||
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
|
||||
b.iter(|| eval.match_condition(&condition, None, None, None).unwrap());
|
||||
}
|
||||
|
||||
#[bench]
|
||||
@@ -199,6 +202,7 @@ fn bench_eval_message(b: &mut Bencher) {
|
||||
vec![],
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -210,7 +214,8 @@ fn bench_eval_message(b: &mut Bencher) {
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
);
|
||||
|
||||
b.iter(|| eval.run(&rules, Some("bob"), Some("person")));
|
||||
b.iter(|| eval.run(&rules, Some("bob"), Some("person"), None));
|
||||
}
|
||||
|
||||
@@ -290,6 +290,26 @@ pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule {
|
||||
}];
|
||||
|
||||
pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
|
||||
PushRule {
|
||||
rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.unsubscribed_thread"),
|
||||
priority_class: 1,
|
||||
conditions: Cow::Borrowed(&[Condition::Known(
|
||||
KnownCondition::Msc4306ThreadSubscription { subscribed: false },
|
||||
)]),
|
||||
actions: Cow::Borrowed(&[]),
|
||||
default: true,
|
||||
default_enabled: true,
|
||||
},
|
||||
PushRule {
|
||||
rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.subscribed_thread"),
|
||||
priority_class: 1,
|
||||
conditions: Cow::Borrowed(&[Condition::Known(
|
||||
KnownCondition::Msc4306ThreadSubscription { subscribed: true },
|
||||
)]),
|
||||
actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]),
|
||||
default: true,
|
||||
default_enabled: true,
|
||||
},
|
||||
PushRule {
|
||||
rule_id: Cow::Borrowed("global/underride/.m.rule.call"),
|
||||
priority_class: 1,
|
||||
|
||||
@@ -106,8 +106,11 @@ pub struct PushRuleEvaluator {
|
||||
/// flag as MSC1767 (extensible events core).
|
||||
msc3931_enabled: bool,
|
||||
|
||||
// If MSC4210 (remove legacy mentions) is enabled.
|
||||
/// If MSC4210 (remove legacy mentions) is enabled.
|
||||
msc4210_enabled: bool,
|
||||
|
||||
/// If MSC4306 (thread subscriptions) is enabled.
|
||||
msc4306_enabled: bool,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
@@ -126,6 +129,7 @@ impl PushRuleEvaluator {
|
||||
room_version_feature_flags,
|
||||
msc3931_enabled,
|
||||
msc4210_enabled,
|
||||
msc4306_enabled,
|
||||
))]
|
||||
pub fn py_new(
|
||||
flattened_keys: BTreeMap<String, JsonValue>,
|
||||
@@ -138,6 +142,7 @@ impl PushRuleEvaluator {
|
||||
room_version_feature_flags: Vec<String>,
|
||||
msc3931_enabled: bool,
|
||||
msc4210_enabled: bool,
|
||||
msc4306_enabled: bool,
|
||||
) -> Result<Self, Error> {
|
||||
let body = match flattened_keys.get("content.body") {
|
||||
Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone().into_owned(),
|
||||
@@ -156,6 +161,7 @@ impl PushRuleEvaluator {
|
||||
room_version_feature_flags,
|
||||
msc3931_enabled,
|
||||
msc4210_enabled,
|
||||
msc4306_enabled,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -167,12 +173,19 @@ impl PushRuleEvaluator {
|
||||
///
|
||||
/// Returns the set of actions, if any, that match (filtering out any
|
||||
/// `dont_notify` and `coalesce` actions).
|
||||
#[pyo3(signature = (push_rules, user_id=None, display_name=None))]
|
||||
///
|
||||
/// msc4306_thread_subscription_state: (Only populated if MSC4306 is enabled)
|
||||
/// The thread subscription state corresponding to the thread containing this event.
|
||||
/// - `None` if the event is not in a thread, or if MSC4306 is disabled.
|
||||
/// - `Some(true)` if the event is in a thread and the user has a subscription for that thread
|
||||
/// - `Some(false)` if the event is in a thread and the user does NOT have a subscription for that thread
|
||||
#[pyo3(signature = (push_rules, user_id=None, display_name=None, msc4306_thread_subscription_state=None))]
|
||||
pub fn run(
|
||||
&self,
|
||||
push_rules: &FilteredPushRules,
|
||||
user_id: Option<&str>,
|
||||
display_name: Option<&str>,
|
||||
msc4306_thread_subscription_state: Option<bool>,
|
||||
) -> Vec<Action> {
|
||||
'outer: for (push_rule, enabled) in push_rules.iter() {
|
||||
if !enabled {
|
||||
@@ -204,7 +217,12 @@ impl PushRuleEvaluator {
|
||||
Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }),
|
||||
);
|
||||
|
||||
match self.match_condition(condition, user_id, display_name) {
|
||||
match self.match_condition(
|
||||
condition,
|
||||
user_id,
|
||||
display_name,
|
||||
msc4306_thread_subscription_state,
|
||||
) {
|
||||
Ok(true) => {}
|
||||
Ok(false) => continue 'outer,
|
||||
Err(err) => {
|
||||
@@ -237,14 +255,20 @@ impl PushRuleEvaluator {
|
||||
}
|
||||
|
||||
/// Check if the given condition matches.
|
||||
#[pyo3(signature = (condition, user_id=None, display_name=None))]
|
||||
#[pyo3(signature = (condition, user_id=None, display_name=None, msc4306_thread_subscription_state=None))]
|
||||
fn matches(
|
||||
&self,
|
||||
condition: Condition,
|
||||
user_id: Option<&str>,
|
||||
display_name: Option<&str>,
|
||||
msc4306_thread_subscription_state: Option<bool>,
|
||||
) -> bool {
|
||||
match self.match_condition(&condition, user_id, display_name) {
|
||||
match self.match_condition(
|
||||
&condition,
|
||||
user_id,
|
||||
display_name,
|
||||
msc4306_thread_subscription_state,
|
||||
) {
|
||||
Ok(true) => true,
|
||||
Ok(false) => false,
|
||||
Err(err) => {
|
||||
@@ -262,6 +286,7 @@ impl PushRuleEvaluator {
|
||||
condition: &Condition,
|
||||
user_id: Option<&str>,
|
||||
display_name: Option<&str>,
|
||||
msc4306_thread_subscription_state: Option<bool>,
|
||||
) -> Result<bool, Error> {
|
||||
let known_condition = match condition {
|
||||
Condition::Known(known) => known,
|
||||
@@ -393,6 +418,13 @@ impl PushRuleEvaluator {
|
||||
&& self.room_version_feature_flags.contains(&flag)
|
||||
}
|
||||
}
|
||||
KnownCondition::Msc4306ThreadSubscription { subscribed } => {
|
||||
if !self.msc4306_enabled {
|
||||
false
|
||||
} else {
|
||||
msc4306_thread_subscription_state == Some(*subscribed)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(result)
|
||||
@@ -536,10 +568,11 @@ fn push_rule_evaluator() {
|
||||
vec![],
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
|
||||
let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"), None);
|
||||
assert_eq!(result.len(), 3);
|
||||
}
|
||||
|
||||
@@ -566,6 +599,7 @@ fn test_requires_room_version_supports_condition() {
|
||||
flags,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -575,6 +609,7 @@ fn test_requires_room_version_supports_condition() {
|
||||
&FilteredPushRules::default(),
|
||||
Some("@bob:example.org"),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
assert_eq!(result.len(), 3);
|
||||
|
||||
@@ -593,7 +628,17 @@ fn test_requires_room_version_supports_condition() {
|
||||
};
|
||||
let rules = PushRules::new(vec![custom_rule]);
|
||||
result = evaluator.run(
|
||||
&FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false, false),
|
||||
&FilteredPushRules::py_new(
|
||||
rules,
|
||||
BTreeMap::new(),
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
@@ -369,6 +369,10 @@ pub enum KnownCondition {
|
||||
RoomVersionSupports {
|
||||
feature: Cow<'static, str>,
|
||||
},
|
||||
#[serde(rename = "io.element.msc4306.thread_subscription")]
|
||||
Msc4306ThreadSubscription {
|
||||
subscribed: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl<'source> IntoPyObject<'source> for Condition {
|
||||
@@ -547,11 +551,13 @@ pub struct FilteredPushRules {
|
||||
msc3664_enabled: bool,
|
||||
msc4028_push_encrypted_events: bool,
|
||||
msc4210_enabled: bool,
|
||||
msc4306_enabled: bool,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl FilteredPushRules {
|
||||
#[new]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn py_new(
|
||||
push_rules: PushRules,
|
||||
enabled_map: BTreeMap<String, bool>,
|
||||
@@ -560,6 +566,7 @@ impl FilteredPushRules {
|
||||
msc3664_enabled: bool,
|
||||
msc4028_push_encrypted_events: bool,
|
||||
msc4210_enabled: bool,
|
||||
msc4306_enabled: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
push_rules,
|
||||
@@ -569,6 +576,7 @@ impl FilteredPushRules {
|
||||
msc3664_enabled,
|
||||
msc4028_push_encrypted_events,
|
||||
msc4210_enabled,
|
||||
msc4306_enabled,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -619,6 +627,10 @@ impl FilteredPushRules {
|
||||
return false;
|
||||
}
|
||||
|
||||
if !self.msc4306_enabled && rule.rule_id.contains("/.io.element.msc4306.rule.") {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
})
|
||||
.map(|r| {
|
||||
|
||||
@@ -140,6 +140,12 @@ class Codes(str, Enum):
|
||||
# Part of MSC4155
|
||||
INVITE_BLOCKED = "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED"
|
||||
|
||||
# Part of MSC4306: Thread Subscriptions
|
||||
MSC4306_CONFLICTING_UNSUBSCRIPTION = (
|
||||
"IO.ELEMENT.MSC4306.M_CONFLICTING_UNSUBSCRIPTION"
|
||||
)
|
||||
MSC4306_NOT_IN_THREAD = "IO.ELEMENT.MSC4306.M_NOT_IN_THREAD"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
"""An exception with integer code, a message string attributes and optional headers.
|
||||
|
||||
@@ -37,7 +37,6 @@ Events are replicated via a separate events stream.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Dict,
|
||||
@@ -68,25 +67,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QueueNames(str, Enum):
|
||||
PRESENCE_MAP = "presence_map"
|
||||
KEYED_EDU = "keyed_edu"
|
||||
KEYED_EDU_CHANGED = "keyed_edu_changed"
|
||||
EDUS = "edus"
|
||||
POS_TIME = "pos_time"
|
||||
PRESENCE_DESTINATIONS = "presence_destinations"
|
||||
|
||||
|
||||
queue_name_to_gauge_map: Dict[QueueNames, LaterGauge] = {}
|
||||
|
||||
for queue_name in QueueNames:
|
||||
queue_name_to_gauge_map[queue_name] = LaterGauge(
|
||||
name=f"synapse_federation_send_queue_{queue_name.value}_size",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
|
||||
class FederationRemoteSendQueue(AbstractFederationSender):
|
||||
"""A drop in replacement for FederationSender"""
|
||||
|
||||
@@ -131,15 +111,23 @@ class FederationRemoteSendQueue(AbstractFederationSender):
|
||||
# we make a new function, so we need to make a new function so the inner
|
||||
# lambda binds to the queue rather than to the name of the queue which
|
||||
# changes. ARGH.
|
||||
def register(queue_name: QueueNames, queue: Sized) -> None:
|
||||
queue_name_to_gauge_map[queue_name].register_hook(
|
||||
lambda: {(self.server_name,): len(queue)}
|
||||
def register(name: str, queue: Sized) -> None:
|
||||
LaterGauge(
|
||||
name="synapse_federation_send_queue_%s_size" % (queue_name,),
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=lambda: {(self.server_name,): len(queue)},
|
||||
)
|
||||
|
||||
for queue_name in QueueNames:
|
||||
queue = getattr(self, queue_name.value)
|
||||
assert isinstance(queue, Sized)
|
||||
register(queue_name, queue=queue)
|
||||
for queue_name in [
|
||||
"presence_map",
|
||||
"keyed_edu",
|
||||
"keyed_edu_changed",
|
||||
"edus",
|
||||
"pos_time",
|
||||
"presence_destinations",
|
||||
]:
|
||||
register(queue_name, getattr(self, queue_name))
|
||||
|
||||
self.clock.looping_call(self._clear_queue, 30 * 1000)
|
||||
|
||||
|
||||
@@ -199,24 +199,6 @@ sent_pdus_destination_dist_total = Counter(
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
transaction_queue_pending_destinations_gauge = LaterGauge(
|
||||
name="synapse_federation_transaction_queue_pending_destinations",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
transaction_queue_pending_pdus_gauge = LaterGauge(
|
||||
name="synapse_federation_transaction_queue_pending_pdus",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
transaction_queue_pending_edus_gauge = LaterGauge(
|
||||
name="synapse_federation_transaction_queue_pending_edus",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
# Time (in s) to wait before trying to wake up destinations that have
|
||||
# catch-up outstanding.
|
||||
# Please note that rate limiting still applies, so while the loop is
|
||||
@@ -416,28 +398,38 @@ class FederationSender(AbstractFederationSender):
|
||||
# map from destination to PerDestinationQueue
|
||||
self._per_destination_queues: Dict[str, PerDestinationQueue] = {}
|
||||
|
||||
transaction_queue_pending_destinations_gauge.register_hook(
|
||||
lambda: {
|
||||
LaterGauge(
|
||||
name="synapse_federation_transaction_queue_pending_destinations",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=lambda: {
|
||||
(self.server_name,): sum(
|
||||
1
|
||||
for d in self._per_destination_queues.values()
|
||||
if d.transmission_loop_running
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
transaction_queue_pending_pdus_gauge.register_hook(
|
||||
lambda: {
|
||||
|
||||
LaterGauge(
|
||||
name="synapse_federation_transaction_queue_pending_pdus",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=lambda: {
|
||||
(self.server_name,): sum(
|
||||
d.pending_pdu_count() for d in self._per_destination_queues.values()
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
transaction_queue_pending_edus_gauge.register_hook(
|
||||
lambda: {
|
||||
LaterGauge(
|
||||
name="synapse_federation_transaction_queue_pending_edus",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=lambda: {
|
||||
(self.server_name,): sum(
|
||||
d.pending_edu_count() for d in self._per_destination_queues.values()
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self._is_processing = False
|
||||
|
||||
@@ -173,18 +173,6 @@ state_transition_counter = Counter(
|
||||
labelnames=["locality", "from", "to", SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
presence_user_to_current_state_size_gauge = LaterGauge(
|
||||
name="synapse_handlers_presence_user_to_current_state_size",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
presence_wheel_timer_size_gauge = LaterGauge(
|
||||
name="synapse_handlers_presence_wheel_timer_size",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them
|
||||
# "currently_active"
|
||||
LAST_ACTIVE_GRANULARITY = 60 * 1000
|
||||
@@ -791,8 +779,11 @@ class PresenceHandler(BasePresenceHandler):
|
||||
EduTypes.PRESENCE, self.incoming_presence
|
||||
)
|
||||
|
||||
presence_user_to_current_state_size_gauge.register_hook(
|
||||
lambda: {(self.server_name,): len(self.user_to_current_state)}
|
||||
LaterGauge(
|
||||
name="synapse_handlers_presence_user_to_current_state_size",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=lambda: {(self.server_name,): len(self.user_to_current_state)},
|
||||
)
|
||||
|
||||
# The per-device presence state, maps user to devices to per-device presence state.
|
||||
@@ -891,8 +882,11 @@ class PresenceHandler(BasePresenceHandler):
|
||||
60 * 1000,
|
||||
)
|
||||
|
||||
presence_wheel_timer_size_gauge.register_hook(
|
||||
lambda: {(self.server_name,): len(self.wheel_timer)}
|
||||
LaterGauge(
|
||||
name="synapse_handlers_presence_wheel_timer_size",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=lambda: {(self.server_name,): len(self.wheel_timer)},
|
||||
)
|
||||
|
||||
# Used to handle sending of presence to newly joined users/servers
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.errors import AuthError, NotFoundError
|
||||
from synapse.storage.databases.main.thread_subscriptions import ThreadSubscription
|
||||
from synapse.types import UserID
|
||||
from synapse.api.constants import RelationTypes
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.events import relation_from_event
|
||||
from synapse.storage.databases.main.thread_subscriptions import (
|
||||
AutomaticSubscriptionConflicted,
|
||||
ThreadSubscription,
|
||||
)
|
||||
from synapse.types import EventOrderings, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -55,42 +61,79 @@ class ThreadSubscriptionsHandler:
|
||||
room_id: str,
|
||||
thread_root_event_id: str,
|
||||
*,
|
||||
automatic: bool,
|
||||
automatic_event_id: Optional[str],
|
||||
) -> Optional[int]:
|
||||
"""Sets or updates a user's subscription settings for a specific thread root.
|
||||
|
||||
Args:
|
||||
requester_user_id: The ID of the user whose settings are being updated.
|
||||
thread_root_event_id: The event ID of the thread root.
|
||||
automatic: whether the user was subscribed by an automatic decision by
|
||||
their client.
|
||||
automatic_event_id: if the user was subscribed by an automatic decision by
|
||||
their client, the event ID that caused this.
|
||||
|
||||
Returns:
|
||||
The stream ID for this update, if the update isn't no-opped.
|
||||
|
||||
Raises:
|
||||
NotFoundError if the user cannot access the thread root event, or it isn't
|
||||
known to this homeserver.
|
||||
known to this homeserver. Ditto for the automatic cause event if supplied.
|
||||
|
||||
SynapseError(400, M_NOT_IN_THREAD): if client supplied an automatic cause event
|
||||
but user cannot access the event.
|
||||
|
||||
SynapseError(409, M_SKIPPED): if client requested an automatic subscription
|
||||
but it was skipped because the cause event is logically later than an unsubscription.
|
||||
"""
|
||||
# First check that the user can access the thread root event
|
||||
# and that it exists
|
||||
try:
|
||||
event = await self.event_handler.get_event(
|
||||
thread_root_event = await self.event_handler.get_event(
|
||||
user_id, room_id, thread_root_event_id
|
||||
)
|
||||
if event is None:
|
||||
if thread_root_event is None:
|
||||
raise NotFoundError("No such thread root")
|
||||
except AuthError:
|
||||
logger.info("rejecting thread subscriptions change (thread not accessible)")
|
||||
raise NotFoundError("No such thread root")
|
||||
|
||||
return await self.store.subscribe_user_to_thread(
|
||||
if automatic_event_id:
|
||||
autosub_cause_event = await self.event_handler.get_event(
|
||||
user_id, room_id, automatic_event_id
|
||||
)
|
||||
if autosub_cause_event is None:
|
||||
raise NotFoundError("Automatic subscription event not found")
|
||||
relation = relation_from_event(autosub_cause_event)
|
||||
if (
|
||||
relation is None
|
||||
or relation.rel_type != RelationTypes.THREAD
|
||||
or relation.parent_id != thread_root_event_id
|
||||
):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Automatic subscription must use an event in the thread",
|
||||
errcode=Codes.MSC4306_NOT_IN_THREAD,
|
||||
)
|
||||
|
||||
automatic_event_orderings = EventOrderings.from_event(autosub_cause_event)
|
||||
else:
|
||||
automatic_event_orderings = None
|
||||
|
||||
outcome = await self.store.subscribe_user_to_thread(
|
||||
user_id.to_string(),
|
||||
event.room_id,
|
||||
room_id,
|
||||
thread_root_event_id,
|
||||
automatic=automatic,
|
||||
automatic_event_orderings=automatic_event_orderings,
|
||||
)
|
||||
|
||||
if isinstance(outcome, AutomaticSubscriptionConflicted):
|
||||
raise SynapseError(
|
||||
HTTPStatus.CONFLICT,
|
||||
"Automatic subscription obsoleted by an unsubscription request.",
|
||||
errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION,
|
||||
)
|
||||
|
||||
return outcome
|
||||
|
||||
async def unsubscribe_user_from_thread(
|
||||
self, user_id: UserID, room_id: str, thread_root_event_id: str
|
||||
) -> Optional[int]:
|
||||
|
||||
@@ -164,12 +164,12 @@ def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]:
|
||||
return counts
|
||||
|
||||
|
||||
in_flight_requests = LaterGauge(
|
||||
LaterGauge(
|
||||
name="synapse_http_server_in_flight_requests_count",
|
||||
desc="",
|
||||
labelnames=["method", "servlet", SERVER_NAME_LABEL],
|
||||
caller=_get_in_flight_counts,
|
||||
)
|
||||
in_flight_requests.register_hook(_get_in_flight_counts)
|
||||
|
||||
|
||||
class RequestMetrics:
|
||||
|
||||
@@ -31,7 +31,6 @@ from typing import (
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
@@ -74,6 +73,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
METRICS_PREFIX = "/_synapse/metrics"
|
||||
|
||||
all_gauges: Dict[str, Collector] = {}
|
||||
|
||||
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
|
||||
|
||||
SERVER_NAME_LABEL = "server_name"
|
||||
@@ -162,47 +163,42 @@ class LaterGauge(Collector):
|
||||
name: str
|
||||
desc: str
|
||||
labelnames: Optional[StrSequence] = attr.ib(hash=False)
|
||||
# List of callbacks: each callback should either return a value (if there are no
|
||||
# labels for this metric), or dict mapping from a label tuple to a value
|
||||
_hooks: List[
|
||||
Callable[
|
||||
[], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
|
||||
]
|
||||
] = attr.ib(factory=list, hash=False)
|
||||
# callback: should either return a value (if there are no labels for this metric),
|
||||
# or dict mapping from a label tuple to a value
|
||||
caller: Callable[
|
||||
[], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
|
||||
]
|
||||
|
||||
def collect(self) -> Iterable[Metric]:
|
||||
# The decision to add `SERVER_NAME_LABEL` is from the `LaterGauge` usage itself
|
||||
# (we don't enforce it here, one level up).
|
||||
g = GaugeMetricFamily(self.name, self.desc, labels=self.labelnames) # type: ignore[missing-server-name-label]
|
||||
|
||||
for hook in self._hooks:
|
||||
try:
|
||||
hook_result = hook()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Exception running callback for LaterGauge(%s)", self.name
|
||||
)
|
||||
yield g
|
||||
return
|
||||
|
||||
if isinstance(hook_result, (int, float)):
|
||||
g.add_metric([], hook_result)
|
||||
else:
|
||||
for k, v in hook_result.items():
|
||||
g.add_metric(k, v)
|
||||
|
||||
try:
|
||||
calls = self.caller()
|
||||
except Exception:
|
||||
logger.exception("Exception running callback for LaterGauge(%s)", self.name)
|
||||
yield g
|
||||
return
|
||||
|
||||
def register_hook(
|
||||
self,
|
||||
hook: Callable[
|
||||
[], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
|
||||
],
|
||||
) -> None:
|
||||
self._hooks.append(hook)
|
||||
if isinstance(calls, (int, float)):
|
||||
g.add_metric([], calls)
|
||||
else:
|
||||
for k, v in calls.items():
|
||||
g.add_metric(k, v)
|
||||
|
||||
yield g
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
self._register()
|
||||
|
||||
def _register(self) -> None:
|
||||
if self.name in all_gauges.keys():
|
||||
logger.warning("%s already registered, reregistering", self.name)
|
||||
REGISTRY.unregister(all_gauges.pop(self.name))
|
||||
|
||||
REGISTRY.register(self)
|
||||
all_gauges[self.name] = self
|
||||
|
||||
|
||||
# `MetricsEntry` only makes sense when it is a `Protocol`,
|
||||
@@ -254,7 +250,7 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
|
||||
# Protects access to _registrations
|
||||
self._lock = threading.Lock()
|
||||
|
||||
REGISTRY.register(self)
|
||||
self._register_with_collector()
|
||||
|
||||
def register(
|
||||
self,
|
||||
@@ -345,6 +341,14 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
|
||||
gauge.add_metric(labels=key, value=getattr(metrics, name))
|
||||
yield gauge
|
||||
|
||||
def _register_with_collector(self) -> None:
|
||||
if self.name in all_gauges.keys():
|
||||
logger.warning("%s already registered, reregistering", self.name)
|
||||
REGISTRY.unregister(all_gauges.pop(self.name))
|
||||
|
||||
REGISTRY.register(self)
|
||||
all_gauges[self.name] = self
|
||||
|
||||
|
||||
class GaugeHistogramMetricFamilyWithLabels(GaugeHistogramMetricFamily):
|
||||
"""
|
||||
|
||||
@@ -86,24 +86,6 @@ users_woken_by_stream_counter = Counter(
|
||||
labelnames=["stream", SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
|
||||
notifier_listeners_gauge = LaterGauge(
|
||||
name="synapse_notifier_listeners",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
notifier_rooms_gauge = LaterGauge(
|
||||
name="synapse_notifier_rooms",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
notifier_users_gauge = LaterGauge(
|
||||
name="synapse_notifier_users",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -299,16 +281,28 @@ class Notifier:
|
||||
)
|
||||
}
|
||||
|
||||
notifier_listeners_gauge.register_hook(count_listeners)
|
||||
notifier_rooms_gauge.register_hook(
|
||||
lambda: {
|
||||
LaterGauge(
|
||||
name="synapse_notifier_listeners",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=count_listeners,
|
||||
)
|
||||
|
||||
LaterGauge(
|
||||
name="synapse_notifier_rooms",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=lambda: {
|
||||
(self.server_name,): count(
|
||||
bool, list(self.room_to_user_streams.values())
|
||||
)
|
||||
}
|
||||
},
|
||||
)
|
||||
notifier_users_gauge.register_hook(
|
||||
lambda: {(self.server_name,): len(self.user_to_user_stream)}
|
||||
LaterGauge(
|
||||
name="synapse_notifier_users",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=lambda: {(self.server_name,): len(self.user_to_user_stream)},
|
||||
)
|
||||
|
||||
def add_replication_callback(self, cb: Callable[[], None]) -> None:
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import (
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
@@ -477,8 +478,18 @@ class BulkPushRuleEvaluator:
|
||||
event.room_version.msc3931_push_features,
|
||||
self.hs.config.experimental.msc1767_enabled, # MSC3931 flag
|
||||
self.hs.config.experimental.msc4210_enabled,
|
||||
self.hs.config.experimental.msc4306_enabled,
|
||||
)
|
||||
|
||||
msc4306_thread_subscribers: Optional[FrozenSet[str]] = None
|
||||
if self.hs.config.experimental.msc4306_enabled and thread_id != MAIN_TIMELINE:
|
||||
# pull out, in batch, all local subscribers to this thread
|
||||
# (in the common case, they will all be getting processed for push
|
||||
# rules right now)
|
||||
msc4306_thread_subscribers = await self.store.get_subscribers_to_thread(
|
||||
event.room_id, thread_id
|
||||
)
|
||||
|
||||
for uid, rules in rules_by_user.items():
|
||||
if event.sender == uid:
|
||||
continue
|
||||
@@ -503,7 +514,13 @@ class BulkPushRuleEvaluator:
|
||||
# current user, it'll be added to the dict later.
|
||||
actions_by_user[uid] = []
|
||||
|
||||
actions = evaluator.run(rules, uid, display_name)
|
||||
msc4306_thread_subscription_state: Optional[bool] = None
|
||||
if msc4306_thread_subscribers is not None:
|
||||
msc4306_thread_subscription_state = uid in msc4306_thread_subscribers
|
||||
|
||||
actions = evaluator.run(
|
||||
rules, uid, display_name, msc4306_thread_subscription_state
|
||||
)
|
||||
if "notify" in actions:
|
||||
# Push rules say we should notify the user of this event
|
||||
actions_by_user[uid] = actions
|
||||
|
||||
@@ -106,18 +106,6 @@ user_ip_cache_counter = Counter(
|
||||
"synapse_replication_tcp_resource_user_ip_cache", "", labelnames=[SERVER_NAME_LABEL]
|
||||
)
|
||||
|
||||
tcp_resource_total_connections_gauge = LaterGauge(
|
||||
name="synapse_replication_tcp_resource_total_connections",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
tcp_command_queue_gauge = LaterGauge(
|
||||
name="synapse_replication_tcp_command_queue",
|
||||
desc="Number of inbound RDATA/POSITION commands queued for processing",
|
||||
labelnames=["stream_name", SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
|
||||
# the type of the entries in _command_queues_by_stream
|
||||
_StreamCommandQueue = Deque[
|
||||
@@ -255,8 +243,11 @@ class ReplicationCommandHandler:
|
||||
# outgoing replication commands to.)
|
||||
self._connections: List[IReplicationConnection] = []
|
||||
|
||||
tcp_resource_total_connections_gauge.register_hook(
|
||||
lambda: {(self.server_name,): len(self._connections)}
|
||||
LaterGauge(
|
||||
name="synapse_replication_tcp_resource_total_connections",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=lambda: {(self.server_name,): len(self._connections)},
|
||||
)
|
||||
|
||||
# When POSITION or RDATA commands arrive, we stick them in a queue and process
|
||||
@@ -275,11 +266,14 @@ class ReplicationCommandHandler:
|
||||
# from that connection.
|
||||
self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {}
|
||||
|
||||
tcp_command_queue_gauge.register_hook(
|
||||
lambda: {
|
||||
LaterGauge(
|
||||
name="synapse_replication_tcp_command_queue",
|
||||
desc="Number of inbound RDATA/POSITION commands queued for processing",
|
||||
labelnames=["stream_name", SERVER_NAME_LABEL],
|
||||
caller=lambda: {
|
||||
(stream_name, self.server_name): len(queue)
|
||||
for stream_name, queue in self._command_queues_by_stream.items()
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
self._is_master = hs.config.worker.worker_app is None
|
||||
|
||||
@@ -527,11 +527,9 @@ pending_commands = LaterGauge(
|
||||
name="synapse_replication_tcp_protocol_pending_commands",
|
||||
desc="",
|
||||
labelnames=["name", SERVER_NAME_LABEL],
|
||||
)
|
||||
pending_commands.register_hook(
|
||||
lambda: {
|
||||
caller=lambda: {
|
||||
(p.name, p.server_name): len(p.pending_commands) for p in connected_connections
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -546,11 +544,9 @@ transport_send_buffer = LaterGauge(
|
||||
name="synapse_replication_tcp_protocol_transport_send_buffer",
|
||||
desc="",
|
||||
labelnames=["name", SERVER_NAME_LABEL],
|
||||
)
|
||||
transport_send_buffer.register_hook(
|
||||
lambda: {
|
||||
caller=lambda: {
|
||||
(p.name, p.server_name): transport_buffer_size(p) for p in connected_connections
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -575,12 +571,10 @@ tcp_transport_kernel_send_buffer = LaterGauge(
|
||||
name="synapse_replication_tcp_protocol_transport_kernel_send_buffer",
|
||||
desc="",
|
||||
labelnames=["name", SERVER_NAME_LABEL],
|
||||
)
|
||||
tcp_transport_kernel_send_buffer.register_hook(
|
||||
lambda: {
|
||||
caller=lambda: {
|
||||
(p.name, p.server_name): transport_kernel_read_buffer_size(p, False)
|
||||
for p in connected_connections
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -588,10 +582,8 @@ tcp_transport_kernel_read_buffer = LaterGauge(
|
||||
name="synapse_replication_tcp_protocol_transport_kernel_read_buffer",
|
||||
desc="",
|
||||
labelnames=["name", SERVER_NAME_LABEL],
|
||||
)
|
||||
tcp_transport_kernel_read_buffer.register_hook(
|
||||
lambda: {
|
||||
caller=lambda: {
|
||||
(p.name, p.server_name): transport_kernel_read_buffer_size(p, True)
|
||||
for p in connected_connections
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
@@ -739,7 +739,7 @@ class ThreadSubscriptionsStream(_StreamFromIdGen):
|
||||
NAME = "thread_subscriptions"
|
||||
ROW_TYPE = ThreadSubscriptionsStreamRow
|
||||
|
||||
def __init__(self, hs: Any):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
@@ -751,7 +751,7 @@ class ThreadSubscriptionsStream(_StreamFromIdGen):
|
||||
self, instance_name: str, from_token: int, to_token: int, limit: int
|
||||
) -> StreamUpdateResult:
|
||||
updates = await self.store.get_updated_thread_subscriptions(
|
||||
from_token, to_token, limit
|
||||
from_id=from_token, to_id=to_token, limit=limit
|
||||
)
|
||||
rows = [
|
||||
(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from synapse._pydantic_compat import StrictBool
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
@@ -12,6 +11,7 @@ from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns
|
||||
from synapse.types import JsonDict, RoomID
|
||||
from synapse.types.rest import RequestBodyModel
|
||||
from synapse.util.pydantic_models import AnyEventId
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -32,7 +32,12 @@ class ThreadSubscriptionsRestServlet(RestServlet):
|
||||
self.handler = hs.get_thread_subscriptions_handler()
|
||||
|
||||
class PutBody(RequestBodyModel):
|
||||
automatic: StrictBool
|
||||
automatic: Optional[AnyEventId]
|
||||
"""
|
||||
If supplied, the event ID of an event giving rise to this automatic subscription.
|
||||
|
||||
If omitted, this subscription is a manual subscription.
|
||||
"""
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str, thread_root_id: str
|
||||
@@ -63,15 +68,15 @@ class ThreadSubscriptionsRestServlet(RestServlet):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM
|
||||
)
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
body = parse_and_validate_json_object_from_request(request, self.PutBody)
|
||||
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
await self.handler.subscribe_user_to_thread(
|
||||
requester.user,
|
||||
room_id,
|
||||
thread_root_id,
|
||||
automatic=body.automatic,
|
||||
automatic_event_id=body.automatic,
|
||||
)
|
||||
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
@@ -100,12 +100,6 @@ sql_txn_duration = Counter(
|
||||
labelnames=["desc", SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
background_update_status = LaterGauge(
|
||||
name="synapse_background_update_status",
|
||||
desc="Background update status",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
|
||||
# Unique indexes which have been added in background updates. Maps from table name
|
||||
# to the name of the background update which added the unique index to that table.
|
||||
@@ -617,8 +611,11 @@ class DatabasePool:
|
||||
)
|
||||
|
||||
self.updates = BackgroundUpdater(hs, self)
|
||||
background_update_status.register_hook(
|
||||
lambda: {(self.server_name,): self.updates.get_status()},
|
||||
LaterGauge(
|
||||
name="synapse_background_update_status",
|
||||
desc="Background update status",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=lambda: {(self.server_name,): self.updates.get_status()},
|
||||
)
|
||||
|
||||
self._previous_txn_total_time = 0.0
|
||||
|
||||
@@ -110,6 +110,7 @@ def _load_rules(
|
||||
msc3381_polls_enabled=experimental_config.msc3381_polls_enabled,
|
||||
msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events,
|
||||
msc4210_enabled=experimental_config.msc4210_enabled,
|
||||
msc4306_enabled=experimental_config.msc4306_enabled,
|
||||
)
|
||||
|
||||
return filtered_rules
|
||||
|
||||
@@ -84,13 +84,6 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
|
||||
_POPULATE_PARTICIPANT_BG_UPDATE_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
federation_known_servers_gauge = LaterGauge(
|
||||
name="synapse_federation_known_servers",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
class EventIdMembership:
|
||||
"""Returned by `get_membership_from_event_ids`"""
|
||||
@@ -123,8 +116,11 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
|
||||
1,
|
||||
self._count_known_servers,
|
||||
)
|
||||
federation_known_servers_gauge.register_hook(
|
||||
lambda: {(self.server_name,): self._known_servers_count}
|
||||
LaterGauge(
|
||||
name="synapse_federation_known_servers",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=lambda: {(self.server_name,): self._known_servers_count},
|
||||
)
|
||||
|
||||
@wrap_as_background_process("_count_known_servers")
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
@@ -33,6 +33,7 @@ from synapse.storage.database import (
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import EventOrderings
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -50,6 +51,14 @@ class ThreadSubscription:
|
||||
"""
|
||||
|
||||
|
||||
class AutomaticSubscriptionConflicted:
|
||||
"""
|
||||
Marker return value to signal that an automatic subscription was skipped,
|
||||
because it conflicted with an unsubscription that we consider to have
|
||||
been made later than the event causing the automatic subscription.
|
||||
"""
|
||||
|
||||
|
||||
class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -91,6 +100,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||
self.get_subscription_for_thread.invalidate(
|
||||
(row.user_id, row.room_id, row.event_id)
|
||||
)
|
||||
self.get_subscribers_to_thread.invalidate((row.room_id, row.event_id))
|
||||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -101,75 +111,196 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||
self._thread_subscriptions_id_gen.advance(instance_name, token)
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
@staticmethod
|
||||
def _should_skip_autosubscription_after_unsubscription(
|
||||
*,
|
||||
autosub: EventOrderings,
|
||||
unsubscribed_at: EventOrderings,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns whether an automatic subscription occurring *after* an unsubscription
|
||||
should be skipped, because the unsubscription already 'acknowledges' the event
|
||||
causing the automatic subscription (the cause event).
|
||||
|
||||
To determine *after*, we use `stream_ordering` unless the event is backfilled
|
||||
(negative `stream_ordering`) and fallback to topological ordering.
|
||||
|
||||
Args:
|
||||
autosub: the stream_ordering and topological_ordering of the cause event
|
||||
unsubscribed_at:
|
||||
the maximum stream ordering and the maximum topological ordering at the time of unsubscription
|
||||
|
||||
Returns:
|
||||
True if the automatic subscription should be skipped
|
||||
"""
|
||||
# For normal rooms, these two orderings should be positive, because
|
||||
# they don't refer to a specific event but rather the maximum at the
|
||||
# time of unsubscription.
|
||||
#
|
||||
# However, for rooms that have never been joined and that are being peeked at,
|
||||
# we might not have a single non-backfilled event and therefore the stream
|
||||
# ordering might be negative, so we don't assert this case.
|
||||
assert unsubscribed_at.topological > 0
|
||||
|
||||
unsubscribed_at_backfilled = unsubscribed_at.stream < 0
|
||||
if (
|
||||
not unsubscribed_at_backfilled
|
||||
and unsubscribed_at.stream >= autosub.stream > 0
|
||||
):
|
||||
# non-backfilled events: the unsubscription is later according to
|
||||
# the stream
|
||||
return True
|
||||
|
||||
if autosub.stream < 0:
|
||||
# the auto-subscription cause event was backfilled, so fall back to
|
||||
# topological ordering
|
||||
if unsubscribed_at.topological >= autosub.topological:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def subscribe_user_to_thread(
|
||||
self, user_id: str, room_id: str, thread_root_event_id: str, *, automatic: bool
|
||||
) -> Optional[int]:
|
||||
self,
|
||||
user_id: str,
|
||||
room_id: str,
|
||||
thread_root_event_id: str,
|
||||
*,
|
||||
automatic_event_orderings: Optional[EventOrderings],
|
||||
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
|
||||
"""Updates a user's subscription settings for a specific thread root.
|
||||
|
||||
If no change would be made to the subscription, does not produce any database change.
|
||||
|
||||
Case-by-case:
|
||||
- if we already have an automatic subscription:
|
||||
- new automatic subscriptions will be no-ops (no database write),
|
||||
- new manual subscriptions will overwrite the automatic subscription
|
||||
- if we already have a manual subscription:
|
||||
we don't update (no database write) in either case, because:
|
||||
- the existing manual subscription wins over a new automatic subscription request
|
||||
- there would be no need to write a manual subscription because we already have one
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user whose settings are being updated.
|
||||
room_id: The ID of the room the thread root belongs to.
|
||||
thread_root_event_id: The event ID of the thread root.
|
||||
automatic: Whether the subscription was performed automatically by the user's client.
|
||||
Only `False` will overwrite an existing value of automatic for a subscription row.
|
||||
automatic_event_orderings:
|
||||
Value depends on whether the subscription was performed automatically by the user's client.
|
||||
For manual subscriptions: None.
|
||||
For automatic subscriptions: the orderings of the event.
|
||||
|
||||
Returns:
|
||||
The stream ID for this update, if the update isn't no-opped.
|
||||
If a subscription is made: (int) the stream ID for this update.
|
||||
If a subscription already exists and did not need to be updated: None
|
||||
If an automatic subscription conflicted with an unsubscription: AutomaticSubscriptionConflicted
|
||||
"""
|
||||
assert self._can_write_to_thread_subscriptions
|
||||
|
||||
def _subscribe_user_to_thread_txn(txn: LoggingTransaction) -> Optional[int]:
|
||||
already_automatic = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="thread_subscriptions",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"event_id": thread_root_event_id,
|
||||
"room_id": room_id,
|
||||
"subscribed": True,
|
||||
},
|
||||
retcol="automatic",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if already_automatic is None:
|
||||
already_subscribed = False
|
||||
already_automatic = True
|
||||
else:
|
||||
already_subscribed = True
|
||||
# convert int (SQLite bool) to Python bool
|
||||
already_automatic = bool(already_automatic)
|
||||
|
||||
if already_subscribed and already_automatic == automatic:
|
||||
# there is nothing we need to do here
|
||||
return None
|
||||
|
||||
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
|
||||
|
||||
values: Dict[str, Optional[Union[bool, int, str]]] = {
|
||||
"subscribed": True,
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"automatic": already_automatic and automatic,
|
||||
}
|
||||
|
||||
self.db_pool.simple_upsert_txn(
|
||||
txn,
|
||||
table="thread_subscriptions",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"event_id": thread_root_event_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
values=values,
|
||||
)
|
||||
|
||||
def _invalidate_subscription_caches(txn: LoggingTransaction) -> None:
|
||||
txn.call_after(
|
||||
self.get_subscription_for_thread.invalidate,
|
||||
(user_id, room_id, thread_root_event_id),
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_subscribers_to_thread.invalidate,
|
||||
(room_id, thread_root_event_id),
|
||||
)
|
||||
|
||||
def _subscribe_user_to_thread_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
|
||||
requested_automatic = automatic_event_orderings is not None
|
||||
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="thread_subscriptions",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"event_id": thread_root_event_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
retcols=(
|
||||
"subscribed",
|
||||
"automatic",
|
||||
"unsubscribed_at_stream_ordering",
|
||||
"unsubscribed_at_topological_ordering",
|
||||
),
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
if row is None:
|
||||
# We have never subscribed before, simply insert the row and finish
|
||||
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table="thread_subscriptions",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"event_id": thread_root_event_id,
|
||||
"room_id": room_id,
|
||||
"subscribed": True,
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"automatic": requested_automatic,
|
||||
"unsubscribed_at_stream_ordering": None,
|
||||
"unsubscribed_at_topological_ordering": None,
|
||||
},
|
||||
)
|
||||
_invalidate_subscription_caches(txn)
|
||||
return stream_id
|
||||
|
||||
# we already have either a subscription or a prior unsubscription here
|
||||
(
|
||||
subscribed,
|
||||
already_automatic,
|
||||
unsubscribed_at_stream_ordering,
|
||||
unsubscribed_at_topological_ordering,
|
||||
) = row
|
||||
|
||||
if subscribed and (not already_automatic or requested_automatic):
|
||||
# we are already subscribed and the current subscription state
|
||||
# is good enough (either we already have a manual subscription,
|
||||
# or we requested an automatic subscription)
|
||||
# In that case, nothing to change here.
|
||||
# (See docstring for case-by-case explanation)
|
||||
return None
|
||||
|
||||
if not subscribed and requested_automatic:
|
||||
assert automatic_event_orderings is not None
|
||||
# we previously unsubscribed and we are now automatically subscribing
|
||||
# Check whether the new autosubscription should be skipped
|
||||
if ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription(
|
||||
autosub=automatic_event_orderings,
|
||||
unsubscribed_at=EventOrderings(
|
||||
unsubscribed_at_stream_ordering,
|
||||
unsubscribed_at_topological_ordering,
|
||||
),
|
||||
):
|
||||
# skip the subscription
|
||||
return AutomaticSubscriptionConflicted()
|
||||
|
||||
# At this point: we have now finished checking that we need to make
|
||||
# a subscription, updating the current row.
|
||||
|
||||
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
|
||||
self.db_pool.simple_update_txn(
|
||||
txn,
|
||||
table="thread_subscriptions",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"event_id": thread_root_event_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
updatevalues={
|
||||
"subscribed": True,
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"automatic": requested_automatic,
|
||||
"unsubscribed_at_stream_ordering": None,
|
||||
"unsubscribed_at_topological_ordering": None,
|
||||
},
|
||||
)
|
||||
_invalidate_subscription_caches(txn)
|
||||
|
||||
return stream_id
|
||||
|
||||
@@ -214,6 +345,21 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||
|
||||
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
|
||||
|
||||
# Find the maximum stream ordering and topological ordering of the room,
|
||||
# which we then store against this unsubscription so we can skip future
|
||||
# automatic subscriptions that are caused by an event logically earlier
|
||||
# than this unsubscription.
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT MAX(stream_ordering) AS mso, MAX(topological_ordering) AS mto FROM events
|
||||
WHERE room_id = ?
|
||||
""",
|
||||
(room_id,),
|
||||
)
|
||||
ord_row = txn.fetchone()
|
||||
assert ord_row is not None
|
||||
max_stream_ordering, max_topological_ordering = ord_row
|
||||
|
||||
self.db_pool.simple_update_txn(
|
||||
txn,
|
||||
table="thread_subscriptions",
|
||||
@@ -227,6 +373,8 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||
"subscribed": False,
|
||||
"stream_id": stream_id,
|
||||
"instance_name": self._instance_name,
|
||||
"unsubscribed_at_stream_ordering": max_stream_ordering,
|
||||
"unsubscribed_at_topological_ordering": max_topological_ordering,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -234,6 +382,10 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||
self.get_subscription_for_thread.invalidate,
|
||||
(user_id, room_id, thread_root_event_id),
|
||||
)
|
||||
txn.call_after(
|
||||
self.get_subscribers_to_thread.invalidate,
|
||||
(room_id, thread_root_event_id),
|
||||
)
|
||||
|
||||
return stream_id
|
||||
|
||||
@@ -246,7 +398,9 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||
Purge all subscriptions for the user.
|
||||
The fact that subscriptions have been purged will not be streamed;
|
||||
all stream rows for the user will in fact be removed.
|
||||
This is intended only for dealing with user deactivation.
|
||||
|
||||
This must only be used for user deactivation,
|
||||
because it does not invalidate the `subscribers_to_thread` cache.
|
||||
"""
|
||||
|
||||
def _purge_thread_subscription_settings_for_user_txn(
|
||||
@@ -307,6 +461,42 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||
|
||||
return ThreadSubscription(automatic=automatic)
|
||||
|
||||
# max_entries=100 rationale:
|
||||
# this returns a potentially large datastructure
|
||||
# (since each entry contains a set which contains a potentially large number of user IDs),
|
||||
# whereas the default of 10'000 entries for @cached feels more
|
||||
# suitable for very small cache entries.
|
||||
#
|
||||
# Overall, when bearing in mind the usual profile of a small community-server or company-server
|
||||
# (where cache tuning hasn't been done, so we're in out-of-box configuration), it is very
|
||||
# unlikely we would benefit from keeping hot the subscribers for as many as 100 threads,
|
||||
# since it's unlikely that so many threads will be active in a short span of time on a small homeserver.
|
||||
# It feels that medium servers will probably also not exhaust this limit.
|
||||
# Larger homeservers are more likely to be carefully tuned, either with a larger global cache factor
|
||||
# or carefully following the usage patterns & cache metrics.
|
||||
# Finally, the query is not so intensive that computing it every time is a huge deal, but given people
|
||||
# often send messages back-to-back in the same thread it seems like it would offer a mild benefit.
|
||||
@cached(max_entries=100)
|
||||
async def get_subscribers_to_thread(
|
||||
self, room_id: str, thread_root_event_id: str
|
||||
) -> FrozenSet[str]:
|
||||
"""
|
||||
Returns:
|
||||
the set of user_ids for local users who are subscribed to the given thread.
|
||||
"""
|
||||
return frozenset(
|
||||
await self.db_pool.simple_select_onecol(
|
||||
table="thread_subscriptions",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"event_id": thread_root_event_id,
|
||||
"subscribed": True,
|
||||
},
|
||||
retcol="user_id",
|
||||
desc="get_subscribers_to_thread",
|
||||
)
|
||||
)
|
||||
|
||||
def get_max_thread_subscriptions_stream_id(self) -> int:
|
||||
"""Get the current maximum stream_id for thread subscriptions.
|
||||
|
||||
@@ -316,7 +506,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||
return self._thread_subscriptions_id_gen.get_current_token()
|
||||
|
||||
async def get_updated_thread_subscriptions(
|
||||
self, from_id: int, to_id: int, limit: int
|
||||
self, *, from_id: int, to_id: int, limit: int
|
||||
) -> List[Tuple[int, str, str, str]]:
|
||||
"""Get updates to thread subscriptions between two stream IDs.
|
||||
|
||||
@@ -349,7 +539,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
|
||||
)
|
||||
|
||||
async def get_updated_thread_subscriptions_for_user(
|
||||
self, user_id: str, from_id: int, to_id: int, limit: int
|
||||
self, user_id: str, *, from_id: int, to_id: int, limit: int
|
||||
) -> List[Tuple[int, str, str]]:
|
||||
"""Get updates to thread subscriptions for a specific user.
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2025 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- The maximum stream_ordering in the room when the unsubscription was made.
|
||||
ALTER TABLE thread_subscriptions
|
||||
ADD COLUMN unsubscribed_at_stream_ordering BIGINT;
|
||||
|
||||
-- The maximum topological_ordering in the room when the unsubscription was made.
|
||||
ALTER TABLE thread_subscriptions
|
||||
ADD COLUMN unsubscribed_at_topological_ordering BIGINT;
|
||||
@@ -0,0 +1,18 @@
|
||||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2025 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_stream_ordering IS
|
||||
$$The maximum stream_ordering in the room when the unsubscription was made.$$;
|
||||
|
||||
COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_topological_ordering IS
|
||||
$$The maximum topological_ordering in the room when the unsubscription was made.$$;
|
||||
@@ -49,6 +49,7 @@ class FilteredPushRules:
|
||||
msc3664_enabled: bool,
|
||||
msc4028_push_encrypted_events: bool,
|
||||
msc4210_enabled: bool,
|
||||
msc4306_enabled: bool,
|
||||
): ...
|
||||
def rules(self) -> Collection[Tuple[PushRule, bool]]: ...
|
||||
|
||||
@@ -67,13 +68,19 @@ class PushRuleEvaluator:
|
||||
room_version_feature_flags: Tuple[str, ...],
|
||||
msc3931_enabled: bool,
|
||||
msc4210_enabled: bool,
|
||||
msc4306_enabled: bool,
|
||||
): ...
|
||||
def run(
|
||||
self,
|
||||
push_rules: FilteredPushRules,
|
||||
user_id: Optional[str],
|
||||
display_name: Optional[str],
|
||||
msc4306_thread_subscription_state: Optional[bool],
|
||||
) -> Collection[Union[Mapping, str]]: ...
|
||||
def matches(
|
||||
self, condition: JsonDict, user_id: Optional[str], display_name: Optional[str]
|
||||
self,
|
||||
condition: JsonDict,
|
||||
user_id: Optional[str],
|
||||
display_name: Optional[str],
|
||||
msc4306_thread_subscription_state: Optional[bool] = None,
|
||||
) -> bool: ...
|
||||
|
||||
@@ -73,6 +73,7 @@ if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
from synapse.appservice.api import ApplicationService
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage.databases.main import DataStore, PurgeEventsStore
|
||||
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
@@ -1464,3 +1465,31 @@ class ScheduledTask:
|
||||
result: Optional[JsonMapping]
|
||||
# Optional error that should be assigned a value when the status is FAILED
|
||||
error: Optional[str]
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, frozen=True, slots=True)
|
||||
class EventOrderings:
|
||||
stream: int
|
||||
"""
|
||||
The stream_ordering of the event.
|
||||
Negative numbers mean the event was backfilled.
|
||||
"""
|
||||
|
||||
topological: int
|
||||
"""
|
||||
The topological_ordering of the event.
|
||||
Currently this is equivalent to the `depth` attributes of
|
||||
the PDU.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_event(event: "EventBase") -> "EventOrderings":
|
||||
"""
|
||||
Get the orderings from an event.
|
||||
|
||||
Preconditions:
|
||||
- the event must have been persisted (otherwise it won't have a stream ordering)
|
||||
"""
|
||||
stream = event.internal_metadata.stream_ordering
|
||||
assert stream is not None
|
||||
return EventOrderings(stream, event.depth)
|
||||
|
||||
@@ -13,7 +13,11 @@
|
||||
#
|
||||
#
|
||||
|
||||
from synapse._pydantic_compat import BaseModel, Extra
|
||||
import re
|
||||
from typing import Any, Callable, Generator
|
||||
|
||||
from synapse._pydantic_compat import BaseModel, Extra, StrictStr
|
||||
from synapse.types import EventID
|
||||
|
||||
|
||||
class ParseModel(BaseModel):
|
||||
@@ -37,3 +41,43 @@ class ParseModel(BaseModel):
|
||||
extra = Extra.ignore
|
||||
# By default, don't allow fields to be reassigned after parsing.
|
||||
allow_mutation = False
|
||||
|
||||
|
||||
class AnyEventId(StrictStr):
|
||||
"""
|
||||
A validator for strings that need to be an Event ID.
|
||||
|
||||
Accepts any valid grammar of Event ID from any room version.
|
||||
"""
|
||||
|
||||
EVENT_ID_HASH_ROOM_VERSION_3_PLUS = re.compile(
|
||||
r"^([a-zA-Z0-9-_]{43}|[a-zA-Z0-9+/]{43})$"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls) -> Generator[Callable[..., Any], Any, Any]:
|
||||
yield from super().__get_validators__() # type: ignore
|
||||
yield cls.validate_event_id
|
||||
|
||||
@classmethod
|
||||
def validate_event_id(cls, value: str) -> str:
|
||||
if not value.startswith("$"):
|
||||
raise ValueError("Event ID must start with `$`")
|
||||
|
||||
if ":" in value:
|
||||
# Room versions 1 and 2
|
||||
EventID.from_string(value) # throws on fail
|
||||
else:
|
||||
# Room versions 3+: event ID is $ + a base64 sha256 hash
|
||||
# Room version 3 is base64, 4+ are base64Url
|
||||
# In both cases, the base64 is unpadded.
|
||||
# refs:
|
||||
# - https://spec.matrix.org/v1.15/rooms/v3/ e.g. $acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk
|
||||
# - https://spec.matrix.org/v1.15/rooms/v4/ e.g. $Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg
|
||||
b64_hash = value[1:]
|
||||
if cls.EVENT_ID_HASH_ROOM_VERSION_3_PLUS.fullmatch(b64_hash) is None:
|
||||
raise ValueError(
|
||||
"Event ID must either have a domain part or be a valid hash"
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
@@ -131,31 +131,27 @@ def _get_counts_from_rate_limiter_instance(
|
||||
# We track the number of affected hosts per time-period so we can
|
||||
# differentiate one really noisy homeserver from a general
|
||||
# ratelimit tuning problem across the federation.
|
||||
sleep_affected_hosts_gauge = LaterGauge(
|
||||
LaterGauge(
|
||||
name="synapse_rate_limit_sleep_affected_hosts",
|
||||
desc="Number of hosts that had requests put to sleep",
|
||||
labelnames=["rate_limiter_name", SERVER_NAME_LABEL],
|
||||
)
|
||||
sleep_affected_hosts_gauge.register_hook(
|
||||
lambda: _get_counts_from_rate_limiter_instance(
|
||||
caller=lambda: _get_counts_from_rate_limiter_instance(
|
||||
lambda rate_limiter_instance: sum(
|
||||
ratelimiter.should_sleep()
|
||||
for ratelimiter in rate_limiter_instance.ratelimiters.values()
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
reject_affected_hosts_gauge = LaterGauge(
|
||||
LaterGauge(
|
||||
name="synapse_rate_limit_reject_affected_hosts",
|
||||
desc="Number of hosts that had requests rejected",
|
||||
labelnames=["rate_limiter_name", SERVER_NAME_LABEL],
|
||||
)
|
||||
reject_affected_hosts_gauge.register_hook(
|
||||
lambda: _get_counts_from_rate_limiter_instance(
|
||||
caller=lambda: _get_counts_from_rate_limiter_instance(
|
||||
lambda rate_limiter_instance: sum(
|
||||
ratelimiter.should_reject()
|
||||
for ratelimiter in rate_limiter_instance.ratelimiters.values()
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -44,13 +44,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
running_tasks_gauge = LaterGauge(
|
||||
name="synapse_scheduler_running_tasks",
|
||||
desc="The number of concurrent running tasks handled by the TaskScheduler",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""
|
||||
This is a simple task scheduler designed for resumable tasks. Normally,
|
||||
@@ -137,8 +130,11 @@ class TaskScheduler:
|
||||
TaskScheduler.SCHEDULE_INTERVAL_MS,
|
||||
)
|
||||
|
||||
running_tasks_gauge.register_hook(
|
||||
lambda: {(self.server_name,): len(self._running_tasks)}
|
||||
LaterGauge(
|
||||
name="synapse_scheduler_running_tasks",
|
||||
desc="The number of concurrent running tasks handled by the TaskScheduler",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
caller=lambda: {(self.server_name,): len(self._running_tasks)},
|
||||
)
|
||||
|
||||
def register_action(
|
||||
|
||||
@@ -22,13 +22,7 @@ from typing import Dict, Protocol, Tuple
|
||||
|
||||
from prometheus_client.core import Sample
|
||||
|
||||
from synapse.metrics import (
|
||||
REGISTRY,
|
||||
SERVER_NAME_LABEL,
|
||||
InFlightGauge,
|
||||
LaterGauge,
|
||||
generate_latest,
|
||||
)
|
||||
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
|
||||
from synapse.util.caches.deferred_cache import DeferredCache
|
||||
|
||||
from tests import unittest
|
||||
@@ -291,42 +285,6 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(hs2_cache_max_size_metric_value, "777.0")
|
||||
|
||||
|
||||
class LaterGaugeTests(unittest.HomeserverTestCase):
|
||||
def test_later_gauge_multiple_servers(self) -> None:
|
||||
"""
|
||||
Test that LaterGauge metrics are reported correctly across multiple servers. We
|
||||
will have an metrics entry for each homeserver that is labeled with the
|
||||
`server_name` label.
|
||||
"""
|
||||
later_gauge = LaterGauge(
|
||||
name="foo",
|
||||
desc="",
|
||||
labelnames=[SERVER_NAME_LABEL],
|
||||
)
|
||||
later_gauge.register_hook(lambda: {("hs1",): 1})
|
||||
later_gauge.register_hook(lambda: {("hs2",): 2})
|
||||
|
||||
metrics_map = get_latest_metrics()
|
||||
|
||||
# Find the metrics for the caches from both homeservers
|
||||
hs1_metric = 'foo{server_name="hs1"}'
|
||||
hs1_metric_value = metrics_map.get(hs1_metric)
|
||||
self.assertIsNotNone(
|
||||
hs1_metric_value,
|
||||
f"Missing metric {hs1_metric} in cache metrics {metrics_map}",
|
||||
)
|
||||
hs2_metric = 'foo{server_name="hs2"}'
|
||||
hs2_metric_value = metrics_map.get(hs2_metric)
|
||||
self.assertIsNotNone(
|
||||
hs2_metric_value,
|
||||
f"Missing metric {hs2_metric} in cache metrics {metrics_map}",
|
||||
)
|
||||
|
||||
# Sanity check the metric values
|
||||
self.assertEqual(hs1_metric_value, "1.0")
|
||||
self.assertEqual(hs2_metric_value, "2.0")
|
||||
|
||||
|
||||
def get_latest_metrics() -> Dict[str, str]:
|
||||
"""
|
||||
Collect the latest metrics from the registry and parse them into an easy to use map.
|
||||
|
||||
@@ -26,7 +26,7 @@ from parameterized import parameterized
|
||||
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventContentFields, RelationTypes
|
||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
|
||||
from synapse.rest import admin
|
||||
@@ -206,7 +206,10 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
|
||||
bulk_evaluator._action_for_event_by_user.assert_not_called()
|
||||
|
||||
def _create_and_process(
|
||||
self, bulk_evaluator: BulkPushRuleEvaluator, content: Optional[JsonDict] = None
|
||||
self,
|
||||
bulk_evaluator: BulkPushRuleEvaluator,
|
||||
content: Optional[JsonDict] = None,
|
||||
type: str = "test",
|
||||
) -> bool:
|
||||
"""Returns true iff the `mentions` trigger an event push action."""
|
||||
# Create a new message event which should cause a notification.
|
||||
@@ -214,7 +217,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
|
||||
self.event_creation_handler.create_event(
|
||||
self.requester,
|
||||
{
|
||||
"type": "test",
|
||||
"type": type,
|
||||
"room_id": self.room_id,
|
||||
"content": content or {},
|
||||
"sender": f"@bob:{self.hs.hostname}",
|
||||
@@ -446,3 +449,73 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@override_config({"experimental_features": {"msc4306_enabled": True}})
|
||||
def test_thread_subscriptions(self) -> None:
|
||||
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
|
||||
(thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token)
|
||||
|
||||
self.assertFalse(
|
||||
self._create_and_process(
|
||||
bulk_evaluator,
|
||||
{
|
||||
"msgtype": "m.text",
|
||||
"body": "test message before subscription",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
type=EventTypes.Message,
|
||||
)
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.subscribe_user_to_thread(
|
||||
self.alice,
|
||||
self.room_id,
|
||||
thread_root_id,
|
||||
automatic_event_orderings=None,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
self._create_and_process(
|
||||
bulk_evaluator,
|
||||
{
|
||||
"msgtype": "m.text",
|
||||
"body": "test message after subscription",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
type="m.room.message",
|
||||
)
|
||||
)
|
||||
|
||||
def test_with_disabled_thread_subscriptions(self) -> None:
|
||||
"""
|
||||
Test what happens with threaded events when MSC4306 is disabled.
|
||||
|
||||
FUTURE: If MSC4306 becomes enabled-by-default/accepted, this test is to be removed.
|
||||
"""
|
||||
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
|
||||
(thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token)
|
||||
|
||||
# When MSC4306 is not enabled, a threaded message generates a notification
|
||||
# by default.
|
||||
self.assertTrue(
|
||||
self._create_and_process(
|
||||
bulk_evaluator,
|
||||
{
|
||||
"msgtype": "m.text",
|
||||
"body": "test message before subscription",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": thread_root_id,
|
||||
},
|
||||
},
|
||||
type="m.room.message",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -150,6 +150,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||
*,
|
||||
related_events: Optional[JsonDict] = None,
|
||||
msc4210: bool = False,
|
||||
msc4306: bool = False,
|
||||
) -> PushRuleEvaluator:
|
||||
event = FrozenEvent(
|
||||
{
|
||||
@@ -176,6 +177,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||
room_version_feature_flags=event.room_version.msc3931_push_features,
|
||||
msc3931_enabled=True,
|
||||
msc4210_enabled=msc4210,
|
||||
msc4306_enabled=msc4306,
|
||||
)
|
||||
|
||||
def test_display_name(self) -> None:
|
||||
@@ -806,6 +808,112 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_thread_subscription_subscribed(self) -> None:
|
||||
"""
|
||||
Test MSC4306 thread subscription push rules against an event in a subscribed thread.
|
||||
"""
|
||||
evaluator = self._get_evaluator(
|
||||
{
|
||||
"msgtype": "m.text",
|
||||
"body": "Squawk",
|
||||
"m.relates_to": {
|
||||
"event_id": "$threadroot",
|
||||
"rel_type": "m.thread",
|
||||
},
|
||||
},
|
||||
msc4306=True,
|
||||
)
|
||||
self.assertTrue(
|
||||
evaluator.matches(
|
||||
{
|
||||
"kind": "io.element.msc4306.thread_subscription",
|
||||
"subscribed": True,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
msc4306_thread_subscription_state=True,
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
evaluator.matches(
|
||||
{
|
||||
"kind": "io.element.msc4306.thread_subscription",
|
||||
"subscribed": False,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
msc4306_thread_subscription_state=True,
|
||||
)
|
||||
)
|
||||
|
||||
def test_thread_subscription_unsubscribed(self) -> None:
|
||||
"""
|
||||
Test MSC4306 thread subscription push rules against an event in an unsubscribed thread.
|
||||
"""
|
||||
evaluator = self._get_evaluator(
|
||||
{
|
||||
"msgtype": "m.text",
|
||||
"body": "Squawk",
|
||||
"m.relates_to": {
|
||||
"event_id": "$threadroot",
|
||||
"rel_type": "m.thread",
|
||||
},
|
||||
},
|
||||
msc4306=True,
|
||||
)
|
||||
self.assertFalse(
|
||||
evaluator.matches(
|
||||
{
|
||||
"kind": "io.element.msc4306.thread_subscription",
|
||||
"subscribed": True,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
msc4306_thread_subscription_state=False,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
evaluator.matches(
|
||||
{
|
||||
"kind": "io.element.msc4306.thread_subscription",
|
||||
"subscribed": False,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
msc4306_thread_subscription_state=False,
|
||||
)
|
||||
)
|
||||
|
||||
def test_thread_subscription_unthreaded(self) -> None:
|
||||
"""
|
||||
Test MSC4306 thread subscription push rules against an unthreaded event.
|
||||
"""
|
||||
evaluator = self._get_evaluator(
|
||||
{"msgtype": "m.text", "body": "Squawk"}, msc4306=True
|
||||
)
|
||||
self.assertFalse(
|
||||
evaluator.matches(
|
||||
{
|
||||
"kind": "io.element.msc4306.thread_subscription",
|
||||
"subscribed": True,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
msc4306_thread_subscription_state=None,
|
||||
)
|
||||
)
|
||||
self.assertFalse(
|
||||
evaluator.matches(
|
||||
{
|
||||
"kind": "io.element.msc4306.thread_subscription",
|
||||
"subscribed": False,
|
||||
},
|
||||
None,
|
||||
None,
|
||||
msc4306_thread_subscription_state=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
|
||||
"""Tests for the bulk push rule evaluator"""
|
||||
|
||||
@@ -62,7 +62,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
|
||||
"@test_user:example.org",
|
||||
room_id,
|
||||
thread_root_id,
|
||||
automatic=True,
|
||||
automatic_event_orderings=None,
|
||||
)
|
||||
)
|
||||
updates.append(thread_root_id)
|
||||
@@ -75,7 +75,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
|
||||
"@test_user:example.org",
|
||||
other_room_id,
|
||||
other_thread_root_id,
|
||||
automatic=False,
|
||||
automatic_event_orderings=None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -124,7 +124,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
|
||||
for user_id in users:
|
||||
self.get_success(
|
||||
store.subscribe_user_to_thread(
|
||||
user_id, room_id, thread_root_id, automatic=True
|
||||
user_id, room_id, thread_root_id, automatic_event_orderings=None
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from http import HTTPStatus
|
||||
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, profile, room, thread_subscriptions
|
||||
from synapse.server import HomeServer
|
||||
@@ -49,15 +50,16 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
# Create a room and send a message to use as a thread root
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||
self.helper.join(self.room_id, self.other_user_id, tok=self.other_token)
|
||||
response = self.helper.send(self.room_id, body="Root message", tok=self.token)
|
||||
self.root_event_id = response["event_id"]
|
||||
(self.root_event_id,) = self.helper.send_messages(
|
||||
self.room_id, 1, tok=self.token
|
||||
)
|
||||
|
||||
# Send a message in the thread
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type="m.room.message",
|
||||
content={
|
||||
"body": "Thread message",
|
||||
self.threaded_events = self.helper.send_messages(
|
||||
self.room_id,
|
||||
2,
|
||||
content_fn=lambda idx: {
|
||||
"body": f"Thread message {idx}",
|
||||
"msgtype": "m.text",
|
||||
"m.relates_to": {
|
||||
"rel_type": "m.thread",
|
||||
@@ -106,9 +108,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
{
|
||||
"automatic": False,
|
||||
},
|
||||
{},
|
||||
access_token=self.token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
@@ -127,7 +127,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
{"automatic": True},
|
||||
{"automatic": self.threaded_events[0]},
|
||||
access_token=self.token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
@@ -148,11 +148,11 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
"PUT",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
{
|
||||
"automatic": True,
|
||||
"automatic": self.threaded_events[0],
|
||||
},
|
||||
access_token=self.token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body)
|
||||
|
||||
# Assert the subscription was saved
|
||||
channel = self.make_request(
|
||||
@@ -167,7 +167,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
{"automatic": False},
|
||||
{},
|
||||
access_token=self.token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
@@ -187,7 +187,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
"PUT",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
{
|
||||
"automatic": True,
|
||||
"automatic": self.threaded_events[0],
|
||||
},
|
||||
access_token=self.token,
|
||||
)
|
||||
@@ -202,7 +202,6 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(channel.json_body, {"automatic": True})
|
||||
|
||||
# Now also register a manual subscription
|
||||
channel = self.make_request(
|
||||
"DELETE",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
@@ -210,7 +209,6 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
|
||||
# Assert the manual subscription was not overridden
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
@@ -224,7 +222,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription",
|
||||
{"automatic": True},
|
||||
{},
|
||||
access_token=self.token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||
@@ -238,7 +236,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
{"automatic": True},
|
||||
{},
|
||||
access_token=no_access_token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||
@@ -249,8 +247,105 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
# non-boolean `automatic`
|
||||
{"automatic": "true"},
|
||||
# non-Event ID `automatic`
|
||||
{"automatic": True},
|
||||
access_token=self.token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
# non-Event ID `automatic`
|
||||
{"automatic": "$malformedEventId"},
|
||||
access_token=self.token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
|
||||
|
||||
def test_auto_subscribe_cause_event_not_in_thread(self) -> None:
|
||||
"""
|
||||
Test making an automatic subscription, where the cause event is not
|
||||
actually in the thread.
|
||||
This is an error.
|
||||
"""
|
||||
(unrelated_event_id,) = self.helper.send_messages(
|
||||
self.room_id, 1, tok=self.token
|
||||
)
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
{"automatic": unrelated_event_id},
|
||||
access_token=self.token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.text_body)
|
||||
self.assertEqual(channel.json_body["errcode"], Codes.MSC4306_NOT_IN_THREAD)
|
||||
|
||||
def test_auto_resubscription_conflict(self) -> None:
|
||||
"""
|
||||
Test that an automatic subscription that conflicts with an unsubscription
|
||||
is skipped.
|
||||
"""
|
||||
# Reuse the test that subscribes and unsubscribes
|
||||
self.test_unsubscribe()
|
||||
|
||||
# Now no matter which event we present as the cause of an automatic subscription,
|
||||
# the automatic subscription is skipped.
|
||||
# This is because the unsubscription happened after all of the events.
|
||||
for event in self.threaded_events:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
{
|
||||
"automatic": event,
|
||||
},
|
||||
access_token=self.token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.CONFLICT, channel.text_body)
|
||||
self.assertEqual(
|
||||
channel.json_body["errcode"],
|
||||
Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION,
|
||||
channel.text_body,
|
||||
)
|
||||
|
||||
# Check the subscription was not made
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
access_token=self.token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
|
||||
|
||||
# But if a new event is sent after the unsubscription took place,
|
||||
# that one can be used for an automatic subscription
|
||||
(later_event_id,) = self.helper.send_messages(
|
||||
self.room_id,
|
||||
1,
|
||||
content_fn=lambda _: {
|
||||
"body": "Thread message after unsubscription",
|
||||
"msgtype": "m.text",
|
||||
"m.relates_to": {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": self.root_event_id,
|
||||
},
|
||||
},
|
||||
tok=self.token,
|
||||
)
|
||||
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
{
|
||||
"automatic": later_event_id,
|
||||
},
|
||||
access_token=self.token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body)
|
||||
|
||||
# Check the subscription was made
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
|
||||
access_token=self.token,
|
||||
)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
self.assertEqual(channel.json_body, {"automatic": True})
|
||||
|
||||
@@ -29,12 +29,14 @@ from http import HTTPStatus
|
||||
from typing import (
|
||||
Any,
|
||||
AnyStr,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
Literal,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
overload,
|
||||
)
|
||||
@@ -45,7 +47,7 @@ import attr
|
||||
from twisted.internet.testing import MemoryReactorClock
|
||||
from twisted.web.server import Site
|
||||
|
||||
from synapse.api.constants import Membership, ReceiptTypes
|
||||
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
|
||||
from synapse.api.errors import Codes
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
@@ -394,6 +396,32 @@ class RestHelper:
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
|
||||
def send_messages(
|
||||
self,
|
||||
room_id: str,
|
||||
num_events: int,
|
||||
content_fn: Callable[[int], JsonDict] = lambda idx: {
|
||||
"msgtype": "m.text",
|
||||
"body": f"Test event {idx}",
|
||||
},
|
||||
tok: Optional[str] = None,
|
||||
) -> Sequence[str]:
|
||||
"""
|
||||
Helper to send a handful of sequential events and return their event IDs as a sequence.
|
||||
"""
|
||||
event_ids = []
|
||||
|
||||
for event_index in range(num_events):
|
||||
response = self.send_event(
|
||||
room_id,
|
||||
EventTypes.Message,
|
||||
content_fn(event_index),
|
||||
tok=tok,
|
||||
)
|
||||
event_ids.append(response["event_id"])
|
||||
|
||||
return event_ids
|
||||
|
||||
def send_event(
|
||||
self,
|
||||
room_id: str,
|
||||
|
||||
@@ -12,13 +12,18 @@
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.databases.main.thread_subscriptions import (
|
||||
AutomaticSubscriptionConflicted,
|
||||
ThreadSubscriptionsWorkerStore,
|
||||
)
|
||||
from synapse.storage.engines.sqlite import Sqlite3Engine
|
||||
from synapse.types import EventOrderings
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
@@ -97,10 +102,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
self,
|
||||
thread_root_id: str,
|
||||
*,
|
||||
automatic: bool,
|
||||
automatic_event_orderings: Optional[EventOrderings],
|
||||
room_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[int]:
|
||||
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
|
||||
if user_id is None:
|
||||
user_id = self.user_id
|
||||
|
||||
@@ -112,7 +117,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
user_id,
|
||||
room_id,
|
||||
thread_root_id,
|
||||
automatic=automatic,
|
||||
automatic_event_orderings=automatic_event_orderings,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -149,7 +154,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
# Subscribe
|
||||
self._subscribe(
|
||||
self.thread_root_id,
|
||||
automatic=True,
|
||||
automatic_event_orderings=EventOrderings(1, 1),
|
||||
)
|
||||
|
||||
# Assert subscription went through
|
||||
@@ -164,7 +169,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
# Now make it a manual subscription
|
||||
self._subscribe(
|
||||
self.thread_root_id,
|
||||
automatic=False,
|
||||
automatic_event_orderings=None,
|
||||
)
|
||||
|
||||
# Assert the manual subscription overrode the automatic one
|
||||
@@ -178,8 +183,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
def test_purge_thread_subscriptions_for_user(self) -> None:
|
||||
"""Test purging all thread subscription settings for a user."""
|
||||
# Set subscription settings for multiple threads
|
||||
self._subscribe(self.thread_root_id, automatic=True)
|
||||
self._subscribe(self.other_thread_root_id, automatic=False)
|
||||
self._subscribe(
|
||||
self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1)
|
||||
)
|
||||
self._subscribe(self.other_thread_root_id, automatic_event_orderings=None)
|
||||
|
||||
subscriptions = self.get_success(
|
||||
self.store.get_updated_thread_subscriptions_for_user(
|
||||
@@ -217,20 +224,32 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
def test_get_updated_thread_subscriptions(self) -> None:
|
||||
"""Test getting updated thread subscriptions since a stream ID."""
|
||||
|
||||
stream_id1 = self._subscribe(self.thread_root_id, automatic=False)
|
||||
stream_id2 = self._subscribe(self.other_thread_root_id, automatic=True)
|
||||
assert stream_id1 is not None
|
||||
assert stream_id2 is not None
|
||||
stream_id1 = self._subscribe(
|
||||
self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1)
|
||||
)
|
||||
stream_id2 = self._subscribe(
|
||||
self.other_thread_root_id, automatic_event_orderings=EventOrderings(2, 2)
|
||||
)
|
||||
assert stream_id1 is not None and not isinstance(
|
||||
stream_id1, AutomaticSubscriptionConflicted
|
||||
)
|
||||
assert stream_id2 is not None and not isinstance(
|
||||
stream_id2, AutomaticSubscriptionConflicted
|
||||
)
|
||||
|
||||
# Get updates since initial ID (should include both changes)
|
||||
updates = self.get_success(
|
||||
self.store.get_updated_thread_subscriptions(0, stream_id2, 10)
|
||||
self.store.get_updated_thread_subscriptions(
|
||||
from_id=0, to_id=stream_id2, limit=10
|
||||
)
|
||||
)
|
||||
self.assertEqual(len(updates), 2)
|
||||
|
||||
# Get updates since first change (should include only the second change)
|
||||
updates = self.get_success(
|
||||
self.store.get_updated_thread_subscriptions(stream_id1, stream_id2, 10)
|
||||
self.store.get_updated_thread_subscriptions(
|
||||
from_id=stream_id1, to_id=stream_id2, limit=10
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
updates,
|
||||
@@ -242,21 +261,27 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
other_user_id = "@other_user:test"
|
||||
|
||||
# Set thread subscription for main user
|
||||
stream_id1 = self._subscribe(self.thread_root_id, automatic=True)
|
||||
assert stream_id1 is not None
|
||||
stream_id1 = self._subscribe(
|
||||
self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1)
|
||||
)
|
||||
assert stream_id1 is not None and not isinstance(
|
||||
stream_id1, AutomaticSubscriptionConflicted
|
||||
)
|
||||
|
||||
# Set thread subscription for other user
|
||||
stream_id2 = self._subscribe(
|
||||
self.other_thread_root_id,
|
||||
automatic=True,
|
||||
automatic_event_orderings=EventOrderings(1, 1),
|
||||
user_id=other_user_id,
|
||||
)
|
||||
assert stream_id2 is not None
|
||||
assert stream_id2 is not None and not isinstance(
|
||||
stream_id2, AutomaticSubscriptionConflicted
|
||||
)
|
||||
|
||||
# Get updates for main user
|
||||
updates = self.get_success(
|
||||
self.store.get_updated_thread_subscriptions_for_user(
|
||||
self.user_id, 0, stream_id2, 10
|
||||
self.user_id, from_id=0, to_id=stream_id2, limit=10
|
||||
)
|
||||
)
|
||||
self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)])
|
||||
@@ -264,9 +289,80 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
|
||||
# Get updates for other user
|
||||
updates = self.get_success(
|
||||
self.store.get_updated_thread_subscriptions_for_user(
|
||||
other_user_id, 0, max(stream_id1, stream_id2), 10
|
||||
other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
updates, [(stream_id2, self.room_id, self.other_thread_root_id)]
|
||||
)
|
||||
|
||||
def test_should_skip_autosubscription_after_unsubscription(self) -> None:
|
||||
"""
|
||||
Tests the comparison logic for whether an autoscription should be skipped
|
||||
due to a chronologically earlier but logically later unsubscription.
|
||||
"""
|
||||
|
||||
func = ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription
|
||||
|
||||
# Order of arguments:
|
||||
# automatic cause event: stream order, then topological order
|
||||
# unsubscribe maximums: stream order, then tological order
|
||||
|
||||
# both orderings agree that the unsub is after the cause event
|
||||
self.assertTrue(
|
||||
func(autosub=EventOrderings(1, 1), unsubscribed_at=EventOrderings(2, 2))
|
||||
)
|
||||
|
||||
# topological ordering is inconsistent with stream ordering,
|
||||
# in that case favour stream ordering because it's what /sync uses
|
||||
self.assertTrue(
|
||||
func(autosub=EventOrderings(1, 2), unsubscribed_at=EventOrderings(2, 1))
|
||||
)
|
||||
|
||||
# the automatic subscription is caused by a backfilled event here
|
||||
# unfortunately we must fall back to topological ordering here
|
||||
self.assertTrue(
|
||||
func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 3))
|
||||
)
|
||||
self.assertFalse(
|
||||
func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 1))
|
||||
)
|
||||
|
||||
def test_get_subscribers_to_thread(self) -> None:
|
||||
"""
|
||||
Test getting all subscribers to a thread at once.
|
||||
|
||||
To check cache invalidations are correct, we do multiple
|
||||
step-by-step rounds of subscription changes and assertions.
|
||||
"""
|
||||
other_user_id = "@other_user:test"
|
||||
|
||||
subscribers = self.get_success(
|
||||
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
|
||||
)
|
||||
self.assertEqual(subscribers, frozenset())
|
||||
|
||||
self._subscribe(
|
||||
self.thread_root_id, automatic_event_orderings=None, user_id=self.user_id
|
||||
)
|
||||
|
||||
subscribers = self.get_success(
|
||||
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
|
||||
)
|
||||
self.assertEqual(subscribers, frozenset((self.user_id,)))
|
||||
|
||||
self._subscribe(
|
||||
self.thread_root_id, automatic_event_orderings=None, user_id=other_user_id
|
||||
)
|
||||
|
||||
subscribers = self.get_success(
|
||||
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
|
||||
)
|
||||
self.assertEqual(subscribers, frozenset((self.user_id, other_user_id)))
|
||||
|
||||
self._unsubscribe(self.thread_root_id, user_id=self.user_id)
|
||||
|
||||
subscribers = self.get_success(
|
||||
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
|
||||
)
|
||||
self.assertEqual(subscribers, frozenset((other_user_id,)))
|
||||
|
||||
Reference in New Issue
Block a user