1
0

Compare commits

...

22 Commits

Author SHA1 Message Date
Devon Hudson
c9f212ab44 Revert "Fix LaterGauge metrics to collect from all servers (#18751)"
This reverts commit 076db0ab49.
2025-08-06 15:49:26 -06:00
Devon Hudson
85e3adba86 Revert "Temporarily disable problem test"
This reverts commit 4333eff1d5.
2025-08-06 15:49:03 -06:00
Devon Hudson
d3bdf8b091 Revert "Temporarily disable all tests that call generate_latest"
This reverts commit d8ab5434d5.
2025-08-06 15:48:36 -06:00
Devon Hudson
d8ab5434d5 Temporarily disable all tests that call generate_latest 2025-08-06 15:21:42 -06:00
Devon Hudson
4333eff1d5 Temporarily disable problem test 2025-08-06 14:39:35 -06:00
Devon Hudson
c9f04f3484 Re-enable parallel 2025-08-06 14:39:18 -06:00
Olivier 'reivilibre
a387d6ecf8 better monitor 2025-08-06 18:03:31 +01:00
Olivier 'reivilibre
9e473d9e38 fail slow 2025-08-06 18:02:30 +01:00
Olivier 'reivilibre
d2ea7e32f5 Revert "choose a test"
This reverts commit a256423553.
2025-08-06 18:02:11 +01:00
Olivier 'reivilibre
2db0f1e49b don't wait for lint 2025-08-06 17:58:14 +01:00
Olivier 'reivilibre
a256423553 choose a test 2025-08-06 17:56:54 +01:00
Olivier 'reivilibre
e91aa4fd2f debug diskspace 2025-08-06 17:12:49 +01:00
Olivier 'reivilibre
499c1631de no parallel? 2025-08-06 16:29:03 +01:00
Olivier 'reivilibre
4ecd9aba95 Newsfile
Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
2025-08-06 16:07:32 +01:00
Olivier 'reivilibre
4e947d05ab Remove stray @DEBUG 2025-08-06 16:07:07 +01:00
reivilibre
6514381b02 Implement the push rules for experimental MSC4306: Thread Subscriptions. (#18762)
Follows: #18756

Implements: MSC4306

---------

Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
2025-08-06 15:33:52 +01:00
reivilibre
8306cee06a Update implementation of MSC4306: Thread Subscriptions to include automatic subscription conflict prevention as introduced in later drafts. (#18756)
Follows: #18674

Implements new drafts of MSC4306

---------

Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
Co-authored-by: Eric Eastwood <erice@element.io>
2025-08-05 18:22:53 +00:00
Eric Eastwood
076db0ab49 Fix LaterGauge metrics to collect from all servers (#18751)
Fix `LaterGauge` metrics to collect from all servers

Follow-up to https://github.com/element-hq/synapse/pull/18714

Previously, our `LaterGauge` metrics did include the `server_name` label
as expected but we were only seeing the last server being reported in
some cases. Any `LaterGauge` that we were creating multiple times was
only reporting the last instance.

This PR updates all `LaterGauge` to be created once and then we use
`LaterGauge.register_hook(...)` to add in the metric callback as before.
This works now because we store a list of callbacks instead of just one.

I noticed this problem thanks to some [tests in the Synapse Pro for
Small Hosts](https://github.com/element-hq/synapse-small-hosts/pull/173)
repo that sanity check all metrics to ensure that we can see each metric
includes data from multiple servers.


### Testing strategy

1. This is only noticeable when you run multiple Synapse instances in
the same process.
 1. TODO

(see test that was added)

### Dev notes

Previous non-global `LaterGauge`:

```
synapse_federation_send_queue_xxx
synapse_federation_transaction_queue_pending_destinations
synapse_federation_transaction_queue_pending_pdus
synapse_federation_transaction_queue_pending_edus
synapse_handlers_presence_user_to_current_state_size
synapse_handlers_presence_wheel_timer_size
synapse_notifier_listeners
synapse_notifier_rooms
synapse_notifier_users
synapse_replication_tcp_resource_total_connections
synapse_replication_tcp_command_queue
synapse_background_update_status
synapse_federation_known_servers
synapse_scheduler_running_tasks
```



### Pull Request Checklist

<!-- Please read
https://element-hq.github.io/synapse/latest/development/contributing_guide.html
before submitting your pull request -->

* [x] Pull request is based on the develop branch
* [x] Pull request includes a [changelog
file](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#changelog).
The entry should:
- Be a short description of your change which makes sense to users.
"Fixed a bug that prevented receiving messages from other servers."
instead of "Moved X method from `EventStore` to `EventWorkerStore`.".
  - Use markdown where necessary, mostly for `code blocks`.
  - End with either a period (.) or an exclamation mark (!).
  - Start with a capital letter.
- Feel free to credit yourself, by adding a sentence "Contributed by
@github_username." or "Contributed by [Your Name]." to the end of the
entry.
* [x] [Code
style](https://element-hq.github.io/synapse/latest/code_style.html) is
correct (run the
[linters](https://element-hq.github.io/synapse/latest/development/contributing_guide.html#run-the-linters))
2025-08-05 15:28:55 +00:00
Andrew Morgan
c7762cd55e Prevent "Move labelled issues to correct projects" GitHub Actions workflow from failing when an issue is already on the project board (#18755) 2025-08-05 12:03:25 +01:00
Andrew Morgan
357b749bf3 Bump minimum supported rust version to 1.82.0 (#18757) 2025-08-05 12:02:57 +01:00
Erik Johnston
20615115fb Make .sleep(..) return a coroutine (#18772)
This helps ensure that mypy can catch places where we don't await on it,
like in #18763.

---------

Co-authored-by: Eric Eastwood <erice@element.io>
2025-08-05 09:30:52 +01:00
Eric Eastwood
ddbcd859aa Improve order of validation and ratelimiting in room creation (#18723)
Spawning from looking at this stuff while reviewing
https://github.com/element-hq/synapse/pull/18721
2025-08-04 11:08:02 -05:00
40 changed files with 1082 additions and 198 deletions

View File

@@ -373,7 +373,7 @@ jobs:
calculate-test-jobs:
if: ${{ !cancelled() && !failure() }} # Allow previous steps to be skipped, but not fail
needs: linting-done
# needs: linting-done
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
@@ -393,6 +393,7 @@ jobs:
- changes
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
job: ${{ fromJson(needs.calculate-test-jobs.outputs.trial_test_matrix) }}
@@ -426,7 +427,24 @@ jobs:
if: ${{ matrix.job.postgres-version }}
timeout-minutes: 2
run: until pg_isready -h localhost; do sleep 1; done
- run: poetry run trial --jobs=6 tests
- run: |
(
while true; do
echo "......."
date
df -h | grep root
free -m
sleep 10
done
) &
MONITOR_PID=$!
poetry run trial --jobs=6 tests
STATUS=$?
kill $MONITOR_PID
exit $STATUS
env:
SYNAPSE_POSTGRES: ${{ matrix.job.database == 'postgres' || '' }}
SYNAPSE_POSTGRES_HOST: /var/run/postgresql

View File

@@ -16,6 +16,10 @@ jobs:
with:
project-url: "https://github.com/orgs/matrix-org/projects/67"
github-token: ${{ secrets.ELEMENT_BOT_TOKEN }}
# This action will error if the issue already exists on the project. Which is
# common as `X-Needs-Info` will often be added to issues that are already in
# the triage queue. Prevent the whole job from failing in this case.
continue-on-error: true
- name: Set status
env:
GITHUB_TOKEN: ${{ secrets.ELEMENT_BOT_TOKEN }}

1
changelog.d/18723.misc Normal file
View File

@@ -0,0 +1 @@
Improve order of validation and ratelimiting in room creation.

1
changelog.d/18755.misc Normal file
View File

@@ -0,0 +1 @@
Prevent "Move labelled issues to correct projects" GitHub Actions workflow from failing when an issue is already on the project board.

1
changelog.d/18756.misc Normal file
View File

@@ -0,0 +1 @@
Update implementation of [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306) to include automatic subscription conflict prevention as introduced in later drafts.

1
changelog.d/18757.misc Normal file
View File

@@ -0,0 +1 @@
Bump minimum supported Rust version (MSRV) to 1.82.0. Missed in [#18553](https://github.com/element-hq/synapse/pull/18553) (released in Synapse 1.134.0).

View File

@@ -0,0 +1 @@
Implement the push rules for experimental [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306).

1
changelog.d/18772.misc Normal file
View File

@@ -0,0 +1 @@
Make `Clock.sleep(..)` return a coroutine, so that mypy can catch places where we don't await on it.

1
changelog.d/18787.misc Normal file
View File

@@ -0,0 +1 @@
CI debugging.

View File

@@ -7,7 +7,7 @@ name = "synapse"
version = "0.1.0"
edition = "2021"
rust-version = "1.81.0"
rust-version = "1.82.0"
[lib]
name = "synapse"

View File

@@ -61,6 +61,7 @@ fn bench_match_exact(b: &mut Bencher) {
vec![],
false,
false,
false,
)
.unwrap();
@@ -71,10 +72,10 @@ fn bench_match_exact(b: &mut Bencher) {
},
));
let matched = eval.match_condition(&condition, None, None).unwrap();
let matched = eval.match_condition(&condition, None, None, None).unwrap();
assert!(matched, "Didn't match");
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
b.iter(|| eval.match_condition(&condition, None, None, None).unwrap());
}
#[bench]
@@ -107,6 +108,7 @@ fn bench_match_word(b: &mut Bencher) {
vec![],
false,
false,
false,
)
.unwrap();
@@ -117,10 +119,10 @@ fn bench_match_word(b: &mut Bencher) {
},
));
let matched = eval.match_condition(&condition, None, None).unwrap();
let matched = eval.match_condition(&condition, None, None, None).unwrap();
assert!(matched, "Didn't match");
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
b.iter(|| eval.match_condition(&condition, None, None, None).unwrap());
}
#[bench]
@@ -153,6 +155,7 @@ fn bench_match_word_miss(b: &mut Bencher) {
vec![],
false,
false,
false,
)
.unwrap();
@@ -163,10 +166,10 @@ fn bench_match_word_miss(b: &mut Bencher) {
},
));
let matched = eval.match_condition(&condition, None, None).unwrap();
let matched = eval.match_condition(&condition, None, None, None).unwrap();
assert!(!matched, "Didn't match");
b.iter(|| eval.match_condition(&condition, None, None).unwrap());
b.iter(|| eval.match_condition(&condition, None, None, None).unwrap());
}
#[bench]
@@ -199,6 +202,7 @@ fn bench_eval_message(b: &mut Bencher) {
vec![],
false,
false,
false,
)
.unwrap();
@@ -210,7 +214,8 @@ fn bench_eval_message(b: &mut Bencher) {
false,
false,
false,
false,
);
b.iter(|| eval.run(&rules, Some("bob"), Some("person")));
b.iter(|| eval.run(&rules, Some("bob"), Some("person"), None));
}

View File

@@ -290,6 +290,26 @@ pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule {
}];
pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
PushRule {
rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.unsubscribed_thread"),
priority_class: 1,
conditions: Cow::Borrowed(&[Condition::Known(
KnownCondition::Msc4306ThreadSubscription { subscribed: false },
)]),
actions: Cow::Borrowed(&[]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.subscribed_thread"),
priority_class: 1,
conditions: Cow::Borrowed(&[Condition::Known(
KnownCondition::Msc4306ThreadSubscription { subscribed: true },
)]),
actions: Cow::Borrowed(&[Action::Notify, SOUND_ACTION]),
default: true,
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/underride/.m.rule.call"),
priority_class: 1,

View File

@@ -106,8 +106,11 @@ pub struct PushRuleEvaluator {
/// flag as MSC1767 (extensible events core).
msc3931_enabled: bool,
// If MSC4210 (remove legacy mentions) is enabled.
/// If MSC4210 (remove legacy mentions) is enabled.
msc4210_enabled: bool,
/// If MSC4306 (thread subscriptions) is enabled.
msc4306_enabled: bool,
}
#[pymethods]
@@ -126,6 +129,7 @@ impl PushRuleEvaluator {
room_version_feature_flags,
msc3931_enabled,
msc4210_enabled,
msc4306_enabled,
))]
pub fn py_new(
flattened_keys: BTreeMap<String, JsonValue>,
@@ -138,6 +142,7 @@ impl PushRuleEvaluator {
room_version_feature_flags: Vec<String>,
msc3931_enabled: bool,
msc4210_enabled: bool,
msc4306_enabled: bool,
) -> Result<Self, Error> {
let body = match flattened_keys.get("content.body") {
Some(JsonValue::Value(SimpleJsonValue::Str(s))) => s.clone().into_owned(),
@@ -156,6 +161,7 @@ impl PushRuleEvaluator {
room_version_feature_flags,
msc3931_enabled,
msc4210_enabled,
msc4306_enabled,
})
}
@@ -167,12 +173,19 @@ impl PushRuleEvaluator {
///
/// Returns the set of actions, if any, that match (filtering out any
/// `dont_notify` and `coalesce` actions).
#[pyo3(signature = (push_rules, user_id=None, display_name=None))]
///
/// msc4306_thread_subscription_state: (Only populated if MSC4306 is enabled)
/// The thread subscription state corresponding to the thread containing this event.
/// - `None` if the event is not in a thread, or if MSC4306 is disabled.
/// - `Some(true)` if the event is in a thread and the user has a subscription for that thread
/// - `Some(false)` if the event is in a thread and the user does NOT have a subscription for that thread
#[pyo3(signature = (push_rules, user_id=None, display_name=None, msc4306_thread_subscription_state=None))]
pub fn run(
&self,
push_rules: &FilteredPushRules,
user_id: Option<&str>,
display_name: Option<&str>,
msc4306_thread_subscription_state: Option<bool>,
) -> Vec<Action> {
'outer: for (push_rule, enabled) in push_rules.iter() {
if !enabled {
@@ -204,7 +217,12 @@ impl PushRuleEvaluator {
Condition::Known(KnownCondition::RoomVersionSupports { feature: _ }),
);
match self.match_condition(condition, user_id, display_name) {
match self.match_condition(
condition,
user_id,
display_name,
msc4306_thread_subscription_state,
) {
Ok(true) => {}
Ok(false) => continue 'outer,
Err(err) => {
@@ -237,14 +255,20 @@ impl PushRuleEvaluator {
}
/// Check if the given condition matches.
#[pyo3(signature = (condition, user_id=None, display_name=None))]
#[pyo3(signature = (condition, user_id=None, display_name=None, msc4306_thread_subscription_state=None))]
fn matches(
&self,
condition: Condition,
user_id: Option<&str>,
display_name: Option<&str>,
msc4306_thread_subscription_state: Option<bool>,
) -> bool {
match self.match_condition(&condition, user_id, display_name) {
match self.match_condition(
&condition,
user_id,
display_name,
msc4306_thread_subscription_state,
) {
Ok(true) => true,
Ok(false) => false,
Err(err) => {
@@ -262,6 +286,7 @@ impl PushRuleEvaluator {
condition: &Condition,
user_id: Option<&str>,
display_name: Option<&str>,
msc4306_thread_subscription_state: Option<bool>,
) -> Result<bool, Error> {
let known_condition = match condition {
Condition::Known(known) => known,
@@ -393,6 +418,13 @@ impl PushRuleEvaluator {
&& self.room_version_feature_flags.contains(&flag)
}
}
KnownCondition::Msc4306ThreadSubscription { subscribed } => {
if !self.msc4306_enabled {
false
} else {
msc4306_thread_subscription_state == Some(*subscribed)
}
}
};
Ok(result)
@@ -536,10 +568,11 @@ fn push_rule_evaluator() {
vec![],
true,
false,
false,
)
.unwrap();
let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"));
let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob"), None);
assert_eq!(result.len(), 3);
}
@@ -566,6 +599,7 @@ fn test_requires_room_version_supports_condition() {
flags,
true,
false,
false,
)
.unwrap();
@@ -575,6 +609,7 @@ fn test_requires_room_version_supports_condition() {
&FilteredPushRules::default(),
Some("@bob:example.org"),
None,
None,
);
assert_eq!(result.len(), 3);
@@ -593,7 +628,17 @@ fn test_requires_room_version_supports_condition() {
};
let rules = PushRules::new(vec![custom_rule]);
result = evaluator.run(
&FilteredPushRules::py_new(rules, BTreeMap::new(), true, false, true, false, false),
&FilteredPushRules::py_new(
rules,
BTreeMap::new(),
true,
false,
true,
false,
false,
false,
),
None,
None,
None,
);

View File

@@ -369,6 +369,10 @@ pub enum KnownCondition {
RoomVersionSupports {
feature: Cow<'static, str>,
},
#[serde(rename = "io.element.msc4306.thread_subscription")]
Msc4306ThreadSubscription {
subscribed: bool,
},
}
impl<'source> IntoPyObject<'source> for Condition {
@@ -547,11 +551,13 @@ pub struct FilteredPushRules {
msc3664_enabled: bool,
msc4028_push_encrypted_events: bool,
msc4210_enabled: bool,
msc4306_enabled: bool,
}
#[pymethods]
impl FilteredPushRules {
#[new]
#[allow(clippy::too_many_arguments)]
pub fn py_new(
push_rules: PushRules,
enabled_map: BTreeMap<String, bool>,
@@ -560,6 +566,7 @@ impl FilteredPushRules {
msc3664_enabled: bool,
msc4028_push_encrypted_events: bool,
msc4210_enabled: bool,
msc4306_enabled: bool,
) -> Self {
Self {
push_rules,
@@ -569,6 +576,7 @@ impl FilteredPushRules {
msc3664_enabled,
msc4028_push_encrypted_events,
msc4210_enabled,
msc4306_enabled,
}
}
@@ -619,6 +627,10 @@ impl FilteredPushRules {
return false;
}
if !self.msc4306_enabled && rule.rule_id.contains("/.io.element.msc4306.rule.") {
return false;
}
true
})
.map(|r| {

View File

@@ -140,6 +140,12 @@ class Codes(str, Enum):
# Part of MSC4155
INVITE_BLOCKED = "ORG.MATRIX.MSC4155.M_INVITE_BLOCKED"
# Part of MSC4306: Thread Subscriptions
MSC4306_CONFLICTING_UNSUBSCRIPTION = (
"IO.ELEMENT.MSC4306.M_CONFLICTING_UNSUBSCRIPTION"
)
MSC4306_NOT_IN_THREAD = "IO.ELEMENT.MSC4306.M_NOT_IN_THREAD"
class CodeMessageException(RuntimeError):
"""An exception with integer code, a message string attributes and optional headers.

View File

@@ -779,6 +779,25 @@ class RoomCreationHandler:
await self.auth_blocking.check_auth_blocking(requester=requester)
if ratelimit:
# Limit the rate of room creations,
# using both the limiter specific to room creations as well
# as the general request ratelimiter.
#
# Note that we don't rate limit the individual
# events in the room — room creation isn't atomic and
# historically it was very janky if half the events in the
# initial state don't make it because of rate limiting.
# First check the room creation ratelimiter without updating it
# (this is so we don't consume a token if the other ratelimiter doesn't
# allow us to proceed)
await self.creation_ratelimiter.ratelimit(requester, update=False)
# then apply the ratelimits
await self.common_request_ratelimiter.ratelimit(requester)
await self.creation_ratelimiter.ratelimit(requester)
if (
self._server_notices_mxid is not None
and user_id == self._server_notices_mxid
@@ -810,37 +829,6 @@ class RoomCreationHandler:
Codes.MISSING_PARAM,
)
if not is_requester_admin:
spam_check = await self._spam_checker_module_callbacks.user_may_create_room(
user_id, config
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
raise SynapseError(
403,
"You are not permitted to create rooms",
errcode=spam_check[0],
additional_fields=spam_check[1],
)
if ratelimit:
# Limit the rate of room creations,
# using both the limiter specific to room creations as well
# as the general request ratelimiter.
#
# Note that we don't rate limit the individual
# events in the room — room creation isn't atomic and
# historically it was very janky if half the events in the
# initial state don't make it because of rate limiting.
# First check the room creation ratelimiter without updating it
# (this is so we don't consume a token if the other ratelimiter doesn't
# allow us to proceed)
await self.creation_ratelimiter.ratelimit(requester, update=False)
# then apply the ratelimits
await self.common_request_ratelimiter.ratelimit(requester)
await self.creation_ratelimiter.ratelimit(requester)
room_version_id = config.get(
"room_version", self.config.server.default_room_version.identifier
)
@@ -932,6 +920,19 @@ class RoomCreationHandler:
self._validate_room_config(config, visibility)
# Run the spam checker after other validation
if not is_requester_admin:
spam_check = await self._spam_checker_module_callbacks.user_may_create_room(
user_id, config
)
if spam_check != self._spam_checker_module_callbacks.NOT_SPAM:
raise SynapseError(
403,
"You are not permitted to create rooms",
errcode=spam_check[0],
additional_fields=spam_check[1],
)
room_id = await self._generate_and_create_room_id(
creator_id=user_id,
is_public=is_public,

View File

@@ -1,9 +1,15 @@
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import AuthError, NotFoundError
from synapse.storage.databases.main.thread_subscriptions import ThreadSubscription
from synapse.types import UserID
from synapse.api.constants import RelationTypes
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.events import relation_from_event
from synapse.storage.databases.main.thread_subscriptions import (
AutomaticSubscriptionConflicted,
ThreadSubscription,
)
from synapse.types import EventOrderings, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -55,42 +61,79 @@ class ThreadSubscriptionsHandler:
room_id: str,
thread_root_event_id: str,
*,
automatic: bool,
automatic_event_id: Optional[str],
) -> Optional[int]:
"""Sets or updates a user's subscription settings for a specific thread root.
Args:
requester_user_id: The ID of the user whose settings are being updated.
thread_root_event_id: The event ID of the thread root.
automatic: whether the user was subscribed by an automatic decision by
their client.
automatic_event_id: if the user was subscribed by an automatic decision by
their client, the event ID that caused this.
Returns:
The stream ID for this update, if the update isn't no-opped.
Raises:
NotFoundError if the user cannot access the thread root event, or it isn't
known to this homeserver.
known to this homeserver. Ditto for the automatic cause event if supplied.
SynapseError(400, M_NOT_IN_THREAD): if client supplied an automatic cause event
but user cannot access the event.
SynapseError(409, M_SKIPPED): if client requested an automatic subscription
but it was skipped because the cause event is logically later than an unsubscription.
"""
# First check that the user can access the thread root event
# and that it exists
try:
event = await self.event_handler.get_event(
thread_root_event = await self.event_handler.get_event(
user_id, room_id, thread_root_event_id
)
if event is None:
if thread_root_event is None:
raise NotFoundError("No such thread root")
except AuthError:
logger.info("rejecting thread subscriptions change (thread not accessible)")
raise NotFoundError("No such thread root")
return await self.store.subscribe_user_to_thread(
if automatic_event_id:
autosub_cause_event = await self.event_handler.get_event(
user_id, room_id, automatic_event_id
)
if autosub_cause_event is None:
raise NotFoundError("Automatic subscription event not found")
relation = relation_from_event(autosub_cause_event)
if (
relation is None
or relation.rel_type != RelationTypes.THREAD
or relation.parent_id != thread_root_event_id
):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Automatic subscription must use an event in the thread",
errcode=Codes.MSC4306_NOT_IN_THREAD,
)
automatic_event_orderings = EventOrderings.from_event(autosub_cause_event)
else:
automatic_event_orderings = None
outcome = await self.store.subscribe_user_to_thread(
user_id.to_string(),
event.room_id,
room_id,
thread_root_event_id,
automatic=automatic,
automatic_event_orderings=automatic_event_orderings,
)
if isinstance(outcome, AutomaticSubscriptionConflicted):
raise SynapseError(
HTTPStatus.CONFLICT,
"Automatic subscription obsoleted by an unsubscription request.",
errcode=Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION,
)
return outcome
async def unsubscribe_user_from_thread(
self, user_id: UserID, room_id: str, thread_root_event_id: str
) -> Optional[int]:

View File

@@ -25,6 +25,7 @@ from typing import (
Any,
Collection,
Dict,
FrozenSet,
List,
Mapping,
Optional,
@@ -477,8 +478,18 @@ class BulkPushRuleEvaluator:
event.room_version.msc3931_push_features,
self.hs.config.experimental.msc1767_enabled, # MSC3931 flag
self.hs.config.experimental.msc4210_enabled,
self.hs.config.experimental.msc4306_enabled,
)
msc4306_thread_subscribers: Optional[FrozenSet[str]] = None
if self.hs.config.experimental.msc4306_enabled and thread_id != MAIN_TIMELINE:
# pull out, in batch, all local subscribers to this thread
# (in the common case, they will all be getting processed for push
# rules right now)
msc4306_thread_subscribers = await self.store.get_subscribers_to_thread(
event.room_id, thread_id
)
for uid, rules in rules_by_user.items():
if event.sender == uid:
continue
@@ -503,7 +514,13 @@ class BulkPushRuleEvaluator:
# current user, it'll be added to the dict later.
actions_by_user[uid] = []
actions = evaluator.run(rules, uid, display_name)
msc4306_thread_subscription_state: Optional[bool] = None
if msc4306_thread_subscribers is not None:
msc4306_thread_subscription_state = uid in msc4306_thread_subscribers
actions = evaluator.run(
rules, uid, display_name, msc4306_thread_subscription_state
)
if "notify" in actions:
# Push rules say we should notify the user of this event
actions_by_user[uid] = actions

View File

@@ -739,7 +739,7 @@ class ThreadSubscriptionsStream(_StreamFromIdGen):
NAME = "thread_subscriptions"
ROW_TYPE = ThreadSubscriptionsStreamRow
def __init__(self, hs: Any):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
super().__init__(
hs.get_instance_name(),
@@ -751,7 +751,7 @@ class ThreadSubscriptionsStream(_StreamFromIdGen):
self, instance_name: str, from_token: int, to_token: int, limit: int
) -> StreamUpdateResult:
updates = await self.store.get_updated_thread_subscriptions(
from_token, to_token, limit
from_id=from_token, to_id=to_token, limit=limit
)
rows = [
(

View File

@@ -1,7 +1,6 @@
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
from synapse._pydantic_compat import StrictBool
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
@@ -12,6 +11,7 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
from synapse.types import JsonDict, RoomID
from synapse.types.rest import RequestBodyModel
from synapse.util.pydantic_models import AnyEventId
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -32,7 +32,12 @@ class ThreadSubscriptionsRestServlet(RestServlet):
self.handler = hs.get_thread_subscriptions_handler()
class PutBody(RequestBodyModel):
automatic: StrictBool
automatic: Optional[AnyEventId]
"""
If supplied, the event ID of an event giving rise to this automatic subscription.
If omitted, this subscription is a manual subscription.
"""
async def on_GET(
self, request: SynapseRequest, room_id: str, thread_root_id: str
@@ -63,15 +68,15 @@ class ThreadSubscriptionsRestServlet(RestServlet):
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Invalid event ID", errcode=Codes.INVALID_PARAM
)
requester = await self.auth.get_user_by_req(request)
body = parse_and_validate_json_object_from_request(request, self.PutBody)
requester = await self.auth.get_user_by_req(request)
await self.handler.subscribe_user_to_thread(
requester.user,
room_id,
thread_root_id,
automatic=body.automatic,
automatic_event_id=body.automatic,
)
return HTTPStatus.OK, {}

View File

@@ -52,7 +52,7 @@ class Clock(Protocol):
# This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
# We only ever sleep(0) though, so that other async functions can make forward
# progress without waiting for stateres to complete.
def sleep(self, duration_ms: float) -> Awaitable[None]: ...
async def sleep(self, duration_ms: float) -> None: ...
class StateResolutionStore(Protocol):

View File

@@ -110,6 +110,7 @@ def _load_rules(
msc3381_polls_enabled=experimental_config.msc3381_polls_enabled,
msc4028_push_encrypted_events=experimental_config.msc4028_push_encrypted_events,
msc4210_enabled=experimental_config.msc4210_enabled,
msc4306_enabled=experimental_config.msc4306_enabled,
)
return filtered_rules

View File

@@ -14,7 +14,7 @@ import logging
from typing import (
TYPE_CHECKING,
Any,
Dict,
FrozenSet,
Iterable,
List,
Optional,
@@ -33,6 +33,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import EventOrderings
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -50,6 +51,14 @@ class ThreadSubscription:
"""
class AutomaticSubscriptionConflicted:
"""
Marker return value to signal that an automatic subscription was skipped,
because it conflicted with an unsubscription that we consider to have
been made later than the event causing the automatic subscription.
"""
class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -91,6 +100,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
self.get_subscription_for_thread.invalidate(
(row.user_id, row.room_id, row.event_id)
)
self.get_subscribers_to_thread.invalidate((row.room_id, row.event_id))
super().process_replication_rows(stream_name, instance_name, token, rows)
@@ -101,75 +111,196 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
self._thread_subscriptions_id_gen.advance(instance_name, token)
super().process_replication_position(stream_name, instance_name, token)
@staticmethod
def _should_skip_autosubscription_after_unsubscription(
*,
autosub: EventOrderings,
unsubscribed_at: EventOrderings,
) -> bool:
"""
Returns whether an automatic subscription occurring *after* an unsubscription
should be skipped, because the unsubscription already 'acknowledges' the event
causing the automatic subscription (the cause event).
To determine *after*, we use `stream_ordering` unless the event is backfilled
(negative `stream_ordering`) and fallback to topological ordering.
Args:
autosub: the stream_ordering and topological_ordering of the cause event
unsubscribed_at:
the maximum stream ordering and the maximum topological ordering at the time of unsubscription
Returns:
True if the automatic subscription should be skipped
"""
# For normal rooms, these two orderings should be positive, because
# they don't refer to a specific event but rather the maximum at the
# time of unsubscription.
#
# However, for rooms that have never been joined and that are being peeked at,
# we might not have a single non-backfilled event and therefore the stream
# ordering might be negative, so we don't assert this case.
assert unsubscribed_at.topological > 0
unsubscribed_at_backfilled = unsubscribed_at.stream < 0
if (
not unsubscribed_at_backfilled
and unsubscribed_at.stream >= autosub.stream > 0
):
# non-backfilled events: the unsubscription is later according to
# the stream
return True
if autosub.stream < 0:
# the auto-subscription cause event was backfilled, so fall back to
# topological ordering
if unsubscribed_at.topological >= autosub.topological:
return True
return False
async def subscribe_user_to_thread(
self, user_id: str, room_id: str, thread_root_event_id: str, *, automatic: bool
) -> Optional[int]:
self,
user_id: str,
room_id: str,
thread_root_event_id: str,
*,
automatic_event_orderings: Optional[EventOrderings],
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
"""Updates a user's subscription settings for a specific thread root.
If no change would be made to the subscription, does not produce any database change.
Case-by-case:
- if we already have an automatic subscription:
- new automatic subscriptions will be no-ops (no database write),
- new manual subscriptions will overwrite the automatic subscription
- if we already have a manual subscription:
we don't update (no database write) in either case, because:
- the existing manual subscription wins over a new automatic subscription request
- there would be no need to write a manual subscription because we already have one
Args:
user_id: The ID of the user whose settings are being updated.
room_id: The ID of the room the thread root belongs to.
thread_root_event_id: The event ID of the thread root.
automatic: Whether the subscription was performed automatically by the user's client.
Only `False` will overwrite an existing value of automatic for a subscription row.
automatic_event_orderings:
Value depends on whether the subscription was performed automatically by the user's client.
For manual subscriptions: None.
For automatic subscriptions: the orderings of the event.
Returns:
The stream ID for this update, if the update isn't no-opped.
If a subscription is made: (int) the stream ID for this update.
If a subscription already exists and did not need to be updated: None
If an automatic subscription conflicted with an unsubscription: AutomaticSubscriptionConflicted
"""
assert self._can_write_to_thread_subscriptions
def _subscribe_user_to_thread_txn(txn: LoggingTransaction) -> Optional[int]:
already_automatic = self.db_pool.simple_select_one_onecol_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
"subscribed": True,
},
retcol="automatic",
allow_none=True,
)
if already_automatic is None:
already_subscribed = False
already_automatic = True
else:
already_subscribed = True
# convert int (SQLite bool) to Python bool
already_automatic = bool(already_automatic)
if already_subscribed and already_automatic == automatic:
# there is nothing we need to do here
return None
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
values: Dict[str, Optional[Union[bool, int, str]]] = {
"subscribed": True,
"stream_id": stream_id,
"instance_name": self._instance_name,
"automatic": already_automatic and automatic,
}
self.db_pool.simple_upsert_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
},
values=values,
)
def _invalidate_subscription_caches(txn: LoggingTransaction) -> None:
txn.call_after(
self.get_subscription_for_thread.invalidate,
(user_id, room_id, thread_root_event_id),
)
txn.call_after(
self.get_subscribers_to_thread.invalidate,
(room_id, thread_root_event_id),
)
def _subscribe_user_to_thread_txn(
txn: LoggingTransaction,
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
requested_automatic = automatic_event_orderings is not None
row = self.db_pool.simple_select_one_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
},
retcols=(
"subscribed",
"automatic",
"unsubscribed_at_stream_ordering",
"unsubscribed_at_topological_ordering",
),
allow_none=True,
)
if row is None:
# We have never subscribed before, simply insert the row and finish
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
self.db_pool.simple_insert_txn(
txn,
table="thread_subscriptions",
values={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
"subscribed": True,
"stream_id": stream_id,
"instance_name": self._instance_name,
"automatic": requested_automatic,
"unsubscribed_at_stream_ordering": None,
"unsubscribed_at_topological_ordering": None,
},
)
_invalidate_subscription_caches(txn)
return stream_id
# we already have either a subscription or a prior unsubscription here
(
subscribed,
already_automatic,
unsubscribed_at_stream_ordering,
unsubscribed_at_topological_ordering,
) = row
if subscribed and (not already_automatic or requested_automatic):
# we are already subscribed and the current subscription state
# is good enough (either we already have a manual subscription,
# or we requested an automatic subscription)
# In that case, nothing to change here.
# (See docstring for case-by-case explanation)
return None
if not subscribed and requested_automatic:
assert automatic_event_orderings is not None
# we previously unsubscribed and we are now automatically subscribing
# Check whether the new autosubscription should be skipped
if ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription(
autosub=automatic_event_orderings,
unsubscribed_at=EventOrderings(
unsubscribed_at_stream_ordering,
unsubscribed_at_topological_ordering,
),
):
# skip the subscription
return AutomaticSubscriptionConflicted()
# At this point: we have now finished checking that we need to make
# a subscription, updating the current row.
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
self.db_pool.simple_update_txn(
txn,
table="thread_subscriptions",
keyvalues={
"user_id": user_id,
"event_id": thread_root_event_id,
"room_id": room_id,
},
updatevalues={
"subscribed": True,
"stream_id": stream_id,
"instance_name": self._instance_name,
"automatic": requested_automatic,
"unsubscribed_at_stream_ordering": None,
"unsubscribed_at_topological_ordering": None,
},
)
_invalidate_subscription_caches(txn)
return stream_id
@@ -214,6 +345,21 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
stream_id = self._thread_subscriptions_id_gen.get_next_txn(txn)
# Find the maximum stream ordering and topological ordering of the room,
# which we then store against this unsubscription so we can skip future
# automatic subscriptions that are caused by an event logically earlier
# than this unsubscription.
txn.execute(
"""
SELECT MAX(stream_ordering) AS mso, MAX(topological_ordering) AS mto FROM events
WHERE room_id = ?
""",
(room_id,),
)
ord_row = txn.fetchone()
assert ord_row is not None
max_stream_ordering, max_topological_ordering = ord_row
self.db_pool.simple_update_txn(
txn,
table="thread_subscriptions",
@@ -227,6 +373,8 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
"subscribed": False,
"stream_id": stream_id,
"instance_name": self._instance_name,
"unsubscribed_at_stream_ordering": max_stream_ordering,
"unsubscribed_at_topological_ordering": max_topological_ordering,
},
)
@@ -234,6 +382,10 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
self.get_subscription_for_thread.invalidate,
(user_id, room_id, thread_root_event_id),
)
txn.call_after(
self.get_subscribers_to_thread.invalidate,
(room_id, thread_root_event_id),
)
return stream_id
@@ -246,7 +398,9 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
Purge all subscriptions for the user.
The fact that subscriptions have been purged will not be streamed;
all stream rows for the user will in fact be removed.
This is intended only for dealing with user deactivation.
This must only be used for user deactivation,
because it does not invalidate the `subscribers_to_thread` cache.
"""
def _purge_thread_subscription_settings_for_user_txn(
@@ -307,6 +461,42 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
return ThreadSubscription(automatic=automatic)
# max_entries=100 rationale:
# this returns a potentially large datastructure
# (since each entry contains a set which contains a potentially large number of user IDs),
# whereas the default of 10'000 entries for @cached feels more
# suitable for very small cache entries.
#
# Overall, when bearing in mind the usual profile of a small community-server or company-server
# (where cache tuning hasn't been done, so we're in out-of-box configuration), it is very
# unlikely we would benefit from keeping hot the subscribers for as many as 100 threads,
# since it's unlikely that so many threads will be active in a short span of time on a small homeserver.
# It feels that medium servers will probably also not exhaust this limit.
# Larger homeservers are more likely to be carefully tuned, either with a larger global cache factor
# or carefully following the usage patterns & cache metrics.
# Finally, the query is not so intensive that computing it every time is a huge deal, but given people
# often send messages back-to-back in the same thread it seems like it would offer a mild benefit.
@cached(max_entries=100)
async def get_subscribers_to_thread(
self, room_id: str, thread_root_event_id: str
) -> FrozenSet[str]:
"""
Returns:
the set of user_ids for local users who are subscribed to the given thread.
"""
return frozenset(
await self.db_pool.simple_select_onecol(
table="thread_subscriptions",
keyvalues={
"room_id": room_id,
"event_id": thread_root_event_id,
"subscribed": True,
},
retcol="user_id",
desc="get_subscribers_to_thread",
)
)
def get_max_thread_subscriptions_stream_id(self) -> int:
"""Get the current maximum stream_id for thread subscriptions.
@@ -316,7 +506,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
return self._thread_subscriptions_id_gen.get_current_token()
async def get_updated_thread_subscriptions(
self, from_id: int, to_id: int, limit: int
self, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str, str]]:
"""Get updates to thread subscriptions between two stream IDs.
@@ -349,7 +539,7 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
)
async def get_updated_thread_subscriptions_for_user(
self, user_id: str, from_id: int, to_id: int, limit: int
self, user_id: str, *, from_id: int, to_id: int, limit: int
) -> List[Tuple[int, str, str]]:
"""Get updates to thread subscriptions for a specific user.

View File

@@ -0,0 +1,20 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- The maximum stream_ordering in the room when the unsubscription was made.
ALTER TABLE thread_subscriptions
ADD COLUMN unsubscribed_at_stream_ordering BIGINT;
-- The maximum topological_ordering in the room when the unsubscription was made.
ALTER TABLE thread_subscriptions
ADD COLUMN unsubscribed_at_topological_ordering BIGINT;

View File

@@ -0,0 +1,18 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_stream_ordering IS
$$The maximum stream_ordering in the room when the unsubscription was made.$$;
COMMENT ON COLUMN thread_subscriptions.unsubscribed_at_topological_ordering IS
$$The maximum topological_ordering in the room when the unsubscription was made.$$;

View File

@@ -49,6 +49,7 @@ class FilteredPushRules:
msc3664_enabled: bool,
msc4028_push_encrypted_events: bool,
msc4210_enabled: bool,
msc4306_enabled: bool,
): ...
def rules(self) -> Collection[Tuple[PushRule, bool]]: ...
@@ -67,13 +68,19 @@ class PushRuleEvaluator:
room_version_feature_flags: Tuple[str, ...],
msc3931_enabled: bool,
msc4210_enabled: bool,
msc4306_enabled: bool,
): ...
def run(
self,
push_rules: FilteredPushRules,
user_id: Optional[str],
display_name: Optional[str],
msc4306_thread_subscription_state: Optional[bool],
) -> Collection[Union[Mapping, str]]: ...
def matches(
self, condition: JsonDict, user_id: Optional[str], display_name: Optional[str]
self,
condition: JsonDict,
user_id: Optional[str],
display_name: Optional[str],
msc4306_thread_subscription_state: Optional[bool] = None,
) -> bool: ...

View File

@@ -73,6 +73,7 @@ if TYPE_CHECKING:
from typing_extensions import Self
from synapse.appservice.api import ApplicationService
from synapse.events import EventBase
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.util.id_generators import MultiWriterIdGenerator
@@ -1464,3 +1465,31 @@ class ScheduledTask:
result: Optional[JsonMapping]
# Optional error that should be assigned a value when the status is FAILED
error: Optional[str]
@attr.s(auto_attribs=True, frozen=True, slots=True)
class EventOrderings:
stream: int
"""
The stream_ordering of the event.
Negative numbers mean the event was backfilled.
"""
topological: int
"""
The topological_ordering of the event.
Currently this is equivalent to the `depth` attributes of
the PDU.
"""
@staticmethod
def from_event(event: "EventBase") -> "EventOrderings":
"""
Get the orderings from an event.
Preconditions:
- the event must have been persisted (otherwise it won't have a stream ordering)
"""
stream = event.internal_metadata.stream_ordering
assert stream is not None
return EventOrderings(stream, event.depth)

View File

@@ -27,7 +27,6 @@ from typing import (
Any,
Callable,
Dict,
Generator,
Iterator,
Mapping,
Optional,
@@ -42,7 +41,6 @@ from matrix_common.versionstring import get_distribution_version_string
from typing_extensions import ParamSpec
from twisted.internet import defer, task
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IDelayedCall, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.python.failure import Failure
@@ -121,13 +119,11 @@ class Clock:
_reactor: IReactorTime = attr.ib()
@defer.inlineCallbacks
def sleep(self, seconds: float) -> "Generator[Deferred[float], Any, Any]":
async def sleep(self, seconds: float) -> None:
d: defer.Deferred[float] = defer.Deferred()
with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
res = yield d
return res
await d
def time(self) -> float:
"""Returns the current system time in seconds since epoch."""

View File

@@ -13,7 +13,11 @@
#
#
from synapse._pydantic_compat import BaseModel, Extra
import re
from typing import Any, Callable, Generator
from synapse._pydantic_compat import BaseModel, Extra, StrictStr
from synapse.types import EventID
class ParseModel(BaseModel):
@@ -37,3 +41,43 @@ class ParseModel(BaseModel):
extra = Extra.ignore
# By default, don't allow fields to be reassigned after parsing.
allow_mutation = False
class AnyEventId(StrictStr):
"""
A validator for strings that need to be an Event ID.
Accepts any valid grammar of Event ID from any room version.
"""
EVENT_ID_HASH_ROOM_VERSION_3_PLUS = re.compile(
r"^([a-zA-Z0-9-_]{43}|[a-zA-Z0-9+/]{43})$"
)
@classmethod
def __get_validators__(cls) -> Generator[Callable[..., Any], Any, Any]:
yield from super().__get_validators__() # type: ignore
yield cls.validate_event_id
@classmethod
def validate_event_id(cls, value: str) -> str:
if not value.startswith("$"):
raise ValueError("Event ID must start with `$`")
if ":" in value:
# Room versions 1 and 2
EventID.from_string(value) # throws on fail
else:
# Room versions 3+: event ID is $ + a base64 sha256 hash
# Room version 3 is base64, 4+ are base64Url
# In both cases, the base64 is unpadded.
# refs:
# - https://spec.matrix.org/v1.15/rooms/v3/ e.g. $acR1l0raoZnm60CBwAVgqbZqoO/mYU81xysh1u7XcJk
# - https://spec.matrix.org/v1.15/rooms/v4/ e.g. $Rqnc-F-dvnEYJTyHq_iKxU2bZ1CI92-kuZq3a5lr5Zg
b64_hash = value[1:]
if cls.EVENT_ID_HASH_ROOM_VERSION_3_PLUS.fullmatch(b64_hash) is None:
raise ValueError(
"Event ID must either have a domain part or be a valid hash"
)
return value

View File

@@ -22,7 +22,7 @@
from synapse.api.constants import EduTypes
from tests import unittest
from tests.unittest import DEBUG, override_config
from tests.unittest import override_config
class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
@@ -48,7 +48,6 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
)
self.assertEqual(200, channel.code)
@DEBUG
def test_edu_debugging_doesnt_explode(self) -> None:
"""Sanity check incoming federation succeeds with `synapse.debug_8631` enabled.

View File

@@ -26,7 +26,7 @@ from parameterized import parameterized
from twisted.internet.testing import MemoryReactor
from synapse.api.constants import EventContentFields, RelationTypes
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.rest import admin
@@ -206,7 +206,10 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
bulk_evaluator._action_for_event_by_user.assert_not_called()
def _create_and_process(
self, bulk_evaluator: BulkPushRuleEvaluator, content: Optional[JsonDict] = None
self,
bulk_evaluator: BulkPushRuleEvaluator,
content: Optional[JsonDict] = None,
type: str = "test",
) -> bool:
"""Returns true iff the `mentions` trigger an event push action."""
# Create a new message event which should cause a notification.
@@ -214,7 +217,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
self.event_creation_handler.create_event(
self.requester,
{
"type": "test",
"type": type,
"room_id": self.room_id,
"content": content or {},
"sender": f"@bob:{self.hs.hostname}",
@@ -446,3 +449,73 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
},
)
)
@override_config({"experimental_features": {"msc4306_enabled": True}})
def test_thread_subscriptions(self) -> None:
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
(thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token)
self.assertFalse(
self._create_and_process(
bulk_evaluator,
{
"msgtype": "m.text",
"body": "test message before subscription",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
type=EventTypes.Message,
)
)
self.get_success(
self.hs.get_datastores().main.subscribe_user_to_thread(
self.alice,
self.room_id,
thread_root_id,
automatic_event_orderings=None,
)
)
self.assertTrue(
self._create_and_process(
bulk_evaluator,
{
"msgtype": "m.text",
"body": "test message after subscription",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
type="m.room.message",
)
)
def test_with_disabled_thread_subscriptions(self) -> None:
"""
Test what happens with threaded events when MSC4306 is disabled.
FUTURE: If MSC4306 becomes enabled-by-default/accepted, this test is to be removed.
"""
bulk_evaluator = BulkPushRuleEvaluator(self.hs)
(thread_root_id,) = self.helper.send_messages(self.room_id, 1, tok=self.token)
# When MSC4306 is not enabled, a threaded message generates a notification
# by default.
self.assertTrue(
self._create_and_process(
bulk_evaluator,
{
"msgtype": "m.text",
"body": "test message before subscription",
"m.relates_to": {
"rel_type": RelationTypes.THREAD,
"event_id": thread_root_id,
},
},
type="m.room.message",
)
)

View File

@@ -150,6 +150,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
*,
related_events: Optional[JsonDict] = None,
msc4210: bool = False,
msc4306: bool = False,
) -> PushRuleEvaluator:
event = FrozenEvent(
{
@@ -176,6 +177,7 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
room_version_feature_flags=event.room_version.msc3931_push_features,
msc3931_enabled=True,
msc4210_enabled=msc4210,
msc4306_enabled=msc4306,
)
def test_display_name(self) -> None:
@@ -806,6 +808,112 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
)
)
def test_thread_subscription_subscribed(self) -> None:
"""
Test MSC4306 thread subscription push rules against an event in a subscribed thread.
"""
evaluator = self._get_evaluator(
{
"msgtype": "m.text",
"body": "Squawk",
"m.relates_to": {
"event_id": "$threadroot",
"rel_type": "m.thread",
},
},
msc4306=True,
)
self.assertTrue(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": True,
},
None,
None,
msc4306_thread_subscription_state=True,
)
)
self.assertFalse(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": False,
},
None,
None,
msc4306_thread_subscription_state=True,
)
)
def test_thread_subscription_unsubscribed(self) -> None:
"""
Test MSC4306 thread subscription push rules against an event in an unsubscribed thread.
"""
evaluator = self._get_evaluator(
{
"msgtype": "m.text",
"body": "Squawk",
"m.relates_to": {
"event_id": "$threadroot",
"rel_type": "m.thread",
},
},
msc4306=True,
)
self.assertFalse(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": True,
},
None,
None,
msc4306_thread_subscription_state=False,
)
)
self.assertTrue(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": False,
},
None,
None,
msc4306_thread_subscription_state=False,
)
)
def test_thread_subscription_unthreaded(self) -> None:
"""
Test MSC4306 thread subscription push rules against an unthreaded event.
"""
evaluator = self._get_evaluator(
{"msgtype": "m.text", "body": "Squawk"}, msc4306=True
)
self.assertFalse(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": True,
},
None,
None,
msc4306_thread_subscription_state=None,
)
)
self.assertFalse(
evaluator.matches(
{
"kind": "io.element.msc4306.thread_subscription",
"subscribed": False,
},
None,
None,
msc4306_thread_subscription_state=None,
)
)
class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
"""Tests for the bulk push rule evaluator"""

View File

@@ -62,7 +62,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
"@test_user:example.org",
room_id,
thread_root_id,
automatic=True,
automatic_event_orderings=None,
)
)
updates.append(thread_root_id)
@@ -75,7 +75,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
"@test_user:example.org",
other_room_id,
other_thread_root_id,
automatic=False,
automatic_event_orderings=None,
)
)
@@ -124,7 +124,7 @@ class ThreadSubscriptionsStreamTestCase(BaseStreamTestCase):
for user_id in users:
self.get_success(
store.subscribe_user_to_thread(
user_id, room_id, thread_root_id, automatic=True
user_id, room_id, thread_root_id, automatic_event_orderings=None
)
)

View File

@@ -15,6 +15,7 @@ from http import HTTPStatus
from twisted.internet.testing import MemoryReactor
from synapse.api.errors import Codes
from synapse.rest import admin
from synapse.rest.client import login, profile, room, thread_subscriptions
from synapse.server import HomeServer
@@ -49,15 +50,16 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Create a room and send a message to use as a thread root
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
self.helper.join(self.room_id, self.other_user_id, tok=self.other_token)
response = self.helper.send(self.room_id, body="Root message", tok=self.token)
self.root_event_id = response["event_id"]
(self.root_event_id,) = self.helper.send_messages(
self.room_id, 1, tok=self.token
)
# Send a message in the thread
self.helper.send_event(
room_id=self.room_id,
type="m.room.message",
content={
"body": "Thread message",
self.threaded_events = self.helper.send_messages(
self.room_id,
2,
content_fn=lambda idx: {
"body": f"Thread message {idx}",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": "m.thread",
@@ -106,9 +108,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": False,
},
{},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -127,7 +127,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": True},
{"automatic": self.threaded_events[0]},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -148,11 +148,11 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": True,
"automatic": self.threaded_events[0],
},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body)
# Assert the subscription was saved
channel = self.make_request(
@@ -167,7 +167,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": False},
{},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
@@ -187,7 +187,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": True,
"automatic": self.threaded_events[0],
},
access_token=self.token,
)
@@ -202,7 +202,6 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.json_body, {"automatic": True})
# Now also register a manual subscription
channel = self.make_request(
"DELETE",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
@@ -210,7 +209,6 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, HTTPStatus.OK)
# Assert the manual subscription was not overridden
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
@@ -224,7 +222,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/$nonexistent:example.org/subscription",
{"automatic": True},
{},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
@@ -238,7 +236,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": True},
{},
access_token=no_access_token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
@@ -249,8 +247,105 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
# non-boolean `automatic`
{"automatic": "true"},
# non-Event ID `automatic`
{"automatic": True},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
# non-Event ID `automatic`
{"automatic": "$malformedEventId"},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
def test_auto_subscribe_cause_event_not_in_thread(self) -> None:
"""
Test making an automatic subscription, where the cause event is not
actually in the thread.
This is an error.
"""
(unrelated_event_id,) = self.helper.send_messages(
self.room_id, 1, tok=self.token
)
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{"automatic": unrelated_event_id},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.text_body)
self.assertEqual(channel.json_body["errcode"], Codes.MSC4306_NOT_IN_THREAD)
def test_auto_resubscription_conflict(self) -> None:
"""
Test that an automatic subscription that conflicts with an unsubscription
is skipped.
"""
# Reuse the test that subscribes and unsubscribes
self.test_unsubscribe()
# Now no matter which event we present as the cause of an automatic subscription,
# the automatic subscription is skipped.
# This is because the unsubscription happened after all of the events.
for event in self.threaded_events:
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": event,
},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.CONFLICT, channel.text_body)
self.assertEqual(
channel.json_body["errcode"],
Codes.MSC4306_CONFLICTING_UNSUBSCRIPTION,
channel.text_body,
)
# Check the subscription was not made
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.NOT_FOUND)
# But if a new event is sent after the unsubscription took place,
# that one can be used for an automatic subscription
(later_event_id,) = self.helper.send_messages(
self.room_id,
1,
content_fn=lambda _: {
"body": "Thread message after unsubscription",
"msgtype": "m.text",
"m.relates_to": {
"rel_type": "m.thread",
"event_id": self.root_event_id,
},
},
tok=self.token,
)
channel = self.make_request(
"PUT",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
{
"automatic": later_event_id,
},
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.text_body)
# Check the subscription was made
channel = self.make_request(
"GET",
f"{PREFIX}/{self.room_id}/thread/{self.root_event_id}/subscription",
access_token=self.token,
)
self.assertEqual(channel.code, HTTPStatus.OK)
self.assertEqual(channel.json_body, {"automatic": True})

View File

@@ -90,7 +90,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
) -> Generator["defer.Deferred[Any]", object, None]:
@defer.inlineCallbacks
def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]:
yield Clock(reactor).sleep(0)
yield defer.ensureDeferred(Clock(reactor).sleep(0))
return 1, {}
@defer.inlineCallbacks

View File

@@ -29,12 +29,14 @@ from http import HTTPStatus
from typing import (
Any,
AnyStr,
Callable,
Dict,
Iterable,
Literal,
Mapping,
MutableMapping,
Optional,
Sequence,
Tuple,
overload,
)
@@ -45,7 +47,7 @@ import attr
from twisted.internet.testing import MemoryReactorClock
from twisted.web.server import Site
from synapse.api.constants import Membership, ReceiptTypes
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.errors import Codes
from synapse.server import HomeServer
from synapse.types import JsonDict
@@ -394,6 +396,32 @@ class RestHelper:
custom_headers=custom_headers,
)
def send_messages(
self,
room_id: str,
num_events: int,
content_fn: Callable[[int], JsonDict] = lambda idx: {
"msgtype": "m.text",
"body": f"Test event {idx}",
},
tok: Optional[str] = None,
) -> Sequence[str]:
"""
Helper to send a handful of sequential events and return their event IDs as a sequence.
"""
event_ids = []
for event_index in range(num_events):
response = self.send_event(
room_id,
EventTypes.Message,
content_fn(event_index),
tok=tok,
)
event_ids.append(response["event_id"])
return event_ids
def send_event(
self,
room_id: str,

View File

@@ -131,7 +131,7 @@ class ServerNoticesTests(unittest.HomeserverTestCase):
break
# Sleep and try again.
self.clock.sleep(0.1)
self.get_success(self.clock.sleep(0.1))
else:
self.fail(
f"Failed to join the server notices room. No 'join' field in sync_body['rooms']: {sync_body['rooms']}"

View File

@@ -65,8 +65,8 @@ ORIGIN_SERVER_TS = 0
class FakeClock:
def sleep(self, msec: float) -> "defer.Deferred[None]":
return defer.succeed(None)
async def sleep(self, msec: float) -> None:
return None
class FakeEvent:

View File

@@ -12,13 +12,18 @@
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
from typing import Optional
from typing import Optional, Union
from twisted.internet.testing import MemoryReactor
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.thread_subscriptions import (
AutomaticSubscriptionConflicted,
ThreadSubscriptionsWorkerStore,
)
from synapse.storage.engines.sqlite import Sqlite3Engine
from synapse.types import EventOrderings
from synapse.util import Clock
from tests import unittest
@@ -97,10 +102,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
self,
thread_root_id: str,
*,
automatic: bool,
automatic_event_orderings: Optional[EventOrderings],
room_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> Optional[int]:
) -> Optional[Union[int, AutomaticSubscriptionConflicted]]:
if user_id is None:
user_id = self.user_id
@@ -112,7 +117,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
user_id,
room_id,
thread_root_id,
automatic=automatic,
automatic_event_orderings=automatic_event_orderings,
)
)
@@ -149,7 +154,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Subscribe
self._subscribe(
self.thread_root_id,
automatic=True,
automatic_event_orderings=EventOrderings(1, 1),
)
# Assert subscription went through
@@ -164,7 +169,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Now make it a manual subscription
self._subscribe(
self.thread_root_id,
automatic=False,
automatic_event_orderings=None,
)
# Assert the manual subscription overrode the automatic one
@@ -178,8 +183,10 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
def test_purge_thread_subscriptions_for_user(self) -> None:
"""Test purging all thread subscription settings for a user."""
# Set subscription settings for multiple threads
self._subscribe(self.thread_root_id, automatic=True)
self._subscribe(self.other_thread_root_id, automatic=False)
self._subscribe(
self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1)
)
self._subscribe(self.other_thread_root_id, automatic_event_orderings=None)
subscriptions = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
@@ -217,20 +224,32 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
def test_get_updated_thread_subscriptions(self) -> None:
"""Test getting updated thread subscriptions since a stream ID."""
stream_id1 = self._subscribe(self.thread_root_id, automatic=False)
stream_id2 = self._subscribe(self.other_thread_root_id, automatic=True)
assert stream_id1 is not None
assert stream_id2 is not None
stream_id1 = self._subscribe(
self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1)
)
stream_id2 = self._subscribe(
self.other_thread_root_id, automatic_event_orderings=EventOrderings(2, 2)
)
assert stream_id1 is not None and not isinstance(
stream_id1, AutomaticSubscriptionConflicted
)
assert stream_id2 is not None and not isinstance(
stream_id2, AutomaticSubscriptionConflicted
)
# Get updates since initial ID (should include both changes)
updates = self.get_success(
self.store.get_updated_thread_subscriptions(0, stream_id2, 10)
self.store.get_updated_thread_subscriptions(
from_id=0, to_id=stream_id2, limit=10
)
)
self.assertEqual(len(updates), 2)
# Get updates since first change (should include only the second change)
updates = self.get_success(
self.store.get_updated_thread_subscriptions(stream_id1, stream_id2, 10)
self.store.get_updated_thread_subscriptions(
from_id=stream_id1, to_id=stream_id2, limit=10
)
)
self.assertEqual(
updates,
@@ -242,21 +261,27 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
other_user_id = "@other_user:test"
# Set thread subscription for main user
stream_id1 = self._subscribe(self.thread_root_id, automatic=True)
assert stream_id1 is not None
stream_id1 = self._subscribe(
self.thread_root_id, automatic_event_orderings=EventOrderings(1, 1)
)
assert stream_id1 is not None and not isinstance(
stream_id1, AutomaticSubscriptionConflicted
)
# Set thread subscription for other user
stream_id2 = self._subscribe(
self.other_thread_root_id,
automatic=True,
automatic_event_orderings=EventOrderings(1, 1),
user_id=other_user_id,
)
assert stream_id2 is not None
assert stream_id2 is not None and not isinstance(
stream_id2, AutomaticSubscriptionConflicted
)
# Get updates for main user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
self.user_id, 0, stream_id2, 10
self.user_id, from_id=0, to_id=stream_id2, limit=10
)
)
self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)])
@@ -264,9 +289,80 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
# Get updates for other user
updates = self.get_success(
self.store.get_updated_thread_subscriptions_for_user(
other_user_id, 0, max(stream_id1, stream_id2), 10
other_user_id, from_id=0, to_id=max(stream_id1, stream_id2), limit=10
)
)
self.assertEqual(
updates, [(stream_id2, self.room_id, self.other_thread_root_id)]
)
def test_should_skip_autosubscription_after_unsubscription(self) -> None:
"""
Tests the comparison logic for whether an autoscription should be skipped
due to a chronologically earlier but logically later unsubscription.
"""
func = ThreadSubscriptionsWorkerStore._should_skip_autosubscription_after_unsubscription
# Order of arguments:
# automatic cause event: stream order, then topological order
# unsubscribe maximums: stream order, then tological order
# both orderings agree that the unsub is after the cause event
self.assertTrue(
func(autosub=EventOrderings(1, 1), unsubscribed_at=EventOrderings(2, 2))
)
# topological ordering is inconsistent with stream ordering,
# in that case favour stream ordering because it's what /sync uses
self.assertTrue(
func(autosub=EventOrderings(1, 2), unsubscribed_at=EventOrderings(2, 1))
)
# the automatic subscription is caused by a backfilled event here
# unfortunately we must fall back to topological ordering here
self.assertTrue(
func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 3))
)
self.assertFalse(
func(autosub=EventOrderings(-50, 2), unsubscribed_at=EventOrderings(2, 1))
)
def test_get_subscribers_to_thread(self) -> None:
"""
Test getting all subscribers to a thread at once.
To check cache invalidations are correct, we do multiple
step-by-step rounds of subscription changes and assertions.
"""
other_user_id = "@other_user:test"
subscribers = self.get_success(
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
)
self.assertEqual(subscribers, frozenset())
self._subscribe(
self.thread_root_id, automatic_event_orderings=None, user_id=self.user_id
)
subscribers = self.get_success(
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
)
self.assertEqual(subscribers, frozenset((self.user_id,)))
self._subscribe(
self.thread_root_id, automatic_event_orderings=None, user_id=other_user_id
)
subscribers = self.get_success(
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
)
self.assertEqual(subscribers, frozenset((self.user_id, other_user_id)))
self._unsubscribe(self.thread_root_id, user_id=self.user_id)
subscribers = self.get_success(
self.store.get_subscribers_to_thread(self.room_id, self.thread_root_id)
)
self.assertEqual(subscribers, frozenset((other_user_id,)))

View File

@@ -51,20 +51,18 @@ class LoggingContextTestCase(unittest.TestCase):
with LoggingContext("test"):
self._check_test_key("test")
@defer.inlineCallbacks
def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]:
async def test_sleep(self) -> None:
clock = Clock(reactor)
@defer.inlineCallbacks
def competing_callback() -> Generator["defer.Deferred[object]", object, None]:
async def competing_callback() -> None:
with LoggingContext("competing"):
yield clock.sleep(0)
await clock.sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
reactor.callLater(0, lambda: defer.ensureDeferred(competing_callback()))
with LoggingContext("one"):
yield clock.sleep(0)
await clock.sleep(0)
self._check_test_key("one")
def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred:
@@ -108,9 +106,8 @@ class LoggingContextTestCase(unittest.TestCase):
return d2
def test_run_in_background_with_blocking_fn(self) -> defer.Deferred:
@defer.inlineCallbacks
def blocking_function() -> Generator["defer.Deferred[object]", object, None]:
yield Clock(reactor).sleep(0)
async def blocking_function() -> None:
await Clock(reactor).sleep(0)
return self._test_run_in_background(blocking_function)
@@ -133,7 +130,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_run_in_background_with_coroutine(self) -> defer.Deferred:
async def testfunc() -> None:
self._check_test_key("one")
d = Clock(reactor).sleep(0)
d = defer.ensureDeferred(Clock(reactor).sleep(0))
self.assertIs(current_context(), SENTINEL_CONTEXT)
await d
self._check_test_key("one")