Implement the push rules for experimental MSC4306: Thread Subscriptions. (#18762)

Follows: #18756

Implements: MSC4306

---------

Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
This commit is contained in:
reivilibre
2025-08-06 15:33:52 +01:00
committed by GitHub
parent 8306cee06a
commit 6514381b02
12 changed files with 404 additions and 28 deletions

View File

@@ -0,0 +1 @@
Implement the push rules for experimental [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306).

View File

@@ -61,6 +61,7 @@ fn bench_match_exact(b: &mut Bencher) {
vec![], vec![],
false, false,
false, false,
false,
) )
.unwrap(); .unwrap();
@@ -71,10 +72,10 @@ fn bench_match_exact(b: &mut Bencher) {
}, },
)); ));
let matched = eval.match_condition(&condition, None, None).unwrap(); let matched = eval.match_condition(&condition, None, None, None).unwrap();
assert!(matched, "Didn't match"); assert!(matched, "Didn't match");
b.iter(|| eval.match_condition(&condition, None, None).unwrap()); b.iter(|| eval.match_condition(&condition, None, None, None).unwrap());
} }
#[bench] #[bench]
@@ -107,6 +108,7 @@ fn bench_match_word(b: &mut Bencher) {
vec![], vec![],
false, false,
false, false,
false,
) )
.unwrap(); .unwrap();
@@ -117,10 +119,10 @@ fn bench_match_word(b: &mut Bencher) {
}, },
)); ));
let matched = eval.match_condition(&condition, None, None).unwrap(); let matched = eval.match_condition(&condition, None, None, None).unwrap();
assert!(matched, "Didn't match"); assert!(matched, "Didn't match");
b.iter(|| eval.match_condition(&condition, None, None).unwrap()); b.iter(|| eval.match_condition(&condition, None, None, None).unwrap());
} }
#[bench] #[bench]
@@ -153,6 +155,7 @@ fn bench_match_word_miss(b: &mut Bencher) {
vec![], vec![],
false, false,
false, false,
false,
) )
.unwrap(); .unwrap();
@@ -163,10 +166,10 @@ fn bench_match_word_miss(b: &mut Bencher) {
}, },
)); ));
let matched = eval.match_condition(&condition, None, None).unwrap(); let matched = eval.match_condition(&condition, None, None, None).unwrap();
assert!(!matched, "Didn't match"); assert!(!matched, "Didn't match");
b.iter(|| eval.match_condition(&condition, None, None).unwrap()); b.iter(|| eval.match_condition(&condition, None, None, None).unwrap());
} }
#[bench] #[bench]
@@ -199,6 +202,7 @@ fn bench_eval_message(b: &mut Bencher) {
vec![], vec![],
false, false,
false, false,
false,
) )
.unwrap(); .unwrap();
@@ -210,7 +214,8 @@ fn bench_eval_message(b: &mut Bencher) {
false, false,
false, false,
false, false,
false,
); );
b.iter(|| eval.run(&rules, Some("bob"), Some("person"))); b.iter(|| eval.run(&rules, Some("bob"), Some("person"), None));
} }

View File

