1
0

Compare commits

...

12 Commits

Author SHA1 Message Date
Kegan Dougal b338f886d6 Use ed25519:1 as the key ID 2025-09-10 15:35:37 +01:00
Kegan Dougal c0ffe61adb Add account_keys table to store all key<->name mappings 2025-09-09 15:19:57 +01:00
Kegan Dougal aefeb3cb58 Add MSC4243 flag, verify event signatures from the user ID 2025-09-09 13:11:34 +01:00
Andrew Morgan 39e4f27347 Merge branch 'master' into develop 2025-09-09 12:30:12 +01:00
reivilibre 6fe8137a4a Configure Synapse to run MSC4306: Thread Subscriptions Complement tests. (#18819)
Pairs with: https://github.com/matrix-org/complement/pull/795

Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
2025-09-09 11:40:10 +01:00
David Baker d48e69ad4c Fix prefixed support for MSC4133 (#18875)
This fixes two bugs that affect the availability of MSC4133 until the
next spec release.

1. The servlet didn't recognise the unstable endpoint even when the
homeserver advertised it
 2. The HS didn't advertise support for the stable prefixed version

Would only have been a problem until the next spec release but it's nice
to have it work before then.
2025-09-09 09:53:08 +01:00
Amin Farjadi 74fdbc7b75 Fix typo in structured_logging.md for file handler config (#18872) 2025-09-09 09:51:36 +01:00
Jason Little 4d55f2f301 fix: Use the Enum's value for the dictionary key when responding to an admin request for experimental features (#18874)
While exploring bring up of using `orjson`, exposed an interesting flaw.
The stdlib `json` encoder seems to be ok with coercing a `str` from an
`Enum`(specifically, a `Class[str, Enum]`). The `orjson` encoder does
not like that this is a class and not a proper `str` per spec. Using the
`.value` of the enum as the key for the dict produced while answering a
`GET` admin request for experimental features seems to fix this.
2025-09-09 09:50:09 +01:00
reivilibre dfccde9f60 Remove obsolete and experimental /sync/e2ee endpoint. (#18583)
Introduced in: https://github.com/element-hq/synapse/pull/17167

The endpoint was part of experiments for MSC3575 but does not feature in
that MSC.

Signed-off-by: Olivier 'reivilibre <oliverw@matrix.org>
2025-09-09 09:28:45 +01:00
Erik Johnston 4b43e6fe02 Handle rescinding invites over federation (#18823)
We should send events that rescind invites over federation.

Similarly, we should handle receiving such events. Unfortunately, the
protocol doesn't make it possible to fully auth such events, and so we
can only handle the case where the original inviter rescinded the invite
(rather than a room admin).

Complement test: https://github.com/matrix-org/complement/pull/797
2025-09-08 10:55:48 +01:00
Eric Eastwood b2997a8f20 Suppress "Applying schema" log noise bulk when running Complement tests (#18878)
If Synapse is under test (`SYNAPSE_LOG_TESTING` is set), we don't care
about seeing the "Applying schema" log lines at the INFO level every
time we run the tests (it's 100 lines of bulk for each homeserver).

```
synapse_main | 2025-08-29 22:34:03,453 - synapse.storage.prepare_database - 433 - INFO - main - Applying schema deltas for v73
synapse_main | 2025-08-29 22:34:03,454 - synapse.storage.prepare_database - 541 - INFO - main - Applying schema 73/01event_failed_pull_attempts.sql
synapse_main | 2025-08-29 22:34:03,463 - synapse.storage.prepare_database - 541 - INFO - main - Applying schema 73/02add_pusher_enabled.sql
synapse_main | 2025-08-29 22:34:03,473 - synapse.storage.prepare_database - 541 - INFO - main - Applying schema 73/02room_id_indexes_for_purging.sql
synapse_main | 2025-08-29 22:34:03,482 - synapse.storage.prepare_database - 541 - INFO - main - Applying schema 73/03pusher_device_id.sql
synapse_main | 2025-08-29 22:34:03,492 - synapse.storage.prepare_database - 541 - INFO - main - Applying schema 73/03users_approved_column.sql
synapse_main | 2025-08-29 22:34:03,502 - synapse.storage.prepare_database - 541 - INFO - main - Applying schema 73/04partial_join_details.sql
synapse_main | 2025-08-29 22:34:03,513 - synapse.storage.prepare_database - 541 - INFO - main - Applying schema 73/04pending_device_list_updates.sql
...
```


The Synapse logs are visible when a Complement test fails or you use
`COMPLEMENT_ALWAYS_PRINT_SERVER_LOGS=1`. This is spawning from a
Complement test with three homeservers and wanting less log noise to
scroll through.
2025-09-02 13:34:47 -05:00
Eric Eastwood bff4a11b3f Re-introduce: Fix LaterGauge metrics to collect from all servers (#18791)
Re-introduce: https://github.com/element-hq/synapse/pull/18751 that was
reverted in https://github.com/element-hq/synapse/pull/18789 (explains
why the PR was reverted in the first place).

- Adds a `cleanup` pattern that cleans up metrics from each homeserver
in the tests. Previously, the list of hooks built up until our CI
machines couldn't operate properly, see
https://github.com/element-hq/synapse/pull/18789
- Fix long-standing issue with `synapse_background_update_status`
metrics only tracking the last database listed in the config (see
https://github.com/element-hq/synapse/pull/18791#discussion_r2261706749)
2025-09-02 12:14:27 -05:00
57 changed files with 1049 additions and 757 deletions
+1
View File
@@ -0,0 +1 @@
Remove obsolete and experimental `/sync/e2ee` endpoint.
+1
View File
@@ -0,0 +1 @@
Fix `LaterGauge` metrics to collect from all servers.
+1
View File
@@ -0,0 +1 @@
Configure Synapse to run MSC4306: Thread Subscriptions Complement tests.
+1
View File
@@ -0,0 +1 @@
Fix bug where we did not send invite revocations over federation.
+1
View File
@@ -0,0 +1 @@
Update push rules for experimental [MSC4306: Thread Subscriptions](https://github.com/matrix-org/matrix-doc/issues/4306) to follow newer draft.
+1
View File
@@ -0,0 +1 @@
Use the `Enum`'s value for the dictionary key when responding to an admin request for experimental features.
+1
View File
@@ -0,0 +1 @@
Fix prefixed support for MSC4133.
+1
View File
@@ -0,0 +1 @@
Suppress "Applying schema" log noise bulk when `SYNAPSE_LOG_TESTING` is set.
@@ -133,6 +133,8 @@ experimental_features:
msc3984_appservice_key_query: true
# Invite filtering
msc4155_enabled: true
# Thread Subscriptions
msc4306_enabled: true
server_notices:
system_mxid_localpart: _server
+7
View File
@@ -77,6 +77,13 @@ loggers:
#}
synapse.visibility.filtered_event_debug:
level: DEBUG
{#
If Synapse is under test, we don't care about seeing the "Applying schema" log
lines at the INFO level every time we run the tests (it's 100 lines of bulk)
#}
synapse.storage.prepare_database:
level: WARN
{% endif %}
root:
+1 -1
View File
@@ -35,7 +35,7 @@ handlers:
loggers:
synapse:
level: INFO
handlers: [remote]
handlers: [file]
synapse.storage.SQL:
level: WARNING
```
+9 -5
View File
@@ -289,10 +289,10 @@ pub const BASE_APPEND_CONTENT_RULES: &[PushRule] = &[PushRule {
default_enabled: true,
}];
pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
pub const BASE_APPEND_POSTCONTENT_RULES: &[PushRule] = &[
PushRule {
rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.unsubscribed_thread"),
priority_class: 1,
rule_id: Cow::Borrowed("global/postcontent/.io.element.msc4306.rule.unsubscribed_thread"),
priority_class: 6,
conditions: Cow::Borrowed(&[Condition::Known(
KnownCondition::Msc4306ThreadSubscription { subscribed: false },
)]),
@@ -301,8 +301,8 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
default_enabled: true,
},
PushRule {
rule_id: Cow::Borrowed("global/content/.io.element.msc4306.rule.subscribed_thread"),
priority_class: 1,
rule_id: Cow::Borrowed("global/postcontent/.io.element.msc4306.rule.subscribed_thread"),
priority_class: 6,
conditions: Cow::Borrowed(&[Condition::Known(
KnownCondition::Msc4306ThreadSubscription { subscribed: true },
)]),
@@ -310,6 +310,9 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
default: true,
default_enabled: true,
},
];
pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[
PushRule {
rule_id: Cow::Borrowed("global/underride/.m.rule.call"),
priority_class: 1,
@@ -726,6 +729,7 @@ lazy_static! {
.iter()
.chain(BASE_APPEND_OVERRIDE_RULES.iter())
.chain(BASE_APPEND_CONTENT_RULES.iter())
.chain(BASE_APPEND_POSTCONTENT_RULES.iter())
.chain(BASE_APPEND_UNDERRIDE_RULES.iter())
.map(|rule| { (&*rule.rule_id, rule) })
.collect();
+1
View File
@@ -527,6 +527,7 @@ impl PushRules {
.chain(base_rules::BASE_APPEND_OVERRIDE_RULES.iter())
.chain(self.content.iter())
.chain(base_rules::BASE_APPEND_CONTENT_RULES.iter())
.chain(base_rules::BASE_APPEND_POSTCONTENT_RULES.iter())
.chain(self.room.iter())
.chain(self.sender.iter())
.chain(self.underride.iter())
+1
View File
@@ -230,6 +230,7 @@ test_packages=(
./tests/msc3967
./tests/msc4140
./tests/msc4155
./tests/msc4306
)
# Enable dirty runs, so tests will reuse the same container where possible.
+5 -1
View File
@@ -153,9 +153,13 @@ def get_registered_paths_for_default(
"""
hs = MockHomeserver(base_config, worker_app)
# TODO We only do this to avoid an error, but don't need the database etc
hs.setup()
return get_registered_paths_for_hs(hs)
registered_paths = get_registered_paths_for_hs(hs)
hs.cleanup()
return registered_paths
def elide_http_methods_if_unconflicting(
+14 -1
View File
@@ -99,6 +99,7 @@ from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
from synapse.types import ISynapseReactor
from synapse.util import SYNAPSE_VERSION, Clock
from synapse.util.stringutils import random_string
# Cast safety: Twisted does some naughty magic which replaces the
# twisted.internet.reactor module with a Reactor instance at runtime.
@@ -323,6 +324,7 @@ class MockHomeserver:
self.config = config
self.hostname = config.server.server_name
self.version_string = SYNAPSE_VERSION
self.instance_id = random_string(5)
def get_clock(self) -> Clock:
return self.clock
@@ -330,6 +332,9 @@ class MockHomeserver:
def get_reactor(self) -> ISynapseReactor:
return reactor
def get_instance_id(self) -> str:
return self.instance_id
def get_instance_name(self) -> str:
return "master"
@@ -685,7 +690,15 @@ class Porter:
)
prepare_database(db_conn, engine, config=self.hs_config)
# Type safety: ignore that we're using Mock homeservers here.
store = Store(DatabasePool(hs, db_config, engine), db_conn, hs) # type: ignore[arg-type]
store = Store(
DatabasePool(
hs, # type: ignore[arg-type]
db_config,
engine,
),
db_conn,
hs, # type: ignore[arg-type]
)
db_conn.commit()
return store
+42
View File
@@ -116,6 +116,8 @@ class RoomVersion:
msc4289_creator_power_enabled: bool
# MSC4291: Room IDs as hashes of the create event
msc4291_room_ids_as_hashes: bool
# MSC4243: User ID localparts as Account Keys
msc4243_account_keys: bool
class RoomVersions:
@@ -140,6 +142,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
V2 = RoomVersion(
"2",
@@ -162,6 +165,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
V3 = RoomVersion(
"3",
@@ -184,6 +188,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
V4 = RoomVersion(
"4",
@@ -206,6 +211,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
V5 = RoomVersion(
"5",
@@ -228,6 +234,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
V6 = RoomVersion(
"6",
@@ -250,6 +257,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
V7 = RoomVersion(
"7",
@@ -272,6 +280,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
V8 = RoomVersion(
"8",
@@ -294,6 +303,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
V9 = RoomVersion(
"9",
@@ -316,6 +326,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
V10 = RoomVersion(
"10",
@@ -338,6 +349,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
MSC1767v10 = RoomVersion(
# MSC1767 (Extensible Events) based on room version "10"
@@ -361,6 +373,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
MSC3757v10 = RoomVersion(
# MSC3757 (Restricting who can overwrite a state event) based on room version "10"
@@ -384,6 +397,7 @@ class RoomVersions:
msc3757_enabled=True,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
V11 = RoomVersion(
"11",
@@ -406,6 +420,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
MSC3757v11 = RoomVersion(
# MSC3757 (Restricting who can overwrite a state event) based on room version "11"
@@ -429,6 +444,7 @@ class RoomVersions:
msc3757_enabled=True,
msc4289_creator_power_enabled=False,
msc4291_room_ids_as_hashes=False,
msc4243_account_keys=False,
)
HydraV11 = RoomVersion(
"org.matrix.hydra.11",
@@ -451,6 +467,7 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=True, # Changed from v11
msc4291_room_ids_as_hashes=True, # Changed from v11
msc4243_account_keys=False,
)
V12 = RoomVersion(
"12",
@@ -473,6 +490,30 @@ class RoomVersions:
msc3757_enabled=False,
msc4289_creator_power_enabled=True, # Changed from v11
msc4291_room_ids_as_hashes=True, # Changed from v11
msc4243_account_keys=False,
)
MSC4243v12 = RoomVersion(
"org.matrix.12.4243",
RoomDisposition.STABLE,
EventFormatVersions.ROOM_V11_HYDRA_PLUS,
StateResolutionVersions.V2_1, # Changed from v11
enforce_key_validity=False, # No longer enforce key validity.
special_case_aliases_auth=False,
strict_canonicaljson=True,
limit_notifications_power_levels=True,
implicit_room_creator=True, # Used by MSC3820
updated_redaction_rules=True, # Used by MSC3820
restricted_join_rule=True,
restricted_join_rule_fix=True,
knock_join_rule=True,
msc3389_relation_redactions=False,
knock_restricted_join_rule=True,
enforce_int_power_levels=True,
msc3931_push_features=(),
msc3757_enabled=False,
msc4289_creator_power_enabled=True,
msc4291_room_ids_as_hashes=True,
msc4243_account_keys=True,
)
@@ -494,6 +535,7 @@ KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
RoomVersions.MSC3757v10,
RoomVersions.MSC3757v11,
RoomVersions.HydraV11,
RoomVersions.MSC4243v12,
)
}
+44 -1
View File
@@ -47,7 +47,7 @@ from synapse.events import EventBase
from synapse.events.utils import prune_event_dict
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict
from synapse.types import JsonDict, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.batching_queue import BatchingQueue
@@ -83,6 +83,7 @@ class VerifyJsonRequest:
get_json_object: Callable[[], JsonDict]
minimum_valid_until_ts: int
key_ids: List[str]
key_ids_are_public_keys: bool = False
@staticmethod
def from_json_object(
@@ -118,6 +119,7 @@ class VerifyJsonRequest:
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
minimum_valid_until_ms,
key_ids=key_ids,
key_ids_are_public_keys=False,
)
@@ -265,6 +267,38 @@ class Keyring:
)
)
async def verify_event_for_account_key(
self,
user_id_str: str,
event: EventBase,
) -> None:
"""Verify that the given event has been signed by the provided account key, as determined by
the user_id provided.
Args:
user_id_str: The MSC4243 user ID, consisting of an account key localpart and a domain.
event: The PDU that should be signed with the account key. Room version must support MSC4243.
"""
assert event.room_version.msc4243_account_keys
user_id = UserID.from_string(user_id_str)
key_ids = list(event.signatures.get(user_id.domain, []))
# only keep the key ID that matches the desired user ID we want to verify as.
# Events can be signed by multiple parties e.g invites, restricted joins
expected_key_id = "ed25519:" + user_id.localpart
key_ids = [key_id for key_id in key_ids if key_id == expected_key_id]
assert len(key_ids) == 1 # the user must have signed the event.
await self.process_request(
VerifyJsonRequest(
user_id.domain,
# We defer creating the redacted json object, as it uses a lot more
# memory than the Event object itself.
lambda: prune_event_dict(event.room_version, event.get_pdu_json()),
0, # No validity times
key_ids=key_ids,
key_ids_are_public_keys=True,
)
)
async def process_request(self, verify_request: VerifyJsonRequest) -> None:
"""Processes the `VerifyJsonRequest`. Raises if the object is not signed
by the server, the signatures don't match or we failed to fetch the
@@ -278,6 +312,15 @@ class Keyring:
Codes.UNAUTHORIZED,
)
if verify_request.key_ids_are_public_keys:
# No need to fetch keys as we have them already.
assert len(verify_request.key_ids) == 1
key_id = verify_request.key_ids[0]
key_bytes = decode_base64(key_id.removeprefix("ed25519:"))
verify_key = decode_verify_key_bytes(key_id, key_bytes)
await self._process_json(verify_key, verify_request)
return
found_keys: Dict[str, FetchKeyResult] = {}
# If we are the originating server, short-circuit the key-fetch for any keys
+32 -14
View File
@@ -241,16 +241,28 @@ async def _check_sigs_on_pdu(
sender_domain = get_domain_from_id(pdu.sender)
if not _is_invite_via_3pid(pdu):
try:
await keyring.verify_event_for_server(
sender_domain,
pdu,
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
)
if pdu.room_version.msc4243_account_keys:
await keyring.verify_event_for_account_key(
pdu.sender,
pdu,
)
else:
await keyring.verify_event_for_server(
sender_domain,
pdu,
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
)
except Exception as e:
raise InvalidEventSignatureError(
f"unable to verify signature for sender domain {sender_domain}: {e}",
pdu.event_id,
) from None
if pdu.room_version.msc4243_account_keys:
raise InvalidEventSignatureError(
f"unable to verify signature for account key {pdu.sender}: {e}",
pdu.event_id,
) from None
else:
raise InvalidEventSignatureError(
f"unable to verify signature for sender domain {sender_domain}: {e}",
pdu.event_id,
) from None
# now let's look for events where the sender's domain is different to the
# event id's domain (normally only the case for joins/leaves), and add additional
@@ -283,11 +295,17 @@ async def _check_sigs_on_pdu(
pdu.content[EventContentFields.AUTHORISING_USER]
)
try:
await keyring.verify_event_for_server(
authorising_server,
pdu,
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
)
if pdu.room_version.msc4243_account_keys:
await keyring.verify_event_for_account_key(
pdu.content[EventContentFields.AUTHORISING_USER],
pdu,
)
else:
await keyring.verify_event_for_server(
authorising_server,
pdu,
pdu.origin_server_ts if room_version.enforce_key_validity else 0,
)
except Exception as e:
raise InvalidEventSignatureError(
f"unable to verify signature for authorising serve {authorising_server}: {e}",
+28 -15
View File
@@ -37,6 +37,7 @@ Events are replicated via a separate events stream.
"""
import logging
from enum import Enum
from typing import (
TYPE_CHECKING,
Dict,
@@ -67,6 +68,25 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class QueueNames(str, Enum):
PRESENCE_MAP = "presence_map"
KEYED_EDU = "keyed_edu"
KEYED_EDU_CHANGED = "keyed_edu_changed"
EDUS = "edus"
POS_TIME = "pos_time"
PRESENCE_DESTINATIONS = "presence_destinations"
queue_name_to_gauge_map: Dict[QueueNames, LaterGauge] = {}
for queue_name in QueueNames:
queue_name_to_gauge_map[queue_name] = LaterGauge(
name=f"synapse_federation_send_queue_{queue_name.value}_size",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
class FederationRemoteSendQueue(AbstractFederationSender):
"""A drop in replacement for FederationSender"""
@@ -111,23 +131,16 @@ class FederationRemoteSendQueue(AbstractFederationSender):
# we make a new function, so we need to make a new function so the inner
# lambda binds to the queue rather than to the name of the queue which
# changes. ARGH.
def register(name: str, queue: Sized) -> None:
LaterGauge(
name="synapse_federation_send_queue_%s_size" % (queue_name,),
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): len(queue)},
def register(queue_name: QueueNames, queue: Sized) -> None:
queue_name_to_gauge_map[queue_name].register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {(self.server_name,): len(queue)},
)
for queue_name in [
"presence_map",
"keyed_edu",
"keyed_edu_changed",
"edus",
"pos_time",
"presence_destinations",
]:
register(queue_name, getattr(self, queue_name))
for queue_name in QueueNames:
queue = getattr(self, queue_name.value)
assert isinstance(queue, Sized)
register(queue_name, queue=queue)
self.clock.looping_call(self._clear_queue, 30 * 1000)
+53 -16
View File
@@ -150,6 +150,7 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
from synapse.api.constants import EventTypes, Membership
from synapse.api.presence import UserPresenceState
from synapse.events import EventBase
from synapse.federation.sender.per_destination_queue import (
@@ -199,6 +200,24 @@ sent_pdus_destination_dist_total = Counter(
labelnames=[SERVER_NAME_LABEL],
)
transaction_queue_pending_destinations_gauge = LaterGauge(
name="synapse_federation_transaction_queue_pending_destinations",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
transaction_queue_pending_pdus_gauge = LaterGauge(
name="synapse_federation_transaction_queue_pending_pdus",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
transaction_queue_pending_edus_gauge = LaterGauge(
name="synapse_federation_transaction_queue_pending_edus",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
# Time (in s) to wait before trying to wake up destinations that have
# catch-up outstanding.
# Please note that rate limiting still applies, so while the loop is
@@ -398,11 +417,9 @@ class FederationSender(AbstractFederationSender):
# map from destination to PerDestinationQueue
self._per_destination_queues: Dict[str, PerDestinationQueue] = {}
LaterGauge(
name="synapse_federation_transaction_queue_pending_destinations",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {
transaction_queue_pending_destinations_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {
(self.server_name,): sum(
1
for d in self._per_destination_queues.values()
@@ -410,22 +427,17 @@ class FederationSender(AbstractFederationSender):
)
},
)
LaterGauge(
name="synapse_federation_transaction_queue_pending_pdus",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {
transaction_queue_pending_pdus_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {
(self.server_name,): sum(
d.pending_pdu_count() for d in self._per_destination_queues.values()
)
},
)
LaterGauge(
name="synapse_federation_transaction_queue_pending_edus",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {
transaction_queue_pending_edus_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {
(self.server_name,): sum(
d.pending_edu_count() for d in self._per_destination_queues.values()
)
@@ -644,6 +656,31 @@ class FederationSender(AbstractFederationSender):
)
return
# If we've rescinded an invite then we want to tell the
# other server.
if (
event.type == EventTypes.Member
and event.membership == Membership.LEAVE
and event.sender != event.state_key
):
# We check if this leave event is rescinding an invite
# by looking if there is an invite event for the user in
# the auth events. It could otherwise be a kick or
# unban, which we don't want to send (if the user wasn't
# already in the room).
auth_events = await self.store.get_events_as_list(
event.auth_event_ids()
)
for auth_event in auth_events:
if (
auth_event.type == EventTypes.Member
and auth_event.state_key == event.state_key
and auth_event.membership == Membership.INVITE
):
destinations = set(destinations)
destinations.add(get_domain_from_id(event.state_key))
break
sharded_destinations = {
d
for d in destinations
+41 -3
View File
@@ -248,9 +248,10 @@ class FederationEventHandler:
self.room_queues[room_id].append((pdu, origin))
return
# If we're not in the room just ditch the event entirely. This is
# probably an old server that has come back and thinks we're still in
# the room (or we've been rejoined to the room by a state reset).
# If we're not in the room just ditch the event entirely (and not
# invited). This is probably an old server that has come back and thinks
# we're still in the room (or we've been rejoined to the room by a state
# reset).
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
@@ -258,6 +259,43 @@ class FederationEventHandler:
room_id, self.server_name
)
if not is_in_room:
# Check if this is a leave event rescinding an invite
if (
pdu.type == EventTypes.Member
and pdu.membership == Membership.LEAVE
and pdu.state_key != pdu.sender
and self._is_mine_id(pdu.state_key)
):
(
membership,
membership_event_id,
) = await self._store.get_local_current_membership_for_user_in_room(
pdu.state_key, pdu.room_id
)
if (
membership == Membership.INVITE
and membership_event_id
and membership_event_id
in pdu.auth_event_ids() # The invite should be in the auth events of the rescission.
):
invite_event = await self._store.get_event(
membership_event_id, allow_none=True
)
# We cannot fully auth the rescission event, but we can
# check if the sender of the leave event is the same as the
# invite.
#
# Technically, a room admin could rescind the invite, but we
# have no way of knowing who is and isn't a room admin.
if invite_event and pdu.sender == invite_event.sender:
# Handle the rescission event
pdu.internal_metadata.outlier = True
pdu.internal_metadata.out_of_band_membership = True
context = EventContext.for_outlier(self._storage_controllers)
await self.persist_events_and_notify(room_id, [(pdu, context)])
return
logger.info(
"Ignoring PDU from %s as we're not in the room",
origin,
+18 -10
View File
@@ -173,6 +173,18 @@ state_transition_counter = Counter(
labelnames=["locality", "from", "to", SERVER_NAME_LABEL],
)
presence_user_to_current_state_size_gauge = LaterGauge(
name="synapse_handlers_presence_user_to_current_state_size",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
presence_wheel_timer_size_gauge = LaterGauge(
name="synapse_handlers_presence_wheel_timer_size",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
# If a user was last active in the last LAST_ACTIVE_GRANULARITY, consider them
# "currently_active"
LAST_ACTIVE_GRANULARITY = 60 * 1000
@@ -779,11 +791,9 @@ class PresenceHandler(BasePresenceHandler):
EduTypes.PRESENCE, self.incoming_presence
)
LaterGauge(
name="synapse_handlers_presence_user_to_current_state_size",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): len(self.user_to_current_state)},
presence_user_to_current_state_size_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {(self.server_name,): len(self.user_to_current_state)},
)
# The per-device presence state, maps user to devices to per-device presence state.
@@ -882,11 +892,9 @@ class PresenceHandler(BasePresenceHandler):
60 * 1000,
)
LaterGauge(
name="synapse_handlers_presence_wheel_timer_size",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): len(self.wheel_timer)},
presence_wheel_timer_size_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {(self.server_name,): len(self.wheel_timer)},
)
# Used to handle sending of presence to newly joined users/servers
+14 -271
View File
@@ -20,7 +20,6 @@
#
import itertools
import logging
from enum import Enum
from typing import (
TYPE_CHECKING,
AbstractSet,
@@ -28,14 +27,11 @@ from typing import (
Dict,
FrozenSet,
List,
Literal,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
overload,
)
import attr
@@ -120,25 +116,6 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
SyncRequestKey = Tuple[Any, ...]
class SyncVersion(Enum):
"""
Enum for specifying the version of sync request. This is used to key which type of
sync response that we are generating.
This is different than the `sync_type` you might see used in other code below; which
specifies the sub-type sync request (e.g. initial_sync, full_state_sync,
incremental_sync) and is really only relevant for the `/sync` v2 endpoint.
"""
# These string values are semantically significant because they are used in the the
# metrics
# Traditional `/sync` endpoint
SYNC_V2 = "sync_v2"
# Part of MSC3575 Sliding Sync
E2EE_SYNC = "e2ee_sync"
@attr.s(slots=True, frozen=True, auto_attribs=True)
class SyncConfig:
user: UserID
@@ -308,26 +285,6 @@ class SyncResult:
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class E2eeSyncResult:
"""
Attributes:
next_batch: Token for the next sync
to_device: List of direct messages for the device.
device_lists: List of user_ids whose devices have changed
device_one_time_keys_count: Dict of algorithm to count for one time keys
for this device
device_unused_fallback_key_types: List of key types that have an unused fallback
key
"""
next_batch: StreamToken
to_device: List[JsonDict]
device_lists: DeviceListUpdates
device_one_time_keys_count: JsonMapping
device_unused_fallback_key_types: List[str]
class SyncHandler:
def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
@@ -373,52 +330,15 @@ class SyncHandler:
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync
@overload
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.SYNC_V2],
request_key: SyncRequestKey,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> SyncResult: ...
@overload
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.E2EE_SYNC],
request_key: SyncRequestKey,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> E2eeSyncResult: ...
@overload
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SyncConfig,
sync_version: SyncVersion,
request_key: SyncRequestKey,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> Union[SyncResult, E2eeSyncResult]: ...
async def wait_for_sync_for_user(
self,
requester: Requester,
sync_config: SyncConfig,
sync_version: SyncVersion,
request_key: SyncRequestKey,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> Union[SyncResult, E2eeSyncResult]:
) -> SyncResult:
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
@@ -433,8 +353,7 @@ class SyncHandler:
full_state: Whether to return the full state for each room.
Returns:
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
returns a full `SyncResult`.
"""
# If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain
@@ -446,7 +365,6 @@ class SyncHandler:
request_key,
self._wait_for_sync_for_user,
sync_config,
sync_version,
since_token,
timeout,
full_state,
@@ -455,48 +373,14 @@ class SyncHandler:
logger.debug("Returning sync response for %s", user_id)
return res
@overload
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.SYNC_V2],
since_token: Optional[StreamToken],
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> SyncResult: ...
@overload
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.E2EE_SYNC],
since_token: Optional[StreamToken],
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> E2eeSyncResult: ...
@overload
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken],
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> Union[SyncResult, E2eeSyncResult]: ...
async def _wait_for_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken],
timeout: int,
full_state: bool,
cache_context: ResponseCacheContext[SyncRequestKey],
) -> Union[SyncResult, E2eeSyncResult]:
) -> SyncResult:
"""The start of the machinery that produces a /sync response.
See https://spec.matrix.org/v1.1/client-server-api/#syncing for full details.
@@ -517,7 +401,7 @@ class SyncHandler:
else:
sync_type = "incremental_sync"
sync_label = f"{sync_version}:{sync_type}"
sync_label = f"sync_v2:{sync_type}"
context = current_context()
if context:
@@ -578,19 +462,15 @@ class SyncHandler:
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
result: Union[
SyncResult, E2eeSyncResult
] = await self.current_sync_for_user(
sync_config, sync_version, since_token, full_state=full_state
result = await self.current_sync_for_user(
sync_config, since_token, full_state=full_state
)
else:
# Otherwise, we wait for something to happen and report it to the user.
async def current_sync_callback(
before_token: StreamToken, after_token: StreamToken
) -> Union[SyncResult, E2eeSyncResult]:
return await self.current_sync_for_user(
sync_config, sync_version, since_token
)
) -> SyncResult:
return await self.current_sync_for_user(sync_config, since_token)
result = await self.notifier.wait_for_events(
sync_config.user.to_string(),
@@ -623,43 +503,15 @@ class SyncHandler:
return result
@overload
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.SYNC_V2],
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult: ...
@overload
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: Literal[SyncVersion.E2EE_SYNC],
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> E2eeSyncResult: ...
@overload
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> Union[SyncResult, E2eeSyncResult]: ...
async def current_sync_for_user(
self,
sync_config: SyncConfig,
sync_version: SyncVersion,
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> Union[SyncResult, E2eeSyncResult]:
) -> SyncResult:
"""
Generates the response body of a sync result, represented as a
`SyncResult`/`E2eeSyncResult`.
`SyncResult`.
This is a wrapper around `generate_sync_result` which starts an open tracing
span to track the sync. See `generate_sync_result` for the next part of your
@@ -672,28 +524,15 @@ class SyncHandler:
full_state: Whether to return the full state for each room.
Returns:
When `SyncVersion.SYNC_V2`, returns a full `SyncResult`.
When `SyncVersion.E2EE_SYNC`, returns a `E2eeSyncResult`.
returns a full `SyncResult`.
"""
with start_active_span("sync.current_sync_for_user"):
log_kv({"since_token": since_token})
# Go through the `/sync` v2 path
if sync_version == SyncVersion.SYNC_V2:
sync_result: Union[
SyncResult, E2eeSyncResult
] = await self.generate_sync_result(
sync_config, since_token, full_state
)
# Go through the MSC3575 Sliding Sync `/sync/e2ee` path
elif sync_version == SyncVersion.E2EE_SYNC:
sync_result = await self.generate_e2ee_sync_result(
sync_config, since_token
)
else:
raise Exception(
f"Unknown sync_version (this is a Synapse problem): {sync_version}"
)
sync_result = await self.generate_sync_result(
sync_config, since_token, full_state
)
set_tag(SynapseTags.SYNC_RESULT, bool(sync_result))
return sync_result
@@ -1968,102 +1807,6 @@ class SyncHandler:
next_batch=sync_result_builder.now_token,
)
async def generate_e2ee_sync_result(
self,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
) -> E2eeSyncResult:
"""
Generates the response body of a MSC3575 Sliding Sync `/sync/e2ee` result.
This is represented by a `E2eeSyncResult` struct, which is built from small
pieces using a `SyncResultBuilder`. The `sync_result_builder` is passed as a
mutable ("inout") parameter to various helper functions. These retrieve and
process the data which forms the sync body, often writing to the
`sync_result_builder` to store their output.
At the end, we transfer data from the `sync_result_builder` to a new `E2eeSyncResult`
instance to signify that the sync calculation is complete.
"""
user_id = sync_config.user.to_string()
app_service = self.store.get_app_service_by_user_id(user_id)
if app_service:
# We no longer support AS users using /sync directly.
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
sync_result_builder = await self.get_sync_result_builder(
sync_config,
since_token,
full_state=False,
)
# 1. Calculate `to_device` events
await self._generate_sync_entry_for_to_device(sync_result_builder)
# 2. Calculate `device_lists`
# Device list updates are sent if a since token is provided.
device_lists = DeviceListUpdates()
include_device_list_updates = bool(since_token and since_token.device_list_key)
if include_device_list_updates:
# Note that _generate_sync_entry_for_rooms sets sync_result_builder.joined, which
# is used in calculate_user_changes below.
#
# TODO: Running `_generate_sync_entry_for_rooms()` is a lot of work just to
# figure out the membership changes/derived info needed for
# `_generate_sync_entry_for_device_list()`. In the future, we should try to
# refactor this away.
(
newly_joined_rooms,
newly_left_rooms,
) = await self._generate_sync_entry_for_rooms(sync_result_builder)
# This uses the sync_result_builder.joined which is set in
# `_generate_sync_entry_for_rooms`, if that didn't find any joined
# rooms for some reason it is a no-op.
(
newly_joined_or_invited_or_knocked_users,
newly_left_users,
) = sync_result_builder.calculate_user_changes()
# include_device_list_updates can only be True if we have a
# since token.
assert since_token is not None
device_lists = await self._device_handler.generate_sync_entry_for_device_list(
user_id=user_id,
since_token=since_token,
now_token=sync_result_builder.now_token,
joined_room_ids=sync_result_builder.joined_room_ids,
newly_joined_rooms=newly_joined_rooms,
newly_joined_or_invited_or_knocked_users=newly_joined_or_invited_or_knocked_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
# 3. Calculate `device_one_time_keys_count` and `device_unused_fallback_key_types`
device_id = sync_config.device_id
one_time_keys_count: JsonMapping = {}
unused_fallback_key_types: List[str] = []
if device_id:
# TODO: We should have a way to let clients differentiate between the states of:
# * no change in OTK count since the provided since token
# * the server has zero OTKs left for this device
# Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
one_time_keys_count = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = list(
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
return E2eeSyncResult(
to_device=sync_result_builder.to_device,
device_lists=device_lists,
device_one_time_keys_count=one_time_keys_count,
device_unused_fallback_key_types=unused_fallback_key_types,
next_batch=sync_result_builder.now_token,
)
async def get_sync_result_builder(
self,
sync_config: SyncConfig,
+4 -2
View File
@@ -164,11 +164,13 @@ def _get_in_flight_counts() -> Mapping[Tuple[str, ...], int]:
return counts
LaterGauge(
in_flight_requests = LaterGauge(
name="synapse_http_server_in_flight_requests_count",
desc="",
labelnames=["method", "servlet", SERVER_NAME_LABEL],
caller=_get_in_flight_counts,
)
in_flight_requests.register_hook(
homeserver_instance_id=None, hook=_get_in_flight_counts
)
+93 -35
View File
@@ -73,8 +73,6 @@ logger = logging.getLogger(__name__)
METRICS_PREFIX = "/_synapse/metrics"
all_gauges: Dict[str, Collector] = {}
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
SERVER_NAME_LABEL = "server_name"
@@ -163,42 +161,110 @@ class LaterGauge(Collector):
name: str
desc: str
labelnames: Optional[StrSequence] = attr.ib(hash=False)
# callback: should either return a value (if there are no labels for this metric),
# or dict mapping from a label tuple to a value
caller: Callable[
[], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
]
_instance_id_to_hook_map: Dict[
Optional[str], # instance_id
Callable[
[], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
],
] = attr.ib(factory=dict, hash=False)
"""
Map from homeserver instance_id to a callback. Each callback should either return a
value (if there are no labels for this metric), or dict mapping from a label tuple
to a value.
We use `instance_id` instead of `server_name` because it's possible to have multiple
workers running in the same process with the same `server_name`.
"""
def collect(self) -> Iterable[Metric]:
# The decision to add `SERVER_NAME_LABEL` is from the `LaterGauge` usage itself
# (we don't enforce it here, one level up).
g = GaugeMetricFamily(self.name, self.desc, labels=self.labelnames) # type: ignore[missing-server-name-label]
try:
calls = self.caller()
except Exception:
logger.exception("Exception running callback for LaterGauge(%s)", self.name)
yield g
return
for homeserver_instance_id, hook in self._instance_id_to_hook_map.items():
try:
hook_result = hook()
except Exception:
logger.exception(
"Exception running callback for LaterGauge(%s) for homeserver_instance_id=%s",
self.name,
homeserver_instance_id,
)
# Continue to return the rest of the metrics that aren't broken
continue
if isinstance(calls, (int, float)):
g.add_metric([], calls)
else:
for k, v in calls.items():
g.add_metric(k, v)
if isinstance(hook_result, (int, float)):
g.add_metric([], hook_result)
else:
for k, v in hook_result.items():
g.add_metric(k, v)
yield g
def register_hook(
self,
*,
homeserver_instance_id: Optional[str],
hook: Callable[
[], Union[Mapping[Tuple[str, ...], Union[int, float]], Union[int, float]]
],
) -> None:
"""
Register a callback/hook that will be called to generate a metric samples for
the gauge.
Args:
homeserver_instance_id: The unique ID for this Synapse process instance
(`hs.get_instance_id()`) that this hook is associated with. This can be used
later to lookup all hooks associated with a given server name in order to
unregister them. This should only be omitted for global hooks that work
across all homeservers.
hook: A callback that should either return a value (if there are no
labels for this metric), or dict mapping from a label tuple to a value
"""
# We shouldn't have multiple hooks registered for the same homeserver `instance_id`.
existing_hook = self._instance_id_to_hook_map.get(homeserver_instance_id)
assert existing_hook is None, (
f"LaterGauge(name={self.name}) hook already registered for homeserver_instance_id={homeserver_instance_id}. "
"This is likely a Synapse bug and you forgot to unregister the previous hooks for "
"the server (especially in tests)."
)
self._instance_id_to_hook_map[homeserver_instance_id] = hook
def unregister_hooks_for_homeserver_instance_id(
self, homeserver_instance_id: str
) -> None:
"""
Unregister all hooks associated with the given homeserver `instance_id`. This should be
called when a homeserver is shutdown to avoid extra hooks sitting around.
Args:
homeserver_instance_id: The unique ID for this Synapse process instance to
unregister hooks for (`hs.get_instance_id()`).
"""
self._instance_id_to_hook_map.pop(homeserver_instance_id, None)
def __attrs_post_init__(self) -> None:
self._register()
def _register(self) -> None:
if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering", self.name)
REGISTRY.unregister(all_gauges.pop(self.name))
REGISTRY.register(self)
all_gauges[self.name] = self
# We shouldn't have multiple metrics with the same name. Typically, metrics
# should be created globally so you shouldn't be running into this and this will
# catch any stupid mistakes. The `REGISTRY.register(self)` call above will also
# raise an error if the metric already exists but to make things explicit, we'll
# also check here.
existing_gauge = all_later_gauges_to_clean_up_on_shutdown.get(self.name)
assert existing_gauge is None, f"LaterGauge(name={self.name}) already exists. "
# Keep track of the gauge so we can clean it up later.
all_later_gauges_to_clean_up_on_shutdown[self.name] = self
all_later_gauges_to_clean_up_on_shutdown: Dict[str, LaterGauge] = {}
"""
Track all `LaterGauge` instances so we can remove any associated hooks during homeserver
shutdown.
"""
# `MetricsEntry` only makes sense when it is a `Protocol`,
@@ -250,7 +316,7 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
# Protects access to _registrations
self._lock = threading.Lock()
self._register_with_collector()
REGISTRY.register(self)
def register(
self,
@@ -341,14 +407,6 @@ class InFlightGauge(Generic[MetricsEntry], Collector):
gauge.add_metric(labels=key, value=getattr(metrics, name))
yield gauge
def _register_with_collector(self) -> None:
if self.name in all_gauges.keys():
logger.warning("%s already registered, reregistering", self.name)
REGISTRY.unregister(all_gauges.pop(self.name))
REGISTRY.register(self)
all_gauges[self.name] = self
class GaugeHistogramMetricFamilyWithLabels(GaugeHistogramMetricFamily):
"""
+26 -16
View File
@@ -86,6 +86,24 @@ users_woken_by_stream_counter = Counter(
labelnames=["stream", SERVER_NAME_LABEL],
)
notifier_listeners_gauge = LaterGauge(
name="synapse_notifier_listeners",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
notifier_rooms_gauge = LaterGauge(
name="synapse_notifier_rooms",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
notifier_users_gauge = LaterGauge(
name="synapse_notifier_users",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
T = TypeVar("T")
@@ -281,28 +299,20 @@ class Notifier:
)
}
LaterGauge(
name="synapse_notifier_listeners",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=count_listeners,
notifier_listeners_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(), hook=count_listeners
)
LaterGauge(
name="synapse_notifier_rooms",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {
notifier_rooms_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {
(self.server_name,): count(
bool, list(self.room_to_user_streams.values())
)
},
)
LaterGauge(
name="synapse_notifier_users",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): len(self.user_to_user_stream)},
notifier_users_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {(self.server_name,): len(self.user_to_user_stream)},
)
def add_replication_callback(self, cb: Callable[[], None]) -> None:
+1 -1
View File
@@ -91,7 +91,7 @@ def _rule_to_template(rule: PushRule) -> Optional[Dict[str, Any]]:
unscoped_rule_id = _rule_id_from_namespaced(rule.rule_id)
template_name = _priority_class_to_template_name(rule.priority_class)
if template_name in ["override", "underride"]:
if template_name in ["override", "underride", "postcontent"]:
templaterule = {"conditions": rule.conditions, "actions": rule.actions}
elif template_name in ["sender", "room"]:
templaterule = {"actions": rule.actions}
+4
View File
@@ -19,10 +19,14 @@
#
#
# Integer literals for push rule `kind`s
# This is used to store them in the database.
PRIORITY_CLASS_MAP = {
"underride": 1,
"sender": 2,
"room": 3,
# MSC4306
"postcontent": 6,
"content": 4,
"override": 5,
}
+18 -10
View File
@@ -106,6 +106,18 @@ user_ip_cache_counter = Counter(
"synapse_replication_tcp_resource_user_ip_cache", "", labelnames=[SERVER_NAME_LABEL]
)
tcp_resource_total_connections_gauge = LaterGauge(
name="synapse_replication_tcp_resource_total_connections",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
tcp_command_queue_gauge = LaterGauge(
name="synapse_replication_tcp_command_queue",
desc="Number of inbound RDATA/POSITION commands queued for processing",
labelnames=["stream_name", SERVER_NAME_LABEL],
)
# the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[
@@ -243,11 +255,9 @@ class ReplicationCommandHandler:
# outgoing replication commands to.)
self._connections: List[IReplicationConnection] = []
LaterGauge(
name="synapse_replication_tcp_resource_total_connections",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): len(self._connections)},
tcp_resource_total_connections_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {(self.server_name,): len(self._connections)},
)
# When POSITION or RDATA commands arrive, we stick them in a queue and process
@@ -266,11 +276,9 @@ class ReplicationCommandHandler:
# from that connection.
self._streams_by_connection: Dict[IReplicationConnection, Set[str]] = {}
LaterGauge(
name="synapse_replication_tcp_command_queue",
desc="Number of inbound RDATA/POSITION commands queued for processing",
labelnames=["stream_name", SERVER_NAME_LABEL],
caller=lambda: {
tcp_command_queue_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {
(stream_name, self.server_name): len(queue)
for stream_name, queue in self._command_queues_by_stream.items()
},
+16 -4
View File
@@ -527,7 +527,10 @@ pending_commands = LaterGauge(
name="synapse_replication_tcp_protocol_pending_commands",
desc="",
labelnames=["name", SERVER_NAME_LABEL],
caller=lambda: {
)
pending_commands.register_hook(
homeserver_instance_id=None,
hook=lambda: {
(p.name, p.server_name): len(p.pending_commands) for p in connected_connections
},
)
@@ -544,7 +547,10 @@ transport_send_buffer = LaterGauge(
name="synapse_replication_tcp_protocol_transport_send_buffer",
desc="",
labelnames=["name", SERVER_NAME_LABEL],
caller=lambda: {
)
transport_send_buffer.register_hook(
homeserver_instance_id=None,
hook=lambda: {
(p.name, p.server_name): transport_buffer_size(p) for p in connected_connections
},
)
@@ -571,7 +577,10 @@ tcp_transport_kernel_send_buffer = LaterGauge(
name="synapse_replication_tcp_protocol_transport_kernel_send_buffer",
desc="",
labelnames=["name", SERVER_NAME_LABEL],
caller=lambda: {
)
tcp_transport_kernel_send_buffer.register_hook(
homeserver_instance_id=None,
hook=lambda: {
(p.name, p.server_name): transport_kernel_read_buffer_size(p, False)
for p in connected_connections
},
@@ -582,7 +591,10 @@ tcp_transport_kernel_read_buffer = LaterGauge(
name="synapse_replication_tcp_protocol_transport_kernel_read_buffer",
desc="",
labelnames=["name", SERVER_NAME_LABEL],
caller=lambda: {
)
tcp_transport_kernel_read_buffer.register_hook(
homeserver_instance_id=None,
hook=lambda: {
(p.name, p.server_name): transport_kernel_read_buffer_size(p, True)
for p in connected_connections
},
+2 -2
View File
@@ -92,9 +92,9 @@ class ExperimentalFeaturesRestServlet(RestServlet):
user_features = {}
for feature in ExperimentalFeature:
if feature in enabled_features:
user_features[feature] = True
user_features[feature.value] = True
else:
user_features[feature] = False
user_features[feature.value] = False
return HTTPStatus.OK, {"features": user_features}
async def on_PUT(
+6
View File
@@ -109,6 +109,12 @@ class ProfileFieldRestServlet(RestServlet):
self.hs = hs
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
if hs.config.experimental.msc4133_enabled:
self.PATTERNS.append(
re.compile(
r"^/_matrix/client/unstable/uk\.tcpip\.msc4133/profile/(?P<user_id>[^/]*)/(?P<field_name>[^/]*)"
)
)
async def on_GET(
self, request: SynapseRequest, user_id: str, field_name: str
+11
View File
@@ -19,9 +19,11 @@
#
#
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Tuple, Union
from synapse.api.errors import (
Codes,
NotFoundError,
StoreError,
SynapseError,
@@ -239,6 +241,15 @@ def _rule_spec_from_path(path: List[str]) -> RuleSpec:
def _rule_tuple_from_request_object(
rule_template: str, rule_id: str, req_obj: JsonDict
) -> Tuple[List[JsonDict], List[Union[str, JsonDict]]]:
if rule_template == "postcontent":
# postcontent is from MSC4306, which says that clients
# cannot create their own postcontent rules right now.
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"user-defined rules using `postcontent` are not accepted",
errcode=Codes.INVALID_PARAM,
)
if rule_template in ["override", "underride"]:
if "conditions" not in req_obj:
raise InvalidRuleException("Missing 'conditions'")
-174
View File
@@ -42,7 +42,6 @@ from synapse.handlers.sync import (
KnockedSyncResult,
SyncConfig,
SyncResult,
SyncVersion,
)
from synapse.http.server import HttpServer
from synapse.http.servlet import (
@@ -267,7 +266,6 @@ class SyncRestServlet(RestServlet):
sync_result = await self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
SyncVersion.SYNC_V2,
request_key,
since_token=since_token,
timeout=timeout,
@@ -632,177 +630,6 @@ class SyncRestServlet(RestServlet):
return result
class SlidingSyncE2eeRestServlet(RestServlet):
"""
API endpoint for MSC3575 Sliding Sync `/sync/e2ee`. This is being introduced as part
of Sliding Sync but doesn't have any sliding window component. It's just a way to
get E2EE events without having to sit through a big initial sync (`/sync` v2). And
we can avoid encryption events being backed up by the main sync response.
Having To-Device messages split out to this sync endpoint also helps when clients
need to have 2 or more sync streams open at a time, e.g a push notification process
and a main process. This can cause the two processes to race to fetch the To-Device
events, resulting in the need for complex synchronisation rules to ensure the token
is correctly and atomically exchanged between processes.
GET parameters::
timeout(int): How long to wait for new events in milliseconds.
since(batch_token): Batch token when asking for incremental deltas.
Response JSON::
{
"next_batch": // batch token for the next /sync
"to_device": {
// list of to-device events
"events": [
{
"content: { "algorithm": "m.olm.v1.curve25519-aes-sha2", "ciphertext": { ... }, "org.matrix.msgid": "abcd", "session_id": "abcd" },
"type": "m.room.encrypted",
"sender": "@alice:example.com",
}
// ...
]
},
"device_lists": {
"changed": ["@alice:example.com"],
"left": ["@bob:example.com"]
},
"device_one_time_keys_count": {
"signed_curve25519": 50
},
"device_unused_fallback_key_types": [
"signed_curve25519"
]
}
"""
PATTERNS = client_patterns(
"/org.matrix.msc3575/sync/e2ee$", releases=[], v1=False, unstable=True
)
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.sync_handler = hs.get_sync_handler()
# Filtering only matters for the `device_lists` because it requires a bunch of
# derived information from rooms (see how `_generate_sync_entry_for_rooms()`
# prepares a bunch of data for `_generate_sync_entry_for_device_list()`).
self.only_member_events_filter_collection = FilterCollection(
self.hs,
{
"room": {
# We only care about membership events for the `device_lists`.
# Membership will tell us whether a user has joined/left a room and
# if there are new devices to encrypt for.
"timeline": {
"types": ["m.room.member"],
},
"state": {
"types": ["m.room.member"],
},
# We don't want any extra account_data generated because it's not
# returned by this endpoint. This helps us avoid work in
# `_generate_sync_entry_for_rooms()`
"account_data": {
"not_types": ["*"],
},
# We don't want any extra ephemeral data generated because it's not
# returned by this endpoint. This helps us avoid work in
# `_generate_sync_entry_for_rooms()`
"ephemeral": {
"not_types": ["*"],
},
},
# We don't want any extra account_data generated because it's not
# returned by this endpoint. (This is just here for good measure)
"account_data": {
"not_types": ["*"],
},
# We don't want any extra presence data generated because it's not
# returned by this endpoint. (This is just here for good measure)
"presence": {
"not_types": ["*"],
},
},
)
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req_experimental_feature(
request, allow_guest=True, feature=ExperimentalFeature.MSC3575
)
user = requester.user
device_id = requester.device_id
timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since")
sync_config = SyncConfig(
user=user,
filter_collection=self.only_member_events_filter_collection,
is_guest=requester.is_guest,
device_id=device_id,
use_state_after=False, # We don't return any rooms so this flag is a no-op
)
since_token = None
if since is not None:
since_token = await StreamToken.from_string(self.store, since)
# Request cache key
request_key = (
SyncVersion.E2EE_SYNC,
user,
timeout,
since,
)
# Gather data for the response
sync_result = await self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
SyncVersion.E2EE_SYNC,
request_key,
since_token=since_token,
timeout=timeout,
full_state=False,
)
# The client may have disconnected by now; don't bother to serialize the
# response if so.
if request._disconnected:
logger.info("Client has disconnected; not serializing response.")
return 200, {}
response: JsonDict = defaultdict(dict)
response["next_batch"] = await sync_result.next_batch.to_string(self.store)
if sync_result.to_device:
response["to_device"] = {"events": sync_result.to_device}
if sync_result.device_lists.changed:
response["device_lists"]["changed"] = list(sync_result.device_lists.changed)
if sync_result.device_lists.left:
response["device_lists"]["left"] = list(sync_result.device_lists.left)
# We always include this because https://github.com/vector-im/element-android/issues/3725
# The spec isn't terribly clear on when this can be omitted and how a client would tell
# the difference between "no keys present" and "nothing changed" in terms of whole field
# absent / individual key type entry absent
# Corresponding synapse issue: https://github.com/matrix-org/synapse/issues/10456
response["device_one_time_keys_count"] = sync_result.device_one_time_keys_count
# https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md
# states that this field should always be included, as long as the server supports the feature.
response["device_unused_fallback_key_types"] = (
sync_result.device_unused_fallback_key_types
)
return 200, response
class SlidingSyncRestServlet(RestServlet):
"""
API endpoint for MSC3575 Sliding Sync `/sync`. Allows for clients to request a
@@ -1254,4 +1081,3 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
SlidingSyncRestServlet(hs).register(http_server)
SlidingSyncE2eeRestServlet(hs).register(http_server)
+1
View File
@@ -175,6 +175,7 @@ class VersionsRestServlet(RestServlet):
"org.matrix.simplified_msc3575": msc3575_enabled,
# Arbitrary key-value profile fields.
"uk.tcpip.msc4133": self.config.experimental.msc4133_enabled,
"uk.tcpip.msc4133.stable": True,
# MSC4155: Invite filtering
"org.matrix.msc4155": self.config.experimental.msc4155_enabled,
# MSC4306: Support for thread subscriptions
+35 -1
View File
@@ -129,7 +129,10 @@ from synapse.http.client import (
)
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.media.media_repository import MediaRepository
from synapse.metrics import register_threadpool
from synapse.metrics import (
all_later_gauges_to_clean_up_on_shutdown,
register_threadpool,
)
from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
from synapse.module_api import ModuleApi
from synapse.module_api.callbacks import ModuleApiCallbacks
@@ -369,6 +372,37 @@ class HomeServer(metaclass=abc.ABCMeta):
if self.config.worker.run_background_tasks:
self.setup_background_tasks()
def __del__(self) -> None:
"""
Called when an the homeserver is garbage collected.
Make sure we actually do some clean-up, rather than leak data.
"""
self.cleanup()
def cleanup(self) -> None:
"""
WIP: Clean-up any references to the homeserver and stop any running related
processes, timers, loops, replication stream, etc.
This should be called wherever you care about the HomeServer being completely
garbage collected like in tests. It's not necessary to call if you plan to just
shut down the whole Python process anyway.
Can be called multiple times.
"""
logger.info("Received cleanup request for %s.", self.hostname)
# TODO: Stop background processes, timers, loops, replication stream, etc.
# Cleanup metrics associated with the homeserver
for later_gauge in all_later_gauges_to_clean_up_on_shutdown.values():
later_gauge.unregister_hooks_for_homeserver_instance_id(
self.get_instance_id()
)
logger.info("Cleanup complete for %s.", self.hostname)
def start_listening(self) -> None: # noqa: B027 (no-op by design)
"""Start the HTTP, manhole, metrics, etc listeners
+1 -7
View File
@@ -61,7 +61,7 @@ from synapse.logging.context import (
current_context,
make_deferred_yieldable,
)
from synapse.metrics import SERVER_NAME_LABEL, LaterGauge, register_threadpool
from synapse.metrics import SERVER_NAME_LABEL, register_threadpool
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
@@ -611,12 +611,6 @@ class DatabasePool:
)
self.updates = BackgroundUpdater(hs, self)
LaterGauge(
name="synapse_background_update_status",
desc="Background update status",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): self.updates.get_status()},
)
self._previous_txn_total_time = 0.0
self._current_txn_total_time = 0.0
+17
View File
@@ -22,6 +22,7 @@
import logging
from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar
from synapse.metrics import SERVER_NAME_LABEL, LaterGauge
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.databases.main.events import PersistEventsStore
@@ -40,6 +41,13 @@ logger = logging.getLogger(__name__)
DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True)
background_update_status = LaterGauge(
name="synapse_background_update_status",
desc="Background update status",
labelnames=["database_name", SERVER_NAME_LABEL],
)
class Databases(Generic[DataStoreT]):
"""The various databases.
@@ -143,6 +151,15 @@ class Databases(Generic[DataStoreT]):
db_conn.close()
# Track the background update status for each database
background_update_status.register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {
(database.name(), server_name): database.updates.get_status()
for database in self.databases
},
)
# Sanity check that we have actually configured all the required stores.
if not main:
raise Exception("No 'main' database configured")
@@ -32,6 +32,7 @@ from synapse.storage.database import (
LoggingDatabaseConnection,
LoggingTransaction,
)
from synapse.storage.databases.main.account_keys import AccountKeysStore
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
from synapse.storage.databases.main.stats import UserSortOrder
from synapse.storage.databases.main.thread_subscriptions import (
@@ -163,6 +164,7 @@ class DataStore(
TaskSchedulerWorkerStore,
SlidingSyncStore,
DelayedEventsStore,
AccountKeysStore,
):
def __init__(
self,
@@ -0,0 +1,165 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
from typing import TYPE_CHECKING, Collection, Dict, List, Tuple, cast
from signedjson.key import (
decode_signing_key_base64,
generate_signing_key,
get_verify_key,
)
from signedjson.types import SigningKey
from unpaddedbase64 import encode_base64
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.types import get_domain_from_id, get_localpart_from_id
if TYPE_CHECKING:
from synapse.server import HomeServer
class AccountKeysStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
async def get_or_create_account_key_user_id_for_account_name_user_id(
self, account_name_user_id: str
) -> Tuple[str, SigningKey]:
"""
Get or create an account key for the given account name user ID.
The user ID must belong to this server.
Args:
account_name_user_id: An account name user ID e.g "@alice:example.com"
Returns:
A tuple of account key user ID e.g @l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ:example.com
and the private key for the account.
Raises:
if the provided account name user ID is not owned by this homeserver, or if the user
ID is invalid in some way.
"""
if not self.hs.is_mine_id(account_name_user_id):
raise SynapseError(
500,
(
"get_or_create_account_key_user_id_for_account_name_user_id: this server cannot"
f" create an account key for other servers: {account_name_user_id}"
),
)
row = await self.db_pool.simple_select_one(
table="account_keys",
keyvalues={
"account_name_user_id": account_name_user_id,
},
retcols=["account_key_user_id", "account_key"],
allow_none=True,
desc="get_or_create_account_key_user_id_for_account_name_user_id.get_key_txn",
)
if row is not None:
return row[0], decode_account_key(row[1])
# create a new account key for this account inside a txn to ensure we lock correctly.
def create_key_txn(txn: LoggingTransaction) -> Tuple[str, str]:
key, public_key_str = generate_account_key()
account_key_user_id = (
f"@{public_key_str}:{get_domain_from_id(account_name_user_id)}"
)
# Race to insert the key. The first one to make it will be returned here as we don't clobber
sql = (
"INSERT INTO account_keys(account_name_user_id, account_key_user_id, account_key)"
" VALUES(?, ?, ?)"
" ON CONFLICT DO NOTHING"
)
txn.execute(
sql,
(
account_name_user_id,
account_key_user_id,
encode_base64(key.encode(), urlsafe=True),
),
)
sql = "SELECT account_key_user_id, account_key FROM account_keys WHERE account_name_user_id = ?"
txn.execute(sql, (account_name_user_id,))
return cast(Tuple[str, str], txn.fetchone())
row = await self.db_pool.runInteraction(
"get_or_create_account_key_user_id_for_account_name_user_id.create_key_txn",
create_key_txn,
)
return row[0], decode_account_key(row[1])
async def get_account_name_user_ids_for_account_key_user_ids(
self,
account_key_user_ids: Collection[str],
) -> Dict[str, str]:
"""
Fetch the verified account name user IDs for the given account key user IDs. Unknown account key
user IDs will be omitted from the dict.
Args:
account_key_user_ids: A list of user IDs in account key format e.g
["@l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ:example.com"]
Returns:
A map of account key user IDs to account name user IDs e.g.
{"@l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ:example.com":"@alice:example.com"}
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "account_key_user_id", account_key_user_ids
)
def f(txn: LoggingTransaction) -> List[Tuple[str, str]]:
sql = f"SELECT account_key_user_id, account_name_user_id FROM account_keys WHERE {clause} AND account_name_user_id IS NOT NULL"
txn.execute(sql, args)
return cast(List[Tuple[str, str]], txn.fetchall())
rows = await self.db_pool.runInteraction(
"get_account_name_user_ids_for_account_key_user_ids", f
)
return {row[0]: row[1] for row in rows}
def generate_account_key() -> Tuple[SigningKey, str]:
signing_key = generate_signing_key("1")
verify_key_str = encode_base64(get_verify_key(signing_key).encode(), urlsafe=True)
return signing_key, verify_key_str
def decode_account_key(signing_key: str) -> SigningKey:
return decode_signing_key_base64(
"ed25519",
"1",
signing_key,
)
+10 -5
View File
@@ -84,6 +84,13 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
_POPULATE_PARTICIPANT_BG_UPDATE_BATCH_SIZE = 1000
federation_known_servers_gauge = LaterGauge(
name="synapse_federation_known_servers",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
@attr.s(frozen=True, slots=True, auto_attribs=True)
class EventIdMembership:
"""Returned by `get_membership_from_event_ids`"""
@@ -116,11 +123,9 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
1,
self._count_known_servers,
)
LaterGauge(
name="synapse_federation_known_servers",
desc="",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): self._known_servers_count},
federation_known_servers_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {(self.server_name,): self._known_servers_count},
)
@wrap_as_background_process("_count_known_servers")
+1 -1
View File
@@ -19,7 +19,7 @@
#
#
SCHEMA_VERSION = 92 # remember to update the list below when updating
SCHEMA_VERSION = 93 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -0,0 +1,25 @@
--
-- This file is licensed under the Affero General Public License (AGPL) version 3.
--
-- Copyright (C) 2025 New Vector, Ltd
--
-- This program is free software: you can redistribute it and/or modify
-- it under the terms of the GNU Affero General Public License as
-- published by the Free Software Foundation, either version 3 of the
-- License, or (at your option) any later version.
--
-- See the GNU Affero General Public License for more details:
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
-- Keeps a record of MSC4243 account key <--> account name mappings for all servers.
-- This mapping is permanent.
CREATE TABLE account_keys (
account_key_user_id TEXT PRIMARY KEY NOT NULL,
-- nullable if we cannot talk to the remote server.
account_name_user_id TEXT,
-- the private key as urlsafe base64, only for local accounts
account_key TEXT,
UNIQUE(account_key_user_id, account_name_user_id)
);
CREATE INDEX account_keys_key_for_name ON account_keys (account_name_user_id) WHERE account_name_user_id IS NOT NULL;
+10 -4
View File
@@ -131,22 +131,28 @@ def _get_counts_from_rate_limiter_instance(
# We track the number of affected hosts per time-period so we can
# differentiate one really noisy homeserver from a general
# ratelimit tuning problem across the federation.
LaterGauge(
sleep_affected_hosts_gauge = LaterGauge(
name="synapse_rate_limit_sleep_affected_hosts",
desc="Number of hosts that had requests put to sleep",
labelnames=["rate_limiter_name", SERVER_NAME_LABEL],
caller=lambda: _get_counts_from_rate_limiter_instance(
)
sleep_affected_hosts_gauge.register_hook(
homeserver_instance_id=None,
hook=lambda: _get_counts_from_rate_limiter_instance(
lambda rate_limiter_instance: sum(
ratelimiter.should_sleep()
for ratelimiter in rate_limiter_instance.ratelimiters.values()
)
),
)
LaterGauge(
reject_affected_hosts_gauge = LaterGauge(
name="synapse_rate_limit_reject_affected_hosts",
desc="Number of hosts that had requests rejected",
labelnames=["rate_limiter_name", SERVER_NAME_LABEL],
caller=lambda: _get_counts_from_rate_limiter_instance(
)
reject_affected_hosts_gauge.register_hook(
homeserver_instance_id=None,
hook=lambda: _get_counts_from_rate_limiter_instance(
lambda rate_limiter_instance: sum(
ratelimiter.should_reject()
for ratelimiter in rate_limiter_instance.ratelimiters.values()
+10 -5
View File
@@ -44,6 +44,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
running_tasks_gauge = LaterGauge(
name="synapse_scheduler_running_tasks",
desc="The number of concurrent running tasks handled by the TaskScheduler",
labelnames=[SERVER_NAME_LABEL],
)
class TaskScheduler:
"""
This is a simple task scheduler designed for resumable tasks. Normally,
@@ -130,11 +137,9 @@ class TaskScheduler:
TaskScheduler.SCHEDULE_INTERVAL_MS,
)
LaterGauge(
name="synapse_scheduler_running_tasks",
desc="The number of concurrent running tasks handled by the TaskScheduler",
labelnames=[SERVER_NAME_LABEL],
caller=lambda: {(self.server_name,): len(self._running_tasks)},
running_tasks_gauge.register_hook(
homeserver_instance_id=hs.get_instance_id(),
hook=lambda: {(self.server_name,): len(self._running_tasks)},
)
def register_action(
+32
View File
@@ -34,12 +34,15 @@ from twisted.internet.defer import Deferred, ensureDeferred
from twisted.internet.testing import MemoryReactor
from synapse.api.errors import SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.crypto import keyring
from synapse.crypto.event_signing import compute_event_signature
from synapse.crypto.keyring import (
PerspectivesKeyFetcher,
ServerKeyFetcher,
StoreKeyFetcher,
)
from synapse.events import make_event_from_dict
from synapse.logging.context import (
ContextRequest,
LoggingContext,
@@ -388,6 +391,35 @@ class KeyringTestCase(unittest.HomeserverTestCase):
mock_fetcher1.get_keys.assert_called_once()
mock_fetcher2.get_keys.assert_called_once()
def test_verify_event_for_account_key(self) -> None:
"""Test basic functionality of verify_event_for_account_key.
- That it parses the user ID correctly.
- That it doesn't rely on key fetchers.
"""
room_version = RoomVersions.MSC4243v12
# Make a signing key and replace the key ID from '1' to be the base64 public key
signing_key = signedjson.key.generate_signing_key("1")
verify_key_str = encode_verify_key_base64(get_verify_key(signing_key))
signing_key.version = verify_key_str
domain = "can.be.anything.com"
signing_user_id = f"@{verify_key_str}:{domain}"
event_dict = {
"type": "m.room.create",
"state_key": "",
"sender": signing_user_id,
"content": {
"room_version": room_version.identifier,
},
}
event_dict["signatures"] = compute_event_signature(
room_version, event_dict, signature_name=domain, signing_key=signing_key
)
event = make_event_from_dict(event_dict, room_version)
kr = keyring.Keyring(self.hs, key_fetchers=None)
self.get_success(kr.verify_event_for_account_key(signing_user_id, event))
@logcontext_clean
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
+1 -2
View File
@@ -35,7 +35,7 @@ from synapse.config._base import RootConfig
from synapse.config.auto_accept_invites import AutoAcceptInvitesConfig
from synapse.events.auto_accept_invites import InviteAutoAccepter
from synapse.federation.federation_base import event_from_pdu_json
from synapse.handlers.sync import JoinedSyncResult, SyncRequestKey, SyncVersion
from synapse.handlers.sync import JoinedSyncResult, SyncRequestKey
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -548,7 +548,6 @@ def sync_join(
testcase.hs.get_sync_handler().wait_for_sync_for_user(
requester,
sync_config,
SyncVersion.SYNC_V2,
generate_request_key(),
since_token,
)
+1 -2
View File
@@ -36,7 +36,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict, StreamToken, create_requester
from synapse.util import Clock
from tests.handlers.test_sync import SyncRequestKey, SyncVersion, generate_sync_config
from tests.handlers.test_sync import SyncRequestKey, generate_sync_config
from tests.unittest import (
FederatingHomeserverTestCase,
HomeserverTestCase,
@@ -532,7 +532,6 @@ def sync_presence(
testcase.hs.get_sync_handler().wait_for_sync_for_user(
requester,
sync_config,
SyncVersion.SYNC_V2,
generate_request_key(),
since_token,
)
-30
View File
@@ -37,7 +37,6 @@ from synapse.handlers.sync import (
SyncConfig,
SyncRequestKey,
SyncResult,
SyncVersion,
TimelineBatch,
)
from synapse.rest import admin
@@ -113,7 +112,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -124,7 +122,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
),
ResourceLimitError,
@@ -142,7 +139,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
requester,
sync_config,
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
),
ResourceLimitError,
@@ -167,7 +163,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_config=generate_sync_config(
user, device_id="dev", use_state_after=self.use_state_after
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -203,7 +198,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_config=generate_sync_config(
user, use_state_after=self.use_state_after
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -218,7 +212,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_config=generate_sync_config(
user, device_id="dev", use_state_after=self.use_state_after
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_result.next_batch,
)
@@ -252,7 +245,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_config=generate_sync_config(
user, use_state_after=self.use_state_after
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -267,7 +259,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
sync_config=generate_sync_config(
user, device_id="dev", use_state_after=self.use_state_after
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_result.next_batch,
)
@@ -310,7 +301,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
create_requester(owner),
generate_sync_config(owner, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -336,7 +326,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
eve_requester,
eve_sync_config,
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -363,7 +352,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
eve_requester,
eve_sync_config,
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=eve_sync_after_ban.next_batch,
)
@@ -376,7 +364,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
eve_requester,
eve_sync_config,
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=None,
)
@@ -411,7 +398,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -441,7 +427,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
),
use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_sync_result.next_batch,
)
@@ -487,7 +472,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -527,7 +511,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
),
use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_sync_result.next_batch,
)
@@ -576,7 +559,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -603,7 +585,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
),
use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_sync_result.next_batch,
)
@@ -643,7 +624,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
),
use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=incremental_sync.next_batch,
)
@@ -717,7 +697,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -743,7 +722,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
),
use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -769,7 +747,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(alice, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=initial_sync_result.next_batch,
)
@@ -833,7 +810,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
bob_requester,
generate_sync_config(bob, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -867,7 +843,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
filter_collection=FilterCollection(self.hs, filter_dict),
use_state_after=self.use_state_after,
),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=None if initial_sync else initial_sync_result.next_batch,
)
@@ -967,7 +942,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -1016,7 +990,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
create_requester(user2),
generate_sync_config(user2, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -1042,7 +1015,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
)
)
@@ -1079,7 +1051,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=since_token,
timeout=0,
@@ -1134,7 +1105,6 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler.wait_for_sync_for_user(
create_requester(user),
generate_sync_config(user, use_state_after=self.use_state_after),
sync_version=SyncVersion.SYNC_V2,
request_key=generate_request_key(),
since_token=since_token,
timeout=0,
+98 -2
View File
@@ -18,11 +18,18 @@
# [This file includes modifications made by New Vector Limited]
#
#
from typing import Dict, Protocol, Tuple
from typing import Dict, NoReturn, Protocol, Tuple
from prometheus_client.core import Sample
from synapse.metrics import REGISTRY, InFlightGauge, generate_latest
from synapse.metrics import (
REGISTRY,
SERVER_NAME_LABEL,
InFlightGauge,
LaterGauge,
all_later_gauges_to_clean_up_on_shutdown,
generate_latest,
)
from synapse.util.caches.deferred_cache import DeferredCache
from tests import unittest
@@ -285,6 +292,95 @@ class CacheMetricsTests(unittest.HomeserverTestCase):
self.assertEqual(hs2_cache_max_size_metric_value, "777.0")
class LaterGaugeTests(unittest.HomeserverTestCase):
def setUp(self) -> None:
super().setUp()
self.later_gauge = LaterGauge(
name="foo",
desc="",
labelnames=[SERVER_NAME_LABEL],
)
def tearDown(self) -> None:
super().tearDown()
REGISTRY.unregister(self.later_gauge)
all_later_gauges_to_clean_up_on_shutdown.pop(self.later_gauge.name, None)
def test_later_gauge_multiple_servers(self) -> None:
"""
Test that LaterGauge metrics are reported correctly across multiple servers. We
will have an metrics entry for each homeserver that is labeled with the
`server_name` label.
"""
self.later_gauge.register_hook(
homeserver_instance_id="123", hook=lambda: {("hs1",): 1}
)
self.later_gauge.register_hook(
homeserver_instance_id="456", hook=lambda: {("hs2",): 2}
)
metrics_map = get_latest_metrics()
# Find the metrics from both homeservers
hs1_metric = 'foo{server_name="hs1"}'
hs1_metric_value = metrics_map.get(hs1_metric)
self.assertIsNotNone(
hs1_metric_value,
f"Missing metric {hs1_metric} in metrics {metrics_map}",
)
self.assertEqual(hs1_metric_value, "1.0")
hs2_metric = 'foo{server_name="hs2"}'
hs2_metric_value = metrics_map.get(hs2_metric)
self.assertIsNotNone(
hs2_metric_value,
f"Missing metric {hs2_metric} in metrics {metrics_map}",
)
self.assertEqual(hs2_metric_value, "2.0")
def test_later_gauge_hook_exception(self) -> None:
"""
Test that LaterGauge metrics are collected across multiple servers even if one
hooks is throwing an exception.
"""
def raise_exception() -> NoReturn:
raise Exception("fake error generating data")
# Make the hook for hs1 throw an exception
self.later_gauge.register_hook(
homeserver_instance_id="123", hook=raise_exception
)
# Metrics from hs2 still work fine
self.later_gauge.register_hook(
homeserver_instance_id="456", hook=lambda: {("hs2",): 2}
)
metrics_map = get_latest_metrics()
# Since we encountered an exception while trying to collect metrics from hs1, we
# don't expect to see it here.
hs1_metric = 'foo{server_name="hs1"}'
hs1_metric_value = metrics_map.get(hs1_metric)
self.assertIsNone(
hs1_metric_value,
(
"Since we encountered an exception while trying to collect metrics from hs1"
f"we don't expect to see it the metrics_map {metrics_map}"
),
)
# We should still see metrics from hs2 though
hs2_metric = 'foo{server_name="hs2"}'
hs2_metric_value = metrics_map.get(hs2_metric)
self.assertIsNotNone(
hs2_metric_value,
f"Missing metric {hs2_metric} in cache metrics {metrics_map}",
)
self.assertEqual(hs2_metric_value, "2.0")
def get_latest_metrics() -> Dict[str, str]:
"""
Collect the latest metrics from the registry and parse them into an easy to use map.
+1 -2
View File
@@ -32,7 +32,6 @@ from synapse.config.workers import InstanceTcpLocationConfig, InstanceUnixLocati
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import (
ClientReplicationStreamProtocol,
ServerReplicationStreamProtocol,
@@ -97,7 +96,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.test_handler = self._build_replication_data_handler()
self.worker_hs._replication_data_handler = self.test_handler # type: ignore[attr-defined]
repl_handler = ReplicationCommandHandler(self.worker_hs)
repl_handler = self.worker_hs.get_replication_command_handler()
self.client = ClientReplicationStreamProtocol(
self.worker_hs,
"client",
+22
View File
@@ -18,6 +18,8 @@
# [This file includes modifications made by New Vector Limited]
#
#
from http import HTTPStatus
import synapse
from synapse.api.errors import Codes
from synapse.rest.client import login, push_rule, room
@@ -486,3 +488,23 @@ class PushRuleAttributesTestCase(HomeserverTestCase):
},
channel.json_body,
)
def test_no_user_defined_postcontent_rules(self) -> None:
"""
Tests that clients are not permitted to create MSC4306 `postcontent` rules.
"""
self.register_user("bob", "pass")
token = self.login("bob", "pass")
channel = self.make_request(
"PUT",
"/pushrules/global/postcontent/some.user.rule",
{},
access_token=token,
)
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST)
self.assertEqual(
Codes.INVALID_PARAM,
channel.json_body["errcode"],
)
+10 -32
View File
@@ -18,27 +18,13 @@
# [This file includes modifications made by New Vector Limited]
#
#
from parameterized import parameterized_class
from synapse.api.constants import EduTypes
from synapse.rest import admin
from synapse.rest.client import login, sendtodevice, sync
from synapse.types import JsonDict
from tests.unittest import HomeserverTestCase, override_config
@parameterized_class(
("sync_endpoint", "experimental_features"),
[
("/sync", {}),
(
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
# Enable sliding sync
{"msc3575_enabled": True},
),
],
)
class SendToDeviceTestCase(HomeserverTestCase):
"""
Test `/sendToDevice` will deliver messages across to people receiving them over `/sync`.
@@ -48,9 +34,6 @@ class SendToDeviceTestCase(HomeserverTestCase):
experimental_features: The experimental features homeserver config to use.
"""
sync_endpoint: str
experimental_features: JsonDict
servlets = [
admin.register_servlets,
login.register_servlets,
@@ -58,11 +41,6 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = self.experimental_features
return config
def test_user_to_user(self) -> None:
"""A to-device message from one user to another should get delivered"""
@@ -83,7 +61,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
# check it appears
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
expected_result = {
"events": [
@@ -99,7 +77,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
# it should re-appear if we do another sync because the to-device message is not
# deleted until we acknowledge it by sending a `?since=...` parameter in the
# next sync request corresponding to the `next_batch` value from the response.
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["to_device"], expected_result)
@@ -107,7 +85,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
sync_token = channel.json_body["next_batch"]
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}",
f"/sync?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
@@ -133,7 +111,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
self.assertEqual(chan.code, 200, chan.result)
# now sync: we should get two of the three (because burst_count=2)
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
@@ -163,7 +141,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}",
f"/sync?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
@@ -198,7 +176,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
)
# now sync: we should get two of the three
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
msgs = channel.json_body["to_device"]["events"]
self.assertEqual(len(msgs), 2)
@@ -233,7 +211,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
# ... which should arrive
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}",
f"/sync?since={sync_token}",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
@@ -258,7 +236,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
user2_tok = self.login("u2", "pass", "d2")
# Do an initial sync
channel = self.make_request("GET", self.sync_endpoint, access_token=user2_tok)
channel = self.make_request("GET", "/sync", access_token=user2_tok)
self.assertEqual(channel.code, 200, channel.result)
sync_token = channel.json_body["next_batch"]
@@ -275,7 +253,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
f"/sync?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
@@ -285,7 +263,7 @@ class SendToDeviceTestCase(HomeserverTestCase):
channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={sync_token}&timeout=300000",
f"/sync?since={sync_token}&timeout=300000",
access_token=user2_tok,
)
self.assertEqual(channel.code, 200, channel.result)
+11 -82
View File
@@ -22,7 +22,7 @@ import json
import logging
from typing import List
from parameterized import parameterized, parameterized_class
from parameterized import parameterized
from twisted.internet.testing import MemoryReactor
@@ -702,29 +702,11 @@ class SyncCacheTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body)
@parameterized_class(
("sync_endpoint", "experimental_features"),
[
("/sync", {}),
(
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
# Enable sliding sync
{"msc3575_enabled": True},
),
],
)
class DeviceListSyncTestCase(unittest.HomeserverTestCase):
"""
Tests regarding device list (`device_lists`) changes.
Attributes:
sync_endpoint: The endpoint under test to use for syncing.
experimental_features: The experimental features homeserver config to use.
"""
sync_endpoint: str
experimental_features: JsonDict
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -733,11 +715,6 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
devices.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = self.experimental_features
return config
def test_receiving_local_device_list_changes(self) -> None:
"""Tests that a local users that share a room receive each other's device list
changes.
@@ -767,7 +744,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
# Now have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
self.sync_endpoint,
"/sync",
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -777,7 +754,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch_token}&timeout=30000",
f"/sync?since={next_batch_token}&timeout=30000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
@@ -824,7 +801,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
# Have Bob initiate an initial sync (in order to get a since token)
channel = self.make_request(
"GET",
self.sync_endpoint,
"/sync",
access_token=bob_access_token,
)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -834,7 +811,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
# which we hope will happen as a result of Alice updating their device list.
bob_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch_token}&timeout=1000",
f"/sync?since={next_batch_token}&timeout=1000",
access_token=bob_access_token,
# Start the request, then continue on.
await_result=False,
@@ -873,9 +850,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
channel = self.make_request("GET", "/sync", access_token=alice_access_token)
self.assertEqual(channel.code, 200, channel.json_body)
next_batch = channel.json_body["next_batch"]
@@ -883,7 +858,7 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
# It won't return until something has happened
incremental_sync_channel = self.make_request(
"GET",
f"{self.sync_endpoint}?since={next_batch}&timeout=30000",
f"/sync?since={next_batch}&timeout=30000",
access_token=alice_access_token,
await_result=False,
)
@@ -913,17 +888,6 @@ class DeviceListSyncTestCase(unittest.HomeserverTestCase):
)
@parameterized_class(
("sync_endpoint", "experimental_features"),
[
("/sync", {}),
(
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
# Enable sliding sync
{"msc3575_enabled": True},
),
],
)
class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
"""
Tests regarding device one time keys (`device_one_time_keys_count`) changes.
@@ -933,9 +897,6 @@ class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
experimental_features: The experimental features homeserver config to use.
"""
sync_endpoint: str
experimental_features: JsonDict
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -943,11 +904,6 @@ class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
devices.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = self.experimental_features
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -964,9 +920,7 @@ class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
channel = self.make_request("GET", "/sync", access_token=alice_access_token)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
@@ -1011,9 +965,7 @@ class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
channel = self.make_request("GET", "/sync", access_token=alice_access_token)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
@@ -1024,17 +976,6 @@ class DeviceOneTimeKeysSyncTestCase(unittest.HomeserverTestCase):
)
@parameterized_class(
("sync_endpoint", "experimental_features"),
[
("/sync", {}),
(
"/_matrix/client/unstable/org.matrix.msc3575/sync/e2ee",
# Enable sliding sync
{"msc3575_enabled": True},
),
],
)
class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
"""
Tests regarding device one time keys (`device_unused_fallback_key_types`) changes.
@@ -1044,9 +985,6 @@ class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
experimental_features: The experimental features homeserver config to use.
"""
sync_endpoint: str
experimental_features: JsonDict
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
@@ -1054,11 +992,6 @@ class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
devices.register_servlets,
]
def default_config(self) -> JsonDict:
config = super().default_config()
config["experimental_features"] = self.experimental_features
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@@ -1078,9 +1011,7 @@ class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
channel = self.make_request("GET", "/sync", access_token=alice_access_token)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for those one time key counts
@@ -1122,9 +1053,7 @@ class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase):
self.assertEqual(fallback_res, ["alg1"], fallback_res)
# Request an initial sync
channel = self.make_request(
"GET", self.sync_endpoint, access_token=alice_access_token
)
channel = self.make_request("GET", "/sync", access_token=alice_access_token)
self.assertEqual(channel.code, 200, channel.json_body)
# Check for the unused fallback key types
+3
View File
@@ -1145,6 +1145,9 @@ def setup_test_homeserver(
reactor=reactor,
)
# Register the cleanup hook
cleanup_func(hs.cleanup)
# Install @cache_in_self attributes
for key, val in kwargs.items():
setattr(hs, "_" + key, val)
+92
View File
@@ -0,0 +1,92 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
from signedjson.key import get_verify_key
from unpaddedbase64 import encode_base64
from twisted.internet.testing import MemoryReactor
from synapse.server import HomeServer
from synapse.types import get_localpart_from_id
from synapse.util import Clock
from tests import unittest
class AccountKeysTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main
self.user = "@user:test"
def test_get_or_create_account_key_user_id_for_account_name_user_id(self) -> None:
key_user_id, key = self.get_success(
self.store.get_or_create_account_key_user_id_for_account_name_user_id(
self.user
)
)
# asserts the localpart is unpadded urlsafe base64
self.assertRegex(key_user_id, r"^@[A-Za-z0-9\-_]{43}:test$")
# asserts the public key is the localpart
self.assertEquals(encode_base64(get_verify_key(key).encode(), urlsafe=True), get_localpart_from_id(key_user_id))
# asserts the key ID is 1
self.assertEquals(key.version, "1")
# assert that repeated calls return the same key
key_user_id2, key2 = self.get_success(
self.store.get_or_create_account_key_user_id_for_account_name_user_id(
self.user
)
)
self.assertEquals(key_user_id, key_user_id2)
self.assertEquals(key.encode(), key2.encode())
def test_get_account_name_user_ids_for_account_key_user_ids(self) -> None:
key_user_id, _ = self.get_success(
self.store.get_or_create_account_key_user_id_for_account_name_user_id(
self.user,
)
)
result = self.get_success(
self.store.get_account_name_user_ids_for_account_key_user_ids(
[key_user_id]
),
)
self.assertEquals(result[key_user_id], self.user)
def test_get_account_name_user_ids_for_account_key_user_ids_multiple(self) -> None:
key_user_id_alice, _ = self.get_success(
self.store.get_or_create_account_key_user_id_for_account_name_user_id(
"@alice:test",
)
)
key_user_id_bob, _ = self.get_success(
self.store.get_or_create_account_key_user_id_for_account_name_user_id(
"@bob:test",
)
)
key_user_id_unknown = "@6fey6W1wS3-vbvUmHZnTd6Gi3o-TIxvIcwtEQP4nrW0:test"
result = self.get_success(
self.store.get_account_name_user_ids_for_account_key_user_ids(
[key_user_id_alice, key_user_id_bob, key_user_id_unknown]
),
)
self.assertEquals(result[key_user_id_alice], "@alice:test")
self.assertEquals(result[key_user_id_bob], "@bob:test")
self.assertEquals(result.get(key_user_id_unknown, None), None)