@@ -290,6 +290,26 @@ pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule {
}]; }];
pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
PushRule {
rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.unsubscribed_thread"),
priority_class: 1,
conditions: Cow::Borrowed(&[Condition::Known(
KnownCondition::Msc4306ThreadSubscription { subscribed: false },
)]),
actions: Cow::Borrowed(&[]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.subscribed_thread"),
priority_class: 1,
conditions: Cow::Borrowed(&[Condition::Known(
KnownCondition::Msc4306ThreadSubscription { subscribed: true },
)]),
actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]),
default: true,
default_enabled: true,
},
PushRule { PushRule {
rule_id: Cow::Borrowed("global/underride/.m.rule.call"), rule_id: Cow::Borrowed("global/underride/.m.rule.call"),
priority_class: 1, priority_class: 1,

View File

@@ -106,8 +106,11 @@ pub struct PushRuleEvaluator {
/// flag as MSC1767 (extensible events core). /// flag as MSC1767 (extensible events core).
msc3931_enabled: bool, msc3931_enabled: bool,
// If MSC4210 (remove legacy mentions) is enabled. /// If MSC4210 (remove legacy mentions) is enabled.
msc4210_enabled: bool, msc4210_enabled: bool,
/// If MSC4306 (thread subscriptions) is enabled.
msc4306_enabled: bool,
} }
#[pymethods] #[pymethods]
@@ -126,6 +129,7 @@ impl PushRuleEvaluator {
room_version_feature_flags, room_version_feature_flags,
msc3931_enabled, msc3931_enabled,
msc4210_enabled, msc4210_enabled,
msc4306_enabled,
))] ))]
pub fn py_new( pub fn py_new(
flattened_keys: BTreeMap<String, JsonValue>, flattened_keys: BTreeMap<String, JsonValue>,
@@ -138,6 +142,7 @@ impl PushRuleEvaluator {
room_version_feature_flags: Vec<String>, room_version_feature_flags: Vec<String>,
msc3931_enabled: bool, msc3931_enabled: bool,
msc4210_enabled: bool, msc4210_enabled: bool,
msc4306_enabled: bool,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let body = match flattened_keys.get("content.body") { let body = match flattened_keys.get("content.body") {
Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone().into_owned(), Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone().into_owned(),
@@ -156,6 +161,7 @@ impl PushRuleEvaluator {
room_version_feature_flags, room_version_feature_flags,
msc3931_enabled, msc3931_enabled,
msc4210_enabled, msc4210_enabled,
msc4306_enabled,
}) })
} }
@@ -167,12 +173,19 @@ impl PushRuleEvaluator {
/// ///
/// Returns the set of actions, if any, that match (filtering out any /// Returns the set of actions, if any, that match (filtering out any
/// `dont_notify` and `coalesce` actions). /// `dont_notify` and `coalesce` actions).
#[pyo3(signature = (push_rules, user_id=None, display_name=None))] ///
/// msc4306_thread_subscription_state: (Only populated if MSC4306 is enabled)
/// The thread subscription state corresponding to the thread containing this event.
/// - `None` if the event is not in a thread, or if MSC4306 is disabled.
/// - `Some(true)` if the event is in a thread and the user has a subscription for that thread
/// - `Some(false)` if the event is in a thread and the user does NOT have a subscription for that thread
#[pyo3(signature = (push_rules, user_id=None, display_name=None, msc4306_thread_subscription_state=None))]
pub fn run( pub fn run(
&self, &self,
push_rules: &FilteredPushRules, push_rules: &FilteredPushRules,
user_id: Option<&str>, user_id: Option<&str>,
display_name: Option<&str>, display_name: Option<&str>,
msc4306_thread_subscription_state: Option<bool>,
) -> Vec<Action> { ) -> Vec<Action> {
'outer: for (push_rule, enabled) in push_rules.iter() { 'outer: for (push_rule, enabled) in push_rules.iter() {
if !enabled { if !enabled {
@@ -204,7 +217,12 @@ impl PushRuleEvaluator {
Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }), Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }),
); );
match self.match_condition(condition, user_id, display_name) { match self.match_condition(
condition,
user_id,
display_name,
msc4306_thread_subscription_state,
) {
Ok(true) => {} Ok(true) => {}
Ok(false) => continue 'outer, Ok(false) => continue 'outer,
Err(err) => { Err(err) => {
@@ -237,14 +255,20 @@ impl PushRuleEvaluator {
} }
/// Check if the given condition matches. /// Check if the given condition matches.
#[pyo3(signature = (condition, user_id=None, display_name=None))] #[pyo3(signature = (condition, user_id=None, display_name=None, msc4306_thread_subscription_state=None))]
fn matches( fn matches(
&self, &self,
condition: Condition, condition: Condition,
user_id: Option<&str>, user_id: Option<&str>,
display_name: Option<&str>, display_name: Option<&str>,
msc4306_thread_subscription_state: Option<bool>,
) -> bool { ) -> bool {
match self.match_condition(&condition, user_id, display_name) { match self.match_condition(
&condition,
user_id,
display_name,
msc4306_thread_subscription_state,
) {
Ok(true) => true, Ok(true) => true,
Ok(false) => false, Ok(false) => false,
Err(err) => { Err(err) => {
@@ -262,6 +286,7 @@ impl PushRuleEvaluator {
condition: &Condition, condition: &Condition,
user_id: Option<&str>, user_id: Option<&str>,
display_name: Option<&str>, display_name: Option<&str>,
msc4306_thread_subscription_state: Option<bool>,
) -> Result<bool, Error> { ) -> Result<bool, Error> {
let known_condition = match condition { let known_condition = match condition {
Condition::Known(known) => known, Condition::Known(known) => known,
@@ -393,6 +418,13 @@ impl PushRuleEvaluator {
&& self.room_version_feature_flags.contains(&flag) && self.room_version_feature_flags.contains(&flag)
} }
} }
KnownCondition::Msc4306ThreadSubscription { subscribed } => {
if !self.msc4306_enabled {
false
} else {
msc4306_thread_subscription_state == Some(*subscribed)
}
}
}; };
Ok(result) Ok(result)
@@ -536,10 +568,11 @@ fn push_rule_evaluator() {
vec![], vec![],
true, true,
false, false,
false,
) )
.unwrap(); .unwrap();
let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"), None);
assert_eq!(result.len(), 3); assert_eq!(result.len(), 3);
} }
@@ -566,6 +599,7 @@ fn test_requires_room_version_supports_condition() {
flags, flags,
true, true,
false, false,
false,
) )
.unwrap(); .unwrap();
@@ -575,6 +609,7 @@ fn test_requires_room_version_supports_condition() {
&FilteredPushRules::default(), &FilteredPushRules::default(),
Some("@bob:example.org"), Some("@bob:example.org"),
None, None,
None,
); );
assert_eq!(result.len(), 3); assert_eq!(result.len(), 3);
@@ -593,7 +628,17 @@ fn test_requires_room_version_supports_condition() {
}; };
let rules = PushRules::new(vec![custom_rule]); let rules = PushRules::new(vec![custom_rule]);
result = evaluator.run( result = evaluator.run(
&FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false, false), &FilteredPushRules::py_new(
rules,
BTreeMap::new(),
true,
false,
true,
false,
false,
false,
),
None,
None, None,
None, None,
); );

View File

@@ -369,6 +369,10 @@ pub enum KnownCondition {
RoomVersionSupports { RoomVersionSupports {
feature: Cow<'static, str>, feature: Cow<'static, str>,
}, },
#[serde(rename = "io.element.msc4306.thread_subscription")]
Msc4306ThreadSubscription {
subscribed: bool,
},
} }
impl<'source> IntoPyObject<'source> for Condition { impl<'source> IntoPyObject<'source> for Condition {
@@ -547,11 +551,13 @@ pub struct FilteredPushRules {
msc3664_enabled: bool, msc3664_enabled: bool,
msc4028_push_encrypted_events: bool, msc4028_push_encrypted_events: bool,
msc4210_enabled: bool, msc4210_enabled: bool,
msc4306_enabled: bool,
} }
#[pymethods] #[pymethods]
impl FilteredPushRules { impl FilteredPushRules {
#[new] #[new]
#[allow(clippy::too_many_arguments)]
pub fn py_new( pub fn py_new(
push_rules: PushRules, push_rules: PushRules,
enabled_map: BTreeMap<String, bool>, enabled_map: BTreeMap<String, bool>,
@@ -560,6 +566,7 @@ impl FilteredPushRules {
msc3664_enabled: bool, msc3664_enabled: bool,
msc4028_push_encrypted_events: bool, msc4028_push_encrypted_events: bool,
msc4210_enabled: bool, msc4210_enabled: bool,
msc4306_enabled: bool,
) -> Self { ) -> Self {
Self { Self {
push_rules, push_rules,
@@ -569,6 +576,7 @@ impl FilteredPushRules {
msc3664_enabled, msc3664_enabled,
msc4028_push_encrypted_events, msc4028_push_encrypted_events,
msc4210_enabled, msc4210_enabled,
msc4306_enabled,
} }
} }
@@ -619,6 +627,10 @@ impl FilteredPushRules {
return false; return false;
} }
if !self.msc4306_enabled && rule.rule_id.contains("/.io.element.msc4306.rule.") {
return false;
}
true true
}) })
.map(|r| { .map(|r| {

View File

@@ -25,6 +25,7 @@ from typing import (
Any, Any,
Collection, Collection,
Dict, Dict,
FrozenSet,
List, List,
Mapping, Mapping,
Optional, Optional,
@@ -477,8 +478,18 @@ class BulkPushRuleEvaluator:
event.room_version.msc3931_push_features, event.room_version.msc3931_push_features,
self.hs.config.experimental.msc1767_enabled, # MSC3931 flag self.hs.config.experimental.msc1767_enabled, # MSC3931 flag
self.hs.config.experimental.msc4210_enabled, self.hs.config.experimental.msc4210_enabled,
self.hs.config.experimental.msc4306_enabled,
) )
msc4306_thread_subscribers: Optional[FrozenSet[str]] = None
if self.hs.config.experimental.msc4306_enabled and thread_id != MAIN_TIMELINE:
# pull out, in batch, all local subscribers to this thread
# (in the common case, they will all be getting processed for push
# rules right now)
msc4306_thread_subscribers = await self.store.get_subscribers_to_thread(
event.room_id, thread_id
)
for uid, rules in rules_by_user.items(): for uid, rules in rules_by_user.items():
if event.sender == uid: if event.sender == uid:
continue continue
@@ -503,7 +514,13 @@ class BulkPushRuleEvaluator:
# current user, it'll be added to the dict later. # current user, it'll be added to the dict later.
actions_by_user[uid] = [] actions_by_user[uid] = []
actions = evaluator.run(rules, uid, display_name) msc4306_thread_subscription_state: Optional[bool] = None
if msc4306_thread_subscribers is not None:
msc4306_thread_subscription_state = uid in msc4306_thread_subscribers
actions = evaluator.run(
rules, uid, display_name, msc4306_thread_subscription_state
)
if "notify" in actions: if "notify" in actions:
# Push rules say we should notify the user of this event # Push rules say we should notify the user of this event
actions_by_user[uid] = actions actions_by_user[uid] = actions

View File

@@ -110,6 +110,7 @@ def _load_rules(
msc3381_polls_enabled=experimental_config.msc3381_polls_enabled, msc3381_polls_enabled=experimental_config.msc3381_polls_enabled,
msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events, msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events,
msc4210_enabled=experimental_config.msc4210_enabled, msc4210_enabled=experimental_config.msc4210_enabled,
msc4306_enabled=experimental_config.msc4306_enabled,
) )
return filtered_rules return filtered_rules

View File

@@ -14,6 +14,7 @@ import logging
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
FrozenSet,
Iterable, Iterable,
List, List,
Optional, Optional,
@@ -99,6 +100,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
self.get_subscription_for_thread.invalidate( self.get_subscription_for_thread.invalidate(
(row.user_id, row.room_id, row.event_id) (row.user_id, row.room_id, row.event_id)
) )
self.get_subscribers_to_thread.invalidate((row.room_id, row.event_id))
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)
@@ -194,6 +196,16 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
""" """
assert self._can_write_to_thread_subscriptions assert self._can_write_to_thread_subscriptions
def _invalidate_subscription_caches(txn: LoggingTransaction) -> None:
txn.call_after(
self.get_subscription_for_thread.invalidate,
(user_id, room_id, thread_root_event_id),
)
txn.call_after(
self.get_subscribers_to_thread.invalidate,
(room_id, thread_root_event_id),
)
def _subscribe_user_to_thread_txn( def _subscribe_user_to_thread_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]: ) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
@@ -234,10 +246,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
"unsubscribed_at_topological_ordering": None, "unsubscribed_at_topological_ordering": None,
}, },
) )
txn.call_after( _invalidate_subscription_caches(txn)
self.get_subscription_for_thread.invalidate,
(user_id, room_id, thread_root_event_id),
)
return stream_id return stream_id
# we already have either a subscription or a prior unsubscription here # we already have either a subscription or a prior unsubscription here
@@ -291,10 +300,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
"unsubscribed_at_topological_ordering": None, "unsubscribed_at_topological_ordering": None,
}, },
) )
txn.call_after( _invalidate_subscription_caches(txn)
self.get_subscription_for_thread.invalidate,
(user_id, room_id, thread_root_event_id),
)
return stream_id return stream_id
@@ -376,6 +382,10 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
self.get_subscription_for_thread.invalidate, self.get_subscription_for_thread.invalidate,
(user_id, room_id, thread_root_event_id), (user_id, room_id, thread_root_event_id),
) )
txn.call_after(
self.get_subscribers_to_thread.invalidate,
(room_id, thread_root_event_id),
)
return stream_id return stream_id
@@ -388,7 +398,9 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
Purge all subscriptions for the user. Purge all subscriptions for the user.
The fact that subscriptions have been purged will not be streamed; The fact that subscriptions have been purged will not be streamed;
all stream rows for the user will in fact be removed. all stream rows for the user will in fact be removed.
This is intended only for dealing with user deactivation.
This must only be used for user deactivation,
because it does not invalidate the `subscribers_to_thread` cache.
""" """
def _purge_thread_subscription_settings_for_user_txn( def _purge_thread_subscription_settings_for_user_txn(
@@ -449,6 +461,42 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
return ThreadSubscription(automatic=automatic) return ThreadSubscription(automatic=automatic)
# max_entries=100 rationale:
# this returns a potentially large datastructure
# (since each entry contains a set which contains a potentially large number of user IDs),
# whereas the default of 10'000 entries for @cached feels more
# suitable for very small cache entries.
#
# Overall, when bearing in mind the usual profile of a small community-server or company-server
# (where cache tuning hasn't been done, so we're in out-of-box configuration), it is very
# unlikely we would benefit from keeping hot the subscribers for as many as 100 threads,
# since it's unlikely that so many threads will be active in a short span of time on a small homeserver.
# It feels that medium servers will probably also not exhaust this limit.
# Larger homeservers are more likely to be carefully tuned, either with a larger global cache factor
# or carefully following the usage patterns & cache metrics.
# Finally, the query is not so intensive that computing it every time is a huge deal, but given people
# often send messages back-to-back in the same thread it seems like it would offer a mild benefit.
@cached(max_entries=100)
async def get_subscribers_to_thread(
self, room_id: str, thread_root_event_id: str
) -> FrozenSet[str]:
"""
Returns:
the set of user_ids for local users who are subscribed to the given thread.
"""
return frozenset(
await self.db_pool.simple_select_onecol(
table="thread_subscriptions",
keyvalues={
"room_id": room_id,
"event_id": thread_root_event_id,
"subscribed": True,
},
retcol="user_id",
desc="get_subscribers_to_thread",
)
)
def get_max_thread_subscriptions_stream_id(self) -> int: def get_max_thread_subscriptions_stream_id(self) -> int:
"""Get the current maximum stream_id for thread subscriptions. """Get the current maximum stream_id for thread subscriptions.

View File

@@ -49,6 +49,7 @@ class FilteredPushRules:
msc3664_enabled: bool, msc3664_enabled: bool,
msc4028_push_encrypted_events: bool, msc4028_push_encrypted_events: bool,
msc4210_enabled: bool, msc4210_enabled: bool,
msc4306_enabled: bool,
): ... ): ...
def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ...
@@ -67,13 +68,19 @@ class PushRuleEvaluator:
room_version_feature_flags: Tuple[str, ...], room_version_feature_flags: Tuple[str, ...],
msc3931_enabled: bool, msc3931_enabled: bool,
msc4210_enabled: bool, msc4210_enabled: bool,
msc4306_enabled: bool,
): ... ): ...
def run( def run(
self, self,
push_rules: FilteredPushRules, push_rules: FilteredPushRules,
user_id: Optional[str], user_id: Optional[str],
display_name: Optional[str], display_name: Optional[str],
msc4306_thread_subscription_state: Optional[bool],
) -> Collection[Union[Mapping, str]]: ... ) -> Collection[Union[Mapping, str]]: ...
def matches( def matches(
self, condition: JsonDict, user_id: Optional[str], display_name: Optional[str] self,
condition: JsonDict,
user_id: Optional[str],
display_name: Optional[str],
msc4306_thread_subscription_state: Optional[bool] = None,
) -> bool: ... ) -> bool: ...

View File

@@ -26,7 +26,7 @@ from parameterized import parameterized
from twisted.internet.testing import MemoryReactor from twisted.internet.testing import MemoryReactor
from synapse.api.constants import EventContentFields, RelationTypes from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.rest import admin from synapse.rest import admin
@@ -206,7 +206,10 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
bulk_evaluator._action_for_event_by_user.assert_not_called() bulk_evaluator._action_for_event_by_user.assert_not_called()
def _create_and_process( def _create_and_process(
self, bulk_evaluator: BulkPushRuleEvaluator, content: Optional[JsonDict] = None self,
bulk_evaluator: BulkPushRuleEvaluator,
content: Optional[JsonDict] = None,
type: str = "test",
) -> bool: ) -> bool:
"""Returns true iff the `mentions` trigger an event push action.""" """Returns true iff the `mentions` trigger an event push action."""
# Create a new message event which should cause a notification. # Create a new message event which should cause a notification.
@@ -214,7 +217,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
self.event_creation_handler.create_event( self.event_creation_handler.create_event(
self.requester, self.requester,
{ {
"type": "test", "type": type,
"room_id": self.room_id, "room_id": self.room_id,
"content": content or {}, "content": content or {},
"sender": f"@bob:{self.hs.hostname}", "sender": f"@bob:{self.hs.hostname}",
@@ -446,3 +449,73 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
}, },
) )
) )
@override_config({"experimental_features": {"msc4306_enabled": True}})
def test_thread_subscriptions(self) -> None:
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
(thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token)
self.assertFalse(
self._create_and_process(
bulk_evaluator,
{
"msgtype": "m.text",
"body": "test message before subscription",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
type=EventTypes.Message,
)
)
self.get_success(
self.hs.get_datastores().main.subscribe_user_to_thread(
self.alice,
self.room_id,
thread_root_id,
automatic_event_orderings=None,
)
)
self.assertTrue(
self._create_and_process(
bulk_evaluator,
{
"msgtype": "m.text",
"body": "test message after subscription",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
type="m.room.message",
)
)
def test_with_disabled_thread_subscriptions(self) -> None:
"""
Test what happens with threaded events when MSC4306 is disabled.
FUTURE: If MSC4306 becomes enabled-by-default/accepted, this test is to be removed.
"""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
(thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token)
# When MSC4306 is not enabled, a threaded message generates a notification
# by default.
self.assertTrue(
self._create_and_process(
bulk_evaluator,
{
"msgtype": "m.text",
"body": "test message before subscription",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
type="m.room.message",
)
)

View File

@@ -150,6 +150,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
*, *,
related_events: Optional[JsonDict] = None, related_events: Optional[JsonDict] = None,
msc4210: bool = False, msc4210: bool = False,
msc4306: bool = False,
) -> PushRuleEvaluator: ) -> PushRuleEvaluator:
event = FrozenEvent( event = FrozenEvent(
{ {
@@ -176,6 +177,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
room_version_feature_flags=event.room_version.msc3931_push_features, room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True, msc3931_enabled=True,
msc4210_enabled=msc4210, msc4210_enabled=msc4210,
msc4306_enabled=msc4306,
) )
def test_display_name(self) -> None: def test_display_name(self) -> None:
@@ -806,6 +808,112 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
) )
) )
def test_thread_subscription_subscribed(self) -> None:
"""
Test MSC4306 thread subscription push rules against an event in a subscribed thread.
"""
evaluator = self._get_evaluator(
{
"msgtype": "m.text",
"body": "Squawk",
"m.relates_to": {
"event_id": "$threadroot",
"rel_type": "m.thread",
},
},
msc4306=True,
)
self.assertTrue(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": True,
},
None,
None,
msc4306_thread_subscription_state=True,
)
)
self.assertFalse(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": False,
},
None,
None,
msc4306_thread_subscription_state=True,
)
)
def test_thread_subscription_unsubscribed(self) -> None:
"""
Test MSC4306 thread subscription push rules against an event in an unsubscribed thread.
"""
evaluator = self._get_evaluator(
{
"msgtype": "m.text",
"body": "Squawk",
"m.relates_to": {
"event_id": "$threadroot",
"rel_type": "m.thread",
},
},
msc4306=True,
)
self.assertFalse(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": True,
},
None,
None,
msc4306_thread_subscription_state=False,
)
)
self.assertTrue(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": False,
},
None,
None,
msc4306_thread_subscription_state=False,
)
)
def test_thread_subscription_unthreaded(self) -> None:
"""
Test MSC4306 thread subscription push rules against an unthreaded event.
"""
evaluator = self._get_evaluator(
{"msgtype": "m.text", "body": "Squawk"}, msc4306=True
)
self.assertFalse(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": True,
},
None,
None,
msc4306_thread_subscription_state=None,
)
)
self.assertFalse(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": False,
},
None,
None,
msc4306_thread_subscription_state=None,
)
)
class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
"""Tests for the bulk push rule evaluator""" """Tests for the bulk push rule evaluator"""

View File

@@ -327,3 +327,42 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
self.assertFalse( self.assertFalse(
func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 1)) func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 1))
) )
def test_get_subscribers_to_thread(self) -> None:
"""
Test getting all subscribers to a thread at once.
To check cache invalidations are correct, we do multiple
step-by-step rounds of subscription changes and assertions.
"""
other_user_id = "@other_user:test"
subscribers = self.get_success(
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
)
self.assertEqual(subscribers, frozenset())
self._subscribe(
self.thread_root_id, automatic_event_orderings=None, user_id=self.user_id
)
subscribers = self.get_success(
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
)
self.assertEqual(subscribers, frozenset((self.user_id,)))
self._subscribe(
self.thread_root_id, automatic_event_orderings=None, user_id=other_user_id
)
subscribers = self.get_success(
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
)
self.assertEqual(subscribers, frozenset((self.user_id, other_user_id)))
self._unsubscribe(self.thread_root_id, user_id=self.user_id)
subscribers = self.get_success(
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
)
self.assertEqual(subscribers, frozenset((other_user_id,)))