Merge remote-tracking branch 'origin/rei/as_device_masquerading_msc3202' into anoa/e2e_as_internal_testing
This commit is contained in:
1
changelog.d/11243.misc
Normal file
1
changelog.d/11243.misc
Normal file
@@ -0,0 +1 @@
|
||||
Allow specific, experimental events to be created without `prev_events`. Used by [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716).
|
||||
1
changelog.d/11378.feature
Normal file
1
changelog.d/11378.feature
Normal file
@@ -0,0 +1 @@
|
||||
Allow guests to send state events per [MSC3419](https://github.com/matrix-org/matrix-doc/pull/3419).
|
||||
1
changelog.d/11427.doc
Normal file
1
changelog.d/11427.doc
Normal file
@@ -0,0 +1 @@
|
||||
Document the usage of refresh tokens.
|
||||
1
changelog.d/11480.misc
Normal file
1
changelog.d/11480.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add missing type hints to `synapse.config` module.
|
||||
1
changelog.d/11487.misc
Normal file
1
changelog.d/11487.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add test to ensure we share the same `state_group` across the whole historical batch when using the [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` endpoint.
|
||||
1
changelog.d/11516.bugfix
Normal file
1
changelog.d/11516.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a long-standing bug where relations from other rooms could be included in the bundled aggregations of an event.
|
||||
1
changelog.d/11520.misc
Normal file
1
changelog.d/11520.misc
Normal file
@@ -0,0 +1 @@
|
||||
Use HTTPStatus constants in place of literals in `tests.rest.client.test_auth`.
|
||||
1
changelog.d/11531.misc
Normal file
1
changelog.d/11531.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add a receipt types constant for `m.read`.
|
||||
1
changelog.d/11535.misc
Normal file
1
changelog.d/11535.misc
Normal file
@@ -0,0 +1 @@
|
||||
Clean up `synapse.rest.admin`.
|
||||
1
changelog.d/11541.misc
Normal file
1
changelog.d/11541.misc
Normal file
@@ -0,0 +1 @@
|
||||
Support unprefixed versions of fallback key property names.
|
||||
1
changelog.d/11542.misc
Normal file
1
changelog.d/11542.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add missing `errcode` to `parse_string` and `parse_boolean`.
|
||||
1
changelog.d/11543.misc
Normal file
1
changelog.d/11543.misc
Normal file
@@ -0,0 +1 @@
|
||||
Use HTTPStatus constants in place of literals in `synapse.http`.
|
||||
1
changelog.d/11547.bugfix
Normal file
1
changelog.d/11547.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug introduced in Synapse 1.17.0 where a pusher created for an email with capital letters would fail to be created.
|
||||
1
changelog.d/11550.misc
Normal file
1
changelog.d/11550.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix an inaccurate and misleading comment in the `/sync` code.
|
||||
1
changelog.d/11558.misc
Normal file
1
changelog.d/11558.misc
Normal file
@@ -0,0 +1 @@
|
||||
Stop populating unused database column `state_events.prev_state`.
|
||||
1
changelog.d/11560.misc
Normal file
1
changelog.d/11560.misc
Normal file
@@ -0,0 +1 @@
|
||||
Minor efficiency improvements in event persistence.
|
||||
1
changelog.d/11565.misc
Normal file
1
changelog.d/11565.misc
Normal file
@@ -0,0 +1 @@
|
||||
Make `get_device` return `None` if the device doesn't exist rather than raising an exception.
|
||||
@@ -30,6 +30,7 @@
|
||||
- [SSO Mapping Providers](sso_mapping_providers.md)
|
||||
- [Password Auth Providers](password_auth_providers.md)
|
||||
- [JSON Web Tokens](jwt.md)
|
||||
- [Refresh Tokens](usage/configuration/user_authentication/refresh_tokens.md)
|
||||
- [Registration Captcha](CAPTCHA_SETUP.md)
|
||||
- [Application Services](application_services.md)
|
||||
- [Server Notices](server_notices.md)
|
||||
|
||||
139
docs/usage/configuration/user_authentication/refresh_tokens.md
Normal file
139
docs/usage/configuration/user_authentication/refresh_tokens.md
Normal file
@@ -0,0 +1,139 @@
|
||||
# Refresh Tokens
|
||||
|
||||
Synapse supports refresh tokens since version 1.49 (some earlier versions had support for an earlier, experimental draft of [MSC2918] which is not compatible).
|
||||
|
||||
|
||||
[MSC2918]: https://github.com/matrix-org/matrix-doc/blob/main/proposals/2918-refreshtokens.md#msc2918-refresh-tokens
|
||||
|
||||
|
||||
## Background and motivation
|
||||
|
||||
Synapse users' sessions are identified by **access tokens**; access tokens are
|
||||
issued to users on login. Each session gets a unique access token which identifies
|
||||
it; the access token must be kept secret as it grants access to the user's account.
|
||||
|
||||
Traditionally, these access tokens were eternally valid (at least until the user
|
||||
explicitly chose to log out).
|
||||
|
||||
In some cases, it may be desirable for these access tokens to expire so that the
|
||||
potential damage caused by leaking an access token is reduced.
|
||||
On the other hand, forcing a user to re-authenticate (log in again) often might
|
||||
be too much of an inconvenience.
|
||||
|
||||
**Refresh tokens** are a mechanism to avoid some of this inconvenience whilst
|
||||
still getting most of the benefits of short access token lifetimes.
|
||||
Refresh tokens are also a concept present in OAuth 2 — further reading is available
|
||||
[here](https://datatracker.ietf.org/doc/html/rfc6749#section-1.5).
|
||||
|
||||
When refresh tokens are in use, both an access token and a refresh token will be
|
||||
issued to users on login. The access token will expire after a predetermined amount
|
||||
of time, but otherwise works in the same way as before. When the access token is
|
||||
close to expiring (or has expired), the user's client should present the homeserver
|
||||
(Synapse) with the refresh token.
|
||||
|
||||
The homeserver will then generate a new access token and refresh token for the user
|
||||
and return them. The old refresh token is invalidated and can not be used again*.
|
||||
|
||||
Finally, refresh tokens also make it possible for sessions to be logged out if they
|
||||
are inactive for too long, before the session naturally ends; see the configuration
|
||||
guide below.
|
||||
|
||||
|
||||
*To prevent issues if clients lose connection half-way through refreshing a token,
|
||||
the refresh token is only invalidated once the new access token has been used at
|
||||
least once. For all intents and purposes, the above simplification is sufficient.
|
||||
|
||||
|
||||
## Caveats
|
||||
|
||||
There are some caveats:
|
||||
|
||||
* If a third party gets both your access token and refresh token, they will be able to
|
||||
continue to enjoy access to your session.
|
||||
* This is still an improvement because you (the user) will notice when *your*
|
||||
session expires and you're not able to use your refresh token.
|
||||
That would be a giveaway that someone else has compromised your session.
|
||||
You would be able to log in again and terminate that session.
|
||||
Previously (with long-lived access tokens), a third party that has your access
|
||||
token could go undetected for a very long time.
|
||||
* Clients need to implement support for refresh tokens in order for them to be a
|
||||
useful mechanism.
|
||||
* It is up to homeserver administrators if they want to issue long-lived access
|
||||
tokens to clients not implementing refresh tokens.
|
||||
* For compatibility, it is likely that they should, at least until client support
|
||||
is widespread.
|
||||
* Users with clients that support refresh tokens will still benefit from the
|
||||
added security; it's not possible to downgrade a session to using long-lived
|
||||
access tokens so this effectively gives users the choice.
|
||||
* In a closed environment where all users use known clients, this may not be
|
||||
an issue as the homeserver administrator can know if the clients have refresh
|
||||
token support. In that case, the non-refreshable access token lifetime
|
||||
may be set to a short duration so that a similar level of security is provided.
|
||||
|
||||
|
||||
## Configuration Guide
|
||||
|
||||
The following configuration options, in the `registration` section, are related:
|
||||
|
||||
* `session_lifetime`: maximum length of a session, even if it's refreshed.
|
||||
In other words, the client must log in again after this time period.
|
||||
In most cases, this can be unset (infinite) or set to a long time (years or months).
|
||||
* `refreshable_access_token_lifetime`: lifetime of access tokens that are created
|
||||
by clients supporting refresh tokens.
|
||||
This should be short; a good value might be 5 minutes (`5m`).
|
||||
* `nonrefreshable_access_token_lifetime`: lifetime of access tokens that are created
|
||||
by clients which don't support refresh tokens.
|
||||
Make this short if you want to effectively force use of refresh tokens.
|
||||
Make this long if you don't want to inconvenience users of clients which don't
|
||||
support refresh tokens (by forcing them to frequently re-authenticate using
|
||||
login credentials).
|
||||
* `refresh_token_lifetime`: lifetime of refresh tokens.
|
||||
In other words, the client must refresh within this time period to maintain its session.
|
||||
Unless you want to log inactive sessions out, it is often fine to use a long
|
||||
value here or even leave it unset (infinite).
|
||||
Beware that making it too short will inconvenience clients that do not connect
|
||||
very often, including mobile clients and clients of infrequent users (by making
|
||||
it more difficult for them to refresh in time, which may force them to need to
|
||||
re-authenticate using login credentials).
|
||||
|
||||
**Note:** All four options above only apply when tokens are created (by logging in or refreshing).
|
||||
Changes to these settings do not apply retroactively.
|
||||
|
||||
|
||||
### Using refresh token expiry to log out inactive sessions
|
||||
|
||||
If you'd like to force sessions to be logged out upon inactivity, you can enable
|
||||
refreshable access token expiry and refresh token expiry.
|
||||
|
||||
This works because a client must refresh at least once within a period of
|
||||
`refresh_token_lifetime` in order to maintain valid credentials to access the
|
||||
account.
|
||||
|
||||
(It's suggested that `refresh_token_lifetime` should be longer than
|
||||
`refreshable_access_token_lifetime` and this section assumes that to be the case
|
||||
for simplicity.)
|
||||
|
||||
Note: this will only affect sessions using refresh tokens. You may wish to
|
||||
set a short `nonrefreshable_access_token_lifetime` to prevent this being bypassed
|
||||
by clients that do not support refresh tokens.
|
||||
|
||||
|
||||
#### Choosing values that guarantee permitting some inactivity
|
||||
|
||||
It may be desirable to permit some short periods of inactivity, for example to
|
||||
accommodate brief outages in client connectivity.
|
||||
|
||||
The following model aims to provide guidance for choosing `refresh_token_lifetime`
|
||||
and `refreshable_access_token_lifetime` to satisfy requirements of the form:
|
||||
|
||||
1. inactivity longer than `L` **MUST** cause the session to be logged out; and
|
||||
2. inactivity shorter than `S` **MUST NOT** cause the session to be logged out.
|
||||
|
||||
This model makes the weakest assumption that all active clients will refresh as
|
||||
needed to maintain an active access token, but no sooner.
|
||||
*In reality, clients may refresh more often than this model assumes, but the
|
||||
above requirements will still hold.*
|
||||
|
||||
To satisfy the above model,
|
||||
* `refresh_token_lifetime` should be set to `L`; and
|
||||
* `refreshable_access_token_lifetime` should be set to `L - S`.
|
||||
@@ -346,7 +346,7 @@ class Auth:
|
||||
effective_device_id = request.args[DEVICE_ID_ARG_NAME][0].decode("utf8")
|
||||
# We only just set this so it can't be None!
|
||||
assert effective_device_id is not None
|
||||
device_opt = await self.store.get_device_opt(
|
||||
device_opt = await self.store.get_device(
|
||||
effective_user_id, effective_device_id
|
||||
)
|
||||
if device_opt is None:
|
||||
|
||||
@@ -253,5 +253,9 @@ class GuestAccess:
|
||||
FORBIDDEN: Final = "forbidden"
|
||||
|
||||
|
||||
class ReceiptTypes:
|
||||
READ: Final = "m.read"
|
||||
|
||||
|
||||
class ReadReceiptEventFields:
|
||||
MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden"
|
||||
|
||||
@@ -16,12 +16,14 @@
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
import attr
|
||||
import jsonschema
|
||||
from signedjson.key import (
|
||||
NACL_ED25519,
|
||||
SigningKey,
|
||||
VerifyKey,
|
||||
decode_signing_key_base64,
|
||||
decode_verify_key_bytes,
|
||||
generate_signing_key,
|
||||
@@ -31,6 +33,7 @@ from signedjson.key import (
|
||||
)
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.stringutils import random_string, random_string_with_symbols
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
@@ -81,14 +84,13 @@ To suppress this warning and continue using 'matrix.org', admins should set
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class TrustedKeyServer:
|
||||
# string: name of the server.
|
||||
server_name = attr.ib()
|
||||
# name of the server.
|
||||
server_name: str
|
||||
|
||||
# dict[str,VerifyKey]|None: map from key id to key object, or None to disable
|
||||
# signature verification.
|
||||
verify_keys = attr.ib(default=None)
|
||||
# map from key id to key object, or None to disable signature verification.
|
||||
verify_keys: Optional[Dict[str, VerifyKey]] = None
|
||||
|
||||
|
||||
class KeyConfig(Config):
|
||||
@@ -279,15 +281,15 @@ class KeyConfig(Config):
|
||||
% locals()
|
||||
)
|
||||
|
||||
def read_signing_keys(self, signing_key_path, name):
|
||||
def read_signing_keys(self, signing_key_path: str, name: str) -> List[SigningKey]:
|
||||
"""Read the signing keys in the given path.
|
||||
|
||||
Args:
|
||||
signing_key_path (str)
|
||||
name (str): Associated config key name
|
||||
signing_key_path
|
||||
name: Associated config key name
|
||||
|
||||
Returns:
|
||||
list[SigningKey]
|
||||
The signing keys read from the given path.
|
||||
"""
|
||||
|
||||
signing_keys = self.read_file(signing_key_path, name)
|
||||
@@ -296,7 +298,9 @@ class KeyConfig(Config):
|
||||
except Exception as e:
|
||||
raise ConfigError("Error reading %s: %s" % (name, str(e)))
|
||||
|
||||
def read_old_signing_keys(self, old_signing_keys):
|
||||
def read_old_signing_keys(
|
||||
self, old_signing_keys: Optional[JsonDict]
|
||||
) -> Dict[str, VerifyKey]:
|
||||
if old_signing_keys is None:
|
||||
return {}
|
||||
keys = {}
|
||||
@@ -340,7 +344,7 @@ class KeyConfig(Config):
|
||||
write_signing_keys(signing_key_file, (key,))
|
||||
|
||||
|
||||
def _perspectives_to_key_servers(config):
|
||||
def _perspectives_to_key_servers(config: JsonDict) -> Iterator[JsonDict]:
|
||||
"""Convert old-style 'perspectives' configs into new-style 'trusted_key_servers'
|
||||
|
||||
Returns an iterable of entries to add to trusted_key_servers.
|
||||
@@ -402,7 +406,9 @@ TRUSTED_KEY_SERVERS_SCHEMA = {
|
||||
}
|
||||
|
||||
|
||||
def _parse_key_servers(key_servers, federation_verify_certificates):
|
||||
def _parse_key_servers(
|
||||
key_servers: List[Any], federation_verify_certificates: bool
|
||||
) -> Iterator[TrustedKeyServer]:
|
||||
try:
|
||||
jsonschema.validate(key_servers, TRUSTED_KEY_SERVERS_SCHEMA)
|
||||
except jsonschema.ValidationError as e:
|
||||
@@ -444,7 +450,7 @@ def _parse_key_servers(key_servers, federation_verify_certificates):
|
||||
yield result
|
||||
|
||||
|
||||
def _assert_keyserver_has_verify_keys(trusted_key_server):
|
||||
def _assert_keyserver_has_verify_keys(trusted_key_server: TrustedKeyServer) -> None:
|
||||
if not trusted_key_server.verify_keys:
|
||||
raise ConfigError(INSECURE_NOTARY_ERROR)
|
||||
|
||||
|
||||
@@ -22,10 +22,12 @@ from ._base import Config, ConfigError
|
||||
|
||||
@attr.s
|
||||
class MetricsFlags:
|
||||
known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool))
|
||||
known_servers: bool = attr.ib(
|
||||
default=False, validator=attr.validators.instance_of(bool)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def all_off(cls):
|
||||
def all_off(cls) -> "MetricsFlags":
|
||||
"""
|
||||
Instantiate the flags with all options set to off.
|
||||
"""
|
||||
|
||||
@@ -1257,7 +1257,7 @@ class ServerConfig(Config):
|
||||
help="Turn on the twisted telnet manhole service on the given port.",
|
||||
)
|
||||
|
||||
def read_gc_intervals(self, durations) -> Optional[Tuple[float, float, float]]:
|
||||
def read_gc_intervals(self, durations: Any) -> Optional[Tuple[float, float, float]]:
|
||||
"""Reads the three durations for the GC min interval option, returning seconds."""
|
||||
if durations is None:
|
||||
return None
|
||||
|
||||
@@ -132,7 +132,7 @@ class TlsConfig(Config):
|
||||
self.tls_certificate: Optional[crypto.X509] = None
|
||||
self.tls_private_key: Optional[crypto.PKey] = None
|
||||
|
||||
def read_certificate_from_disk(self):
|
||||
def read_certificate_from_disk(self) -> None:
|
||||
"""
|
||||
Read the certificates and private key from disk.
|
||||
"""
|
||||
|
||||
@@ -454,23 +454,26 @@ class EventClientSerializer:
|
||||
return
|
||||
|
||||
event_id = event.event_id
|
||||
room_id = event.room_id
|
||||
|
||||
# The bundled aggregations to include.
|
||||
aggregations = {}
|
||||
|
||||
annotations = await self.store.get_aggregation_groups_for_event(event_id)
|
||||
annotations = await self.store.get_aggregation_groups_for_event(
|
||||
event_id, room_id
|
||||
)
|
||||
if annotations.chunk:
|
||||
aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
|
||||
|
||||
references = await self.store.get_relations_for_event(
|
||||
event_id, RelationTypes.REFERENCE, direction="f"
|
||||
event_id, room_id, RelationTypes.REFERENCE, direction="f"
|
||||
)
|
||||
if references.chunk:
|
||||
aggregations[RelationTypes.REFERENCE] = references.to_dict()
|
||||
|
||||
edit = None
|
||||
if event.type == EventTypes.Message:
|
||||
edit = await self.store.get_applicable_edit(event_id)
|
||||
edit = await self.store.get_applicable_edit(event_id, room_id)
|
||||
|
||||
if edit:
|
||||
# If there is an edit replace the content, preserving existing
|
||||
@@ -503,7 +506,7 @@ class EventClientSerializer:
|
||||
(
|
||||
thread_count,
|
||||
latest_thread_event,
|
||||
) = await self.store.get_thread_summary(event_id)
|
||||
) = await self.store.get_thread_summary(event_id, room_id)
|
||||
if latest_thread_event:
|
||||
aggregations[RelationTypes.THREAD] = {
|
||||
# Don't bundle aggregations as this could recurse forever.
|
||||
|
||||
@@ -997,9 +997,7 @@ class AuthHandler:
|
||||
# really don't want is active access_tokens without a record of the
|
||||
# device, so we double-check it here.
|
||||
if device_id is not None:
|
||||
try:
|
||||
await self.store.get_device(user_id, device_id)
|
||||
except StoreError:
|
||||
if await self.store.get_device(user_id, device_id) is None:
|
||||
await self.store.delete_access_token(access_token)
|
||||
raise StoreError(400, "Login raced against device deletion")
|
||||
|
||||
|
||||
@@ -106,10 +106,10 @@ class DeviceWorkerHandler:
|
||||
Raises:
|
||||
errors.NotFoundError: if the device was not found
|
||||
"""
|
||||
try:
|
||||
device = await self.store.get_device(user_id, device_id)
|
||||
except errors.StoreError:
|
||||
raise errors.NotFoundError
|
||||
device = await self.store.get_device(user_id, device_id)
|
||||
if device is None:
|
||||
raise errors.NotFoundError()
|
||||
|
||||
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
|
||||
_update_device_from_client_ips(device, ips)
|
||||
|
||||
@@ -600,6 +600,8 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
access_token, device_id
|
||||
)
|
||||
old_device = await self.store.get_device(user_id, old_device_id)
|
||||
if old_device is None:
|
||||
raise errors.NotFoundError()
|
||||
await self.store.update_device(user_id, device_id, old_device["display_name"])
|
||||
# can't call self.delete_device because that will clobber the
|
||||
# access token so call the storage layer directly
|
||||
|
||||
@@ -580,7 +580,9 @@ class E2eKeysHandler:
|
||||
log_kv(
|
||||
{"message": "Did not update one_time_keys", "reason": "no keys given"}
|
||||
)
|
||||
fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
|
||||
fallback_keys = keys.get("fallback_keys") or keys.get(
|
||||
"org.matrix.msc2732.fallback_keys"
|
||||
)
|
||||
if fallback_keys and isinstance(fallback_keys, dict):
|
||||
log_kv(
|
||||
{
|
||||
|
||||
@@ -496,6 +496,7 @@ class EventCreationHandler:
|
||||
require_consent: bool = True,
|
||||
outlier: bool = False,
|
||||
historical: bool = False,
|
||||
allow_no_prev_events: bool = False,
|
||||
depth: Optional[int] = None,
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
"""
|
||||
@@ -607,6 +608,7 @@ class EventCreationHandler:
|
||||
prev_event_ids=prev_event_ids,
|
||||
auth_event_ids=auth_event_ids,
|
||||
depth=depth,
|
||||
allow_no_prev_events=allow_no_prev_events,
|
||||
)
|
||||
|
||||
# In an ideal world we wouldn't need the second part of this condition. However,
|
||||
@@ -882,6 +884,7 @@ class EventCreationHandler:
|
||||
prev_event_ids: Optional[List[str]] = None,
|
||||
auth_event_ids: Optional[List[str]] = None,
|
||||
depth: Optional[int] = None,
|
||||
allow_no_prev_events: bool = False,
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
"""Create a new event for a local client
|
||||
|
||||
@@ -912,6 +915,7 @@ class EventCreationHandler:
|
||||
full_state_ids_at_event = None
|
||||
if auth_event_ids is not None:
|
||||
# If auth events are provided, prev events must be also.
|
||||
# prev_event_ids could be an empty array though.
|
||||
assert prev_event_ids is not None
|
||||
|
||||
# Copy the full auth state before it stripped down
|
||||
@@ -943,14 +947,22 @@ class EventCreationHandler:
|
||||
else:
|
||||
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
|
||||
|
||||
# we now ought to have some prev_events (unless it's a create event).
|
||||
#
|
||||
# do a quick sanity check here, rather than waiting until we've created the
|
||||
# Do a quick sanity check here, rather than waiting until we've created the
|
||||
# event and then try to auth it (which fails with a somewhat confusing "No
|
||||
# create event in auth events")
|
||||
assert (
|
||||
builder.type == EventTypes.Create or len(prev_event_ids) > 0
|
||||
), "Attempting to create an event with no prev_events"
|
||||
if allow_no_prev_events:
|
||||
# We allow events with no `prev_events` but it better have some `auth_events`
|
||||
assert (
|
||||
builder.type == EventTypes.Create
|
||||
# Allow an event to have empty list of prev_event_ids
|
||||
# only if it has auth_event_ids.
|
||||
or auth_event_ids
|
||||
), "Attempting to create a non-m.room.create event with no prev_events or auth_event_ids"
|
||||
else:
|
||||
# we now ought to have some prev_events (unless it's a create event).
|
||||
assert (
|
||||
builder.type == EventTypes.Create or prev_event_ids
|
||||
), "Attempting to create a non-m.room.create event with no prev_events"
|
||||
|
||||
event = await builder.build(
|
||||
prev_event_ids=prev_event_ids,
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import ReadReceiptEventFields
|
||||
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.streams import EventSource
|
||||
from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
|
||||
@@ -178,7 +178,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
|
||||
|
||||
for event_id in content.keys():
|
||||
event_content = content.get(event_id, {})
|
||||
m_read = event_content.get("m.read", {})
|
||||
m_read = event_content.get(ReceiptTypes.READ, {})
|
||||
|
||||
# If m_read is missing copy over the original event_content as there is nothing to process here
|
||||
if not m_read:
|
||||
@@ -206,7 +206,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
|
||||
|
||||
# Set new users unless empty
|
||||
if len(new_users.keys()) > 0:
|
||||
new_event["content"][event_id] = {"m.read": new_users}
|
||||
new_event["content"][event_id] = {ReceiptTypes.READ: new_users}
|
||||
|
||||
# Append new_event to visible_events unless empty
|
||||
if len(new_event["content"].keys()) > 0:
|
||||
|
||||
@@ -658,7 +658,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
if block_invite:
|
||||
raise SynapseError(403, "Invites have been disabled on this server")
|
||||
|
||||
if prev_event_ids:
|
||||
# An empty prev_events list is allowed as long as the auth_event_ids are present
|
||||
if prev_event_ids is not None:
|
||||
return await self._local_membership_update(
|
||||
requester=requester,
|
||||
target=target,
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse.api.constants import AccountDataTypes, EventTypes, Membership
|
||||
from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
|
||||
from synapse.api.filtering import FilterCollection
|
||||
from synapse.api.presence import UserPresenceState
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
@@ -1022,7 +1022,7 @@ class SyncHandler:
|
||||
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
|
||||
user_id=sync_config.user.to_string(),
|
||||
room_id=room_id,
|
||||
receipt_type="m.read",
|
||||
receipt_type=ReceiptTypes.READ,
|
||||
)
|
||||
|
||||
notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
|
||||
@@ -1638,20 +1638,20 @@ class SyncHandler:
|
||||
) -> _RoomChanges:
|
||||
"""Determine the changes in rooms to report to the user.
|
||||
|
||||
Ideally, we want to report all events whose stream ordering `s` lies in the
|
||||
range `since_token < s <= now_token`, where the two tokens are read from the
|
||||
sync_result_builder.
|
||||
This function is a first pass at generating the rooms part of the sync response.
|
||||
It determines which rooms have changed during the sync period, and categorises
|
||||
them into four buckets: "knock", "invite", "join" and "leave".
|
||||
|
||||
If there are too many events in that range to report, things get complicated.
|
||||
In this situation we return a truncated list of the most recent events, and
|
||||
indicate in the response that there is a "gap" of omitted events. Additionally:
|
||||
1. Finds all membership changes for the user in the sync period (from
|
||||
`since_token` up to `now_token`).
|
||||
2. Uses those to place the room in one of the four categories above.
|
||||
3. Builds a `_RoomChanges` struct to record this, and return that struct.
|
||||
|
||||
- we include a "state_delta", to describe the changes in state over the gap,
|
||||
- we include all membership events applying to the user making the request,
|
||||
even those in the gap.
|
||||
|
||||
See the spec for the rationale:
|
||||
https://spec.matrix.org/v1.1/client-server-api/#syncing
|
||||
For rooms classified as "knock", "invite" or "leave", we just need to report
|
||||
a single membership event in the eventual /sync response. For "join" we need
|
||||
to fetch additional non-membership events, e.g. messages in the room. That is
|
||||
more complicated, so instead we report an intermediary `RoomSyncResultBuilder`
|
||||
struct, and leave the additional work to `_generate_room_entry`.
|
||||
|
||||
The sync_result_builder is not modified by this function.
|
||||
"""
|
||||
@@ -1662,16 +1662,6 @@ class SyncHandler:
|
||||
|
||||
assert since_token
|
||||
|
||||
# The spec
|
||||
# https://spec.matrix.org/v1.1/client-server-api/#get_matrixclientv3sync
|
||||
# notes that membership events need special consideration:
|
||||
#
|
||||
# > When a sync is limited, the server MUST return membership events for events
|
||||
# > in the gap (between since and the start of the returned timeline), regardless
|
||||
# > as to whether or not they are redundant.
|
||||
#
|
||||
# We fetch such events here, but we only seem to use them for categorising rooms
|
||||
# as newly joined, newly left, invited or knocked.
|
||||
# TODO: we've already called this function and ran this query in
|
||||
# _have_rooms_changed. We could keep the results in memory to avoid a
|
||||
# second query, at the cost of more complicated source code.
|
||||
@@ -1985,6 +1975,23 @@ class SyncHandler:
|
||||
"""Populates the `joined` and `archived` section of `sync_result_builder`
|
||||
based on the `room_builder`.
|
||||
|
||||
Ideally, we want to report all events whose stream ordering `s` lies in the
|
||||
range `since_token < s <= now_token`, where the two tokens are read from the
|
||||
sync_result_builder.
|
||||
|
||||
If there are too many events in that range to report, things get complicated.
|
||||
In this situation we return a truncated list of the most recent events, and
|
||||
indicate in the response that there is a "gap" of omitted events. Lots of this
|
||||
is handled in `_load_filtered_recents`, but some of is handled in this method.
|
||||
|
||||
Additionally:
|
||||
- we include a "state_delta", to describe the changes in state over the gap,
|
||||
- we include all membership events applying to the user making the request,
|
||||
even those in the gap.
|
||||
|
||||
See the spec for the rationale:
|
||||
https://spec.matrix.org/v1.1/client-server-api/#syncing
|
||||
|
||||
Args:
|
||||
sync_result_builder
|
||||
ignored_users: Set of users ignored by user.
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import urllib.parse
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -280,7 +281,9 @@ class BlacklistingAgentWrapper(Agent):
|
||||
ip_address, self._ip_whitelist, self._ip_blacklist
|
||||
):
|
||||
logger.info("Blocking access to %s due to blacklist" % (ip_address,))
|
||||
e = SynapseError(403, "IP address blocked by IP blacklist entry")
|
||||
e = SynapseError(
|
||||
HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry"
|
||||
)
|
||||
return defer.fail(Failure(e))
|
||||
|
||||
return self._agent.request(
|
||||
@@ -719,7 +722,9 @@ class SimpleHttpClient:
|
||||
|
||||
if response.code > 299:
|
||||
logger.warning("Got %d when downloading %s" % (response.code, url))
|
||||
raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN
|
||||
)
|
||||
|
||||
# TODO: if our Content-Type is HTML or something, just read the first
|
||||
# N bytes into RAM rather than saving it all to disk only to read it
|
||||
@@ -731,12 +736,14 @@ class SimpleHttpClient:
|
||||
)
|
||||
except BodyExceededMaxSize:
|
||||
raise SynapseError(
|
||||
502,
|
||||
HTTPStatus.BAD_GATEWAY,
|
||||
"Requested file is too large > %r bytes" % (max_size,),
|
||||
Codes.TOO_LARGE,
|
||||
)
|
||||
except Exception as e:
|
||||
raise SynapseError(502, ("Failed to download remote body: %s" % e)) from e
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e)
|
||||
) from e
|
||||
|
||||
return (
|
||||
length,
|
||||
|
||||
@@ -19,6 +19,7 @@ import random
|
||||
import sys
|
||||
import typing
|
||||
import urllib.parse
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO, StringIO
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -1154,7 +1155,7 @@ class MatrixFederationHttpClient:
|
||||
request.destination,
|
||||
msg,
|
||||
)
|
||||
raise SynapseError(502, msg, Codes.TOO_LARGE)
|
||||
raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
|
||||
except defer.TimeoutError as e:
|
||||
logger.warning(
|
||||
"{%s} [%s] Timed out reading response - %s %s",
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
""" This module contains base REST classes for constructing REST servlets. """
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Iterable,
|
||||
@@ -137,11 +138,15 @@ def parse_integer_from_args(
|
||||
return int(args[name_bytes][0])
|
||||
except Exception:
|
||||
message = "Query parameter %r must be an integer" % (name,)
|
||||
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
|
||||
)
|
||||
else:
|
||||
if required:
|
||||
message = "Missing integer query parameter %r" % (name,)
|
||||
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
|
||||
)
|
||||
else:
|
||||
return default
|
||||
|
||||
@@ -246,11 +251,15 @@ def parse_boolean_from_args(
|
||||
message = (
|
||||
"Boolean query parameter %r must be one of ['true', 'false']"
|
||||
) % (name,)
|
||||
raise SynapseError(400, message)
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
|
||||
)
|
||||
else:
|
||||
if required:
|
||||
message = "Missing boolean query parameter %r" % (name,)
|
||||
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
|
||||
)
|
||||
else:
|
||||
return default
|
||||
|
||||
@@ -313,7 +322,7 @@ def parse_bytes_from_args(
|
||||
return args[name_bytes][0]
|
||||
elif required:
|
||||
message = "Missing string query parameter %s" % (name,)
|
||||
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
|
||||
|
||||
return default
|
||||
|
||||
@@ -407,14 +416,16 @@ def _parse_string_value(
|
||||
try:
|
||||
value_str = value.decode(encoding)
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Query parameter %r must be %s" % (name, encoding))
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Query parameter %r must be %s" % (name, encoding)
|
||||
)
|
||||
|
||||
if allowed_values is not None and value_str not in allowed_values:
|
||||
message = "Query parameter %r must be one of [%s]" % (
|
||||
name,
|
||||
", ".join(repr(v) for v in allowed_values),
|
||||
)
|
||||
raise SynapseError(400, message)
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
|
||||
else:
|
||||
return value_str
|
||||
|
||||
@@ -510,7 +521,9 @@ def parse_strings_from_args(
|
||||
else:
|
||||
if required:
|
||||
message = "Missing string query parameter %r" % (name,)
|
||||
raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
|
||||
)
|
||||
|
||||
return default
|
||||
|
||||
@@ -638,7 +651,7 @@ def parse_json_value_from_request(
|
||||
try:
|
||||
content_bytes = request.content.read() # type: ignore
|
||||
except Exception:
|
||||
raise SynapseError(400, "Error reading JSON content.")
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Error reading JSON content.")
|
||||
|
||||
if not content_bytes and allow_empty_body:
|
||||
return None
|
||||
@@ -647,7 +660,9 @@ def parse_json_value_from_request(
|
||||
content = json_decoder.decode(content_bytes.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
|
||||
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
@@ -673,7 +688,7 @@ def parse_json_object_from_request(
|
||||
|
||||
if not isinstance(content, dict):
|
||||
message = "Content must be a JSON object."
|
||||
raise SynapseError(400, message, errcode=Codes.BAD_JSON)
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.BAD_JSON)
|
||||
|
||||
return content
|
||||
|
||||
@@ -685,7 +700,9 @@ def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
|
||||
absent.append(k)
|
||||
|
||||
if len(absent) > 0:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Missing params: %r" % absent, Codes.MISSING_PARAM
|
||||
)
|
||||
|
||||
|
||||
class RestServlet:
|
||||
@@ -758,10 +775,12 @@ class ResolveRoomIdMixin:
|
||||
resolved_room_id = room_id.to_string()
|
||||
else:
|
||||
raise SynapseError(
|
||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"%s was not legal room ID or room alias" % (room_identifier,),
|
||||
)
|
||||
if not resolved_room_id:
|
||||
raise SynapseError(
|
||||
400, "Unknown room ID or room alias %s" % room_identifier
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Unknown room ID or room alias %s" % room_identifier,
|
||||
)
|
||||
return resolved_room_id, remote_room_hosts
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
from typing import Dict
|
||||
|
||||
from synapse.api.constants import ReceiptTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
|
||||
from synapse.storage import Storage
|
||||
@@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
|
||||
invites = await store.get_invited_rooms_for_local_user(user_id)
|
||||
joins = await store.get_rooms_for_user(user_id)
|
||||
|
||||
my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read")
|
||||
my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ)
|
||||
|
||||
badge = len(invites)
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from synapse.push.pusher import PusherFactory
|
||||
from synapse.replication.http.push import ReplicationRemovePusherRestServlet
|
||||
from synapse.types import JsonDict, RoomStreamToken
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -113,7 +114,9 @@ class PusherPool:
|
||||
"""
|
||||
|
||||
if kind == "email":
|
||||
email_owner = await self.store.get_user_id_by_threepid("email", pushkey)
|
||||
email_owner = await self.store.get_user_id_by_threepid(
|
||||
"email", canonicalise_email(pushkey)
|
||||
)
|
||||
if email_owner != user_id:
|
||||
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ class VersionServlet(RestServlet):
|
||||
|
||||
class PurgeHistoryRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns(
|
||||
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
|
||||
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]*))?$"
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
@@ -195,7 +195,7 @@ class PurgeHistoryRestServlet(RestServlet):
|
||||
|
||||
|
||||
class PurgeHistoryStatusRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
|
||||
PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.pagination_handler = hs.get_pagination_handler()
|
||||
|
||||
@@ -22,7 +22,7 @@ from synapse.http.servlet import (
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
|
||||
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -41,8 +41,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
|
||||
self._data_stores = hs.get_datastores()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
await assert_requester_is_admin(self._auth, request)
|
||||
|
||||
# We need to check that all configured databases have updates enabled.
|
||||
# (They *should* all be in sync.)
|
||||
@@ -51,8 +50,7 @@ class BackgroundUpdateEnabledRestServlet(RestServlet):
|
||||
return HTTPStatus.OK, {"enabled": enabled}
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
await assert_requester_is_admin(self._auth, request)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
@@ -84,8 +82,7 @@ class BackgroundUpdateRestServlet(RestServlet):
|
||||
self._data_stores = hs.get_datastores()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
await assert_requester_is_admin(self._auth, request)
|
||||
|
||||
# We need to check that all configured databases have updates enabled.
|
||||
# (They *should* all be in sync.)
|
||||
@@ -111,15 +108,14 @@ class BackgroundUpdateRestServlet(RestServlet):
|
||||
class BackgroundUpdateStartJobRestServlet(RestServlet):
|
||||
"""Allows to start specific background updates"""
|
||||
|
||||
PATTERNS = admin_patterns("/background_updates/start_job")
|
||||
PATTERNS = admin_patterns("/background_updates/start_job$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
self._store = hs.get_datastore()
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self._auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self._auth, requester.user)
|
||||
await assert_requester_is_admin(self._auth, request)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ["job_name"])
|
||||
|
||||
@@ -42,10 +42,10 @@ class DeviceRestServlet(RestServlet):
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str, device_id: str
|
||||
@@ -53,7 +53,7 @@ class DeviceRestServlet(RestServlet):
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
|
||||
|
||||
u = await self.store.get_user_by_id(target_user.to_string())
|
||||
@@ -63,6 +63,8 @@ class DeviceRestServlet(RestServlet):
|
||||
device = await self.device_handler.get_device(
|
||||
target_user.to_string(), device_id
|
||||
)
|
||||
if device is None:
|
||||
raise NotFoundError("No device found")
|
||||
return HTTPStatus.OK, device
|
||||
|
||||
async def on_DELETE(
|
||||
@@ -71,7 +73,7 @@ class DeviceRestServlet(RestServlet):
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
|
||||
|
||||
u = await self.store.get_user_by_id(target_user.to_string())
|
||||
@@ -87,7 +89,7 @@ class DeviceRestServlet(RestServlet):
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
|
||||
|
||||
u = await self.store.get_user_by_id(target_user.to_string())
|
||||
@@ -109,14 +111,10 @@ class DevicesRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
"""
|
||||
Args:
|
||||
hs: server
|
||||
"""
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
@@ -124,7 +122,7 @@ class DevicesRestServlet(RestServlet):
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
|
||||
|
||||
u = await self.store.get_user_by_id(target_user.to_string())
|
||||
@@ -144,10 +142,10 @@ class DeleteDevicesRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
@@ -155,7 +153,7 @@ class DeleteDevicesRestServlet(RestServlet):
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only lookup local users")
|
||||
|
||||
u = await self.store.get_user_by_id(target_user.to_string())
|
||||
|
||||
@@ -52,7 +52,6 @@ class EventReportsRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/event_reports$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@@ -115,7 +114,6 @@ class EventReportDetailRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ class DestinationsRestServlet(RestServlet):
|
||||
200 OK with details of a destination if success otherwise an error.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]+)$")
|
||||
PATTERNS = admin_patterns("/federation/destinations/(?P<destination>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
|
||||
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
||||
class DeleteGroupAdminRestServlet(RestServlet):
|
||||
"""Allows deleting of local groups"""
|
||||
|
||||
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
|
||||
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.group_server = hs.get_groups_server_handler()
|
||||
|
||||
@@ -17,7 +17,7 @@ import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
@@ -41,9 +41,9 @@ class QuarantineMediaInRoom(RestServlet):
|
||||
"""
|
||||
|
||||
PATTERNS = [
|
||||
*admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine$"),
|
||||
*admin_patterns("/room/(?P<room_id>[^/]*)/media/quarantine$"),
|
||||
# This path kept around for legacy reasons
|
||||
*admin_patterns("/quarantine_media/(?P<room_id>[^/]+)"),
|
||||
*admin_patterns("/quarantine_media/(?P<room_id>[^/]*)$"),
|
||||
]
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
@@ -71,7 +71,7 @@ class QuarantineMediaByUser(RestServlet):
|
||||
this server.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine$")
|
||||
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]*)/media/quarantine$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
@@ -99,7 +99,7 @@ class QuarantineMediaByID(RestServlet):
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns(
|
||||
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
|
||||
"/media/quarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
@@ -128,7 +128,7 @@ class UnquarantineMediaByID(RestServlet):
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns(
|
||||
"/media/unquarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
|
||||
"/media/unquarantine/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
@@ -138,8 +138,7 @@ class UnquarantineMediaByID(RestServlet):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, server_name: str, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
logging.info(
|
||||
"Remove from quarantine local media by ID: %s/%s", server_name, media_id
|
||||
@@ -154,7 +153,7 @@ class UnquarantineMediaByID(RestServlet):
|
||||
class ProtectMediaByID(RestServlet):
|
||||
"""Protect local media from being quarantined."""
|
||||
|
||||
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
|
||||
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
@@ -163,8 +162,7 @@ class ProtectMediaByID(RestServlet):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
logging.info("Protecting local media by ID: %s", media_id)
|
||||
|
||||
@@ -177,7 +175,7 @@ class ProtectMediaByID(RestServlet):
|
||||
class UnprotectMediaByID(RestServlet):
|
||||
"""Unprotect local media from being quarantined."""
|
||||
|
||||
PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]+)")
|
||||
PATTERNS = admin_patterns("/media/unprotect/(?P<media_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
@@ -186,8 +184,7 @@ class UnprotectMediaByID(RestServlet):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, media_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
logging.info("Unprotecting local media by ID: %s", media_id)
|
||||
|
||||
@@ -200,7 +197,7 @@ class UnprotectMediaByID(RestServlet):
|
||||
class ListMediaInRoom(RestServlet):
|
||||
"""Lists all of the media in a given room."""
|
||||
|
||||
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media$")
|
||||
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]*)/media$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
@@ -209,10 +206,7 @@ class ListMediaInRoom(RestServlet):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
is_admin = await self.auth.is_server_admin(requester.user)
|
||||
if not is_admin:
|
||||
raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
|
||||
|
||||
@@ -254,7 +248,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
|
||||
class DeleteMediaByID(RestServlet):
|
||||
"""Delete local media by a given ID. Removes it from this server."""
|
||||
|
||||
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
|
||||
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/(?P<media_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
@@ -286,7 +280,7 @@ class DeleteMediaByDateSize(RestServlet):
|
||||
timestamp and size.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete$")
|
||||
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]*)/delete$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
@@ -353,7 +347,7 @@ class UserMediaRestServlet(RestServlet):
|
||||
media that exist given for this user
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/media$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.is_mine = hs.is_mine
|
||||
@@ -403,16 +397,7 @@ class UserMediaRestServlet(RestServlet):
|
||||
request,
|
||||
"order_by",
|
||||
default=MediaSortOrder.CREATED_TS.value,
|
||||
allowed_values=(
|
||||
MediaSortOrder.MEDIA_ID.value,
|
||||
MediaSortOrder.UPLOAD_NAME.value,
|
||||
MediaSortOrder.CREATED_TS.value,
|
||||
MediaSortOrder.LAST_ACCESS_TS.value,
|
||||
MediaSortOrder.MEDIA_LENGTH.value,
|
||||
MediaSortOrder.MEDIA_TYPE.value,
|
||||
MediaSortOrder.QUARANTINED_BY.value,
|
||||
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
|
||||
),
|
||||
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
|
||||
)
|
||||
direction = parse_string(
|
||||
request, "dir", default="f", allowed_values=("f", "b")
|
||||
@@ -470,16 +455,7 @@ class UserMediaRestServlet(RestServlet):
|
||||
request,
|
||||
"order_by",
|
||||
default=MediaSortOrder.CREATED_TS.value,
|
||||
allowed_values=(
|
||||
MediaSortOrder.MEDIA_ID.value,
|
||||
MediaSortOrder.UPLOAD_NAME.value,
|
||||
MediaSortOrder.CREATED_TS.value,
|
||||
MediaSortOrder.LAST_ACCESS_TS.value,
|
||||
MediaSortOrder.MEDIA_LENGTH.value,
|
||||
MediaSortOrder.MEDIA_TYPE.value,
|
||||
MediaSortOrder.QUARANTINED_BY.value,
|
||||
MediaSortOrder.SAFE_FROM_QUARANTINE.value,
|
||||
),
|
||||
allowed_values=[sort_order.value for sort_order in MediaSortOrder],
|
||||
)
|
||||
direction = parse_string(
|
||||
request, "dir", default="f", allowed_values=("f", "b")
|
||||
|
||||
@@ -70,7 +70,6 @@ class ListRegistrationTokensRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/registration_tokens$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@@ -109,7 +108,6 @@ class NewRegistrationTokenRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/registration_tokens/new$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
@@ -260,7 +258,6 @@ class RegistrationTokenRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/registration_tokens/(?P<token>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@@ -61,7 +61,7 @@ class RoomRestV2Servlet(RestServlet):
|
||||
If 'purge' is true, it will remove all traces of a room from the database.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$", "v2")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
@@ -123,7 +123,7 @@ class RoomRestV2Servlet(RestServlet):
|
||||
class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
|
||||
"""Get the status of the delete room background task."""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete_status$", "v2")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/delete_status$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
@@ -160,7 +160,7 @@ class DeleteRoomStatusByRoomIdRestServlet(RestServlet):
|
||||
class DeleteRoomStatusByDeleteIdRestServlet(RestServlet):
|
||||
"""Get the status of the delete room background task."""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]+)$", "v2")
|
||||
PATTERNS = admin_patterns("/rooms/delete_status/(?P<delete_id>[^/]*)$", "v2")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
@@ -193,35 +193,17 @@ class ListRoomRestServlet(RestServlet):
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
# Extract query parameters
|
||||
start = parse_integer(request, "from", default=0)
|
||||
limit = parse_integer(request, "limit", default=100)
|
||||
order_by = parse_string(request, "order_by", default=RoomSortOrder.NAME.value)
|
||||
if order_by not in (
|
||||
RoomSortOrder.ALPHABETICAL.value,
|
||||
RoomSortOrder.SIZE.value,
|
||||
RoomSortOrder.NAME.value,
|
||||
RoomSortOrder.CANONICAL_ALIAS.value,
|
||||
RoomSortOrder.JOINED_MEMBERS.value,
|
||||
RoomSortOrder.JOINED_LOCAL_MEMBERS.value,
|
||||
RoomSortOrder.VERSION.value,
|
||||
RoomSortOrder.CREATOR.value,
|
||||
RoomSortOrder.ENCRYPTION.value,
|
||||
RoomSortOrder.FEDERATABLE.value,
|
||||
RoomSortOrder.PUBLIC.value,
|
||||
RoomSortOrder.JOIN_RULES.value,
|
||||
RoomSortOrder.GUEST_ACCESS.value,
|
||||
RoomSortOrder.HISTORY_VISIBILITY.value,
|
||||
RoomSortOrder.STATE_EVENTS.value,
|
||||
):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Unknown value for order_by: %s" % (order_by,),
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
order_by = parse_string(
|
||||
request,
|
||||
"order_by",
|
||||
default=RoomSortOrder.NAME.value,
|
||||
allowed_values=[sort_order.value for sort_order in RoomSortOrder],
|
||||
)
|
||||
|
||||
search_term = parse_string(request, "search_term", encoding="utf-8")
|
||||
if search_term == "":
|
||||
@@ -292,10 +274,9 @@ class RoomRestServlet(RestServlet):
|
||||
TODO: Add on_POST to allow room creation without joining the room
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.room_shutdown_handler = hs.get_room_shutdown_handler()
|
||||
@@ -397,10 +378,9 @@ class RoomMembersRestServlet(RestServlet):
|
||||
Get members list of a room.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/members$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@@ -424,10 +404,9 @@ class RoomStateRestServlet(RestServlet):
|
||||
Get full state within a room.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/state$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
@@ -436,8 +415,7 @@ class RoomStateRestServlet(RestServlet):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
ret = await self.store.get_room(room_id)
|
||||
if not ret:
|
||||
@@ -454,14 +432,14 @@ class RoomStateRestServlet(RestServlet):
|
||||
|
||||
class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
|
||||
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_identifier: str
|
||||
@@ -477,7 +455,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
||||
assert_params_in_dict(content, ["user_id"])
|
||||
target_user = UserID.from_string(content["user_id"])
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"This endpoint can only be used with local users",
|
||||
@@ -542,11 +520,10 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
}
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
@@ -688,19 +665,17 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, room_identifier: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
room_id, _ = await self.resolve_room_id(room_identifier)
|
||||
|
||||
@@ -710,8 +685,7 @@ class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_identifier: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
room_id, _ = await self.resolve_room_id(room_identifier)
|
||||
|
||||
@@ -793,7 +767,7 @@ class BlockRoomRestServlet(RestServlet):
|
||||
On GET: Get blocking status of room and user who has blocked this room.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/block$")
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/block$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._auth = hs.get_auth()
|
||||
|
||||
@@ -52,11 +52,11 @@ class SendServerNoticeServlet(RestServlet):
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.server_notices_manager = hs.get_server_notices_manager()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self.txns = HttpTransactionCache(hs)
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
def register(self, json_resource: HttpServer) -> None:
|
||||
PATTERN = "/send_server_notice"
|
||||
@@ -88,7 +88,7 @@ class SendServerNoticeServlet(RestServlet):
|
||||
)
|
||||
|
||||
target_user = UserID.from_string(body["user_id"])
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Server notices can only be sent to local users"
|
||||
)
|
||||
|
||||
@@ -37,7 +37,6 @@ class UserMediaStatisticsRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/statistics/users/media$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@@ -45,19 +44,16 @@ class UserMediaStatisticsRestServlet(RestServlet):
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
order_by = parse_string(
|
||||
request, "order_by", default=UserSortOrder.USER_ID.value
|
||||
request,
|
||||
"order_by",
|
||||
default=UserSortOrder.USER_ID.value,
|
||||
allowed_values=(
|
||||
UserSortOrder.MEDIA_LENGTH.value,
|
||||
UserSortOrder.MEDIA_COUNT.value,
|
||||
UserSortOrder.USER_ID.value,
|
||||
UserSortOrder.DISPLAYNAME.value,
|
||||
),
|
||||
)
|
||||
if order_by not in (
|
||||
UserSortOrder.MEDIA_LENGTH.value,
|
||||
UserSortOrder.MEDIA_COUNT.value,
|
||||
UserSortOrder.USER_ID.value,
|
||||
UserSortOrder.DISPLAYNAME.value,
|
||||
):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Unknown value for order_by: %s" % (order_by,),
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
start = parse_integer(request, "from", default=0)
|
||||
if start < 0:
|
||||
|
||||
@@ -37,7 +37,7 @@ class UsernameAvailableRestServlet(RestServlet):
|
||||
}
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/username_available")
|
||||
PATTERNS = admin_patterns("/username_available$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@@ -66,7 +66,6 @@ class UsersRestServletV2(RestServlet):
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
@@ -126,7 +125,7 @@ class UsersRestServletV2(RestServlet):
|
||||
|
||||
|
||||
class UserRestServletV2(RestServlet):
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2")
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$", "v2")
|
||||
|
||||
"""Get request to list user details.
|
||||
This needs user to have administrator access in Synapse.
|
||||
@@ -414,7 +413,7 @@ class UserRegisterServlet(RestServlet):
|
||||
nonce to the time it was generated, in int seconds.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/register")
|
||||
PATTERNS = admin_patterns("/register$")
|
||||
NONCE_TIMEOUT = 60
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
@@ -561,9 +560,9 @@ class WhoisRestServlet(RestServlet):
|
||||
]
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
@@ -575,7 +574,7 @@ class WhoisRestServlet(RestServlet):
|
||||
if target_user != auth_user:
|
||||
await assert_user_is_admin(self.auth, auth_user)
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
|
||||
|
||||
ret = await self.admin_handler.get_whois(target_user)
|
||||
@@ -584,7 +583,7 @@ class WhoisRestServlet(RestServlet):
|
||||
|
||||
|
||||
class DeactivateAccountRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
|
||||
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||
@@ -630,7 +629,6 @@ class AccountValidityRenewServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/account_validity/validity$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.account_activity_handler = hs.get_account_validity_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@@ -674,11 +672,10 @@ class ResetPasswordRestServlet(RestServlet):
|
||||
200 OK with empty object if success otherwise an error.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
|
||||
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self._set_password_handler = hs.get_set_password_handler()
|
||||
@@ -718,12 +715,12 @@ class SearchUsersRestServlet(RestServlet):
|
||||
200 OK with json object {list[dict[str, Any]], count} or empty object.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
|
||||
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, target_user_id: str
|
||||
@@ -740,7 +737,7 @@ class SearchUsersRestServlet(RestServlet):
|
||||
# if not is_admin and target_user != auth_user:
|
||||
# raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only users a local user")
|
||||
|
||||
term = parse_string(request, "term", required=True)
|
||||
@@ -779,9 +776,9 @@ class UserAdminServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
@@ -790,7 +787,7 @@ class UserAdminServlet(RestServlet):
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Only local users can be admins of this homeserver",
|
||||
@@ -813,7 +810,7 @@ class UserAdminServlet(RestServlet):
|
||||
|
||||
assert_params_in_dict(body, ["admin"])
|
||||
|
||||
if not self.hs.is_mine(target_user):
|
||||
if not self.is_mine(target_user):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Only local users can be admins of this homeserver",
|
||||
@@ -834,7 +831,7 @@ class UserMembershipRestServlet(RestServlet):
|
||||
Get room list of an user.
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/joined_rooms$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.is_mine = hs.is_mine
|
||||
@@ -909,10 +906,10 @@ class UserTokenRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
@@ -921,7 +918,7 @@ class UserTokenRestServlet(RestServlet):
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
auth_user = requester.user
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Only local users can be logged in as"
|
||||
)
|
||||
@@ -975,19 +972,19 @@ class ShadowBanRestServlet(RestServlet):
|
||||
{}
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
|
||||
)
|
||||
@@ -1001,7 +998,7 @@ class ShadowBanRestServlet(RestServlet):
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Only local users can be shadow-banned"
|
||||
)
|
||||
@@ -1027,19 +1024,19 @@ class RateLimitRestServlet(RestServlet):
|
||||
}
|
||||
"""
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit")
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/override_ratelimit$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only look up local users")
|
||||
|
||||
if not await self.store.get_user_by_id(user_id):
|
||||
@@ -1068,7 +1065,7 @@ class RateLimitRestServlet(RestServlet):
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
|
||||
)
|
||||
@@ -1113,7 +1110,7 @@ class RateLimitRestServlet(RestServlet):
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST, "Only local users can be ratelimited"
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api import errors
|
||||
from synapse.api.errors import NotFoundError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
@@ -24,10 +25,9 @@ from synapse.http.servlet import (
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns, interactive_auth_handler
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from ._base import client_patterns, interactive_auth_handler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
@@ -116,6 +116,8 @@ class DeviceRestServlet(RestServlet):
|
||||
device = await self.device_handler.get_device(
|
||||
requester.user.to_string(), device_id
|
||||
)
|
||||
if device is None:
|
||||
raise NotFoundError("No device found")
|
||||
return 200, device
|
||||
|
||||
@interactive_auth_handler
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import ReceiptTypes
|
||||
from synapse.events.utils import format_event_for_client_v2_without_room_id
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
@@ -54,7 +55,7 @@ class NotificationsServlet(RestServlet):
|
||||
)
|
||||
|
||||
receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
|
||||
user_id, "m.read"
|
||||
user_id, ReceiptTypes.READ
|
||||
)
|
||||
|
||||
notif_event_ids = [pa["event_id"] for pa in push_actions]
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import ReadReceiptEventFields
|
||||
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
@@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet):
|
||||
await self.presence_handler.bump_presence_active_time(requester.user)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
read_event_id = body.get("m.read", None)
|
||||
read_event_id = body.get(ReceiptTypes.READ, None)
|
||||
hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False)
|
||||
|
||||
if not isinstance(hidden, bool):
|
||||
@@ -62,7 +62,7 @@ class ReadMarkerRestServlet(RestServlet):
|
||||
if read_event_id:
|
||||
await self.receipts_handler.received_client_receipt(
|
||||
room_id,
|
||||
"m.read",
|
||||
ReceiptTypes.READ,
|
||||
user_id=requester.user.to_string(),
|
||||
event_id=read_event_id,
|
||||
hidden=hidden,
|
||||
|
||||
@@ -16,7 +16,7 @@ import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.constants import ReadReceiptEventFields
|
||||
from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http import get_request_user_agent
|
||||
from synapse.http.server import HttpServer
|
||||
@@ -53,7 +53,7 @@ class ReceiptRestServlet(RestServlet):
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
if receipt_type != "m.read":
|
||||
if receipt_type != ReceiptTypes.READ:
|
||||
raise SynapseError(400, "Receipt type must be 'm.read'")
|
||||
|
||||
# Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body.
|
||||
|
||||
@@ -212,6 +212,7 @@ class RelationPaginationServlet(RestServlet):
|
||||
|
||||
pagination_chunk = await self.store.get_relations_for_event(
|
||||
event_id=parent_id,
|
||||
room_id=room_id,
|
||||
relation_type=relation_type,
|
||||
event_type=event_type,
|
||||
limit=limit,
|
||||
@@ -317,6 +318,7 @@ class RelationAggregationPaginationServlet(RestServlet):
|
||||
|
||||
pagination_chunk = await self.store.get_aggregation_groups_for_event(
|
||||
event_id=parent_id,
|
||||
room_id=room_id,
|
||||
event_type=event_type,
|
||||
limit=limit,
|
||||
from_token=from_token,
|
||||
@@ -383,7 +385,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||
|
||||
# This checks that a) the event exists and b) the user is allowed to
|
||||
# view it.
|
||||
await self.event_handler.get_event(requester.user, room_id, parent_id)
|
||||
event = await self.event_handler.get_event(requester.user, room_id, parent_id)
|
||||
if event is None:
|
||||
raise SynapseError(404, "Unknown parent event.")
|
||||
|
||||
if relation_type != RelationTypes.ANNOTATION:
|
||||
raise SynapseError(400, "Relation type must be 'annotation'")
|
||||
@@ -402,6 +406,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
|
||||
|
||||
result = await self.store.get_relations_for_event(
|
||||
event_id=parent_id,
|
||||
room_id=room_id,
|
||||
relation_type=relation_type,
|
||||
event_type=event_type,
|
||||
aggregation_key=key,
|
||||
|
||||
@@ -187,7 +187,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
||||
state_key: str,
|
||||
txn_id: Optional[str] = None,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
if txn_id:
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
@@ -293,6 +293,9 @@ class SyncRestServlet(RestServlet):
|
||||
response[
|
||||
"org.matrix.msc2732.device_unused_fallback_key_types"
|
||||
] = sync_result.device_unused_fallback_key_types
|
||||
response[
|
||||
"device_unused_fallback_key_types"
|
||||
] = sync_result.device_unused_fallback_key_types
|
||||
|
||||
if joined:
|
||||
response["rooms"][Membership.JOIN] = joined
|
||||
|
||||
@@ -896,6 +896,9 @@ class DatabasePool:
|
||||
) -> None:
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
The input is given as a list of dicts, with one dict per row.
|
||||
Generally simple_insert_many_values should be preferred for new code.
|
||||
|
||||
Args:
|
||||
table: string giving the table name
|
||||
values: dict of new column names and values for them
|
||||
@@ -909,6 +912,9 @@ class DatabasePool:
|
||||
) -> None:
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
The input is given as a list of dicts, with one dict per row.
|
||||
Generally simple_insert_many_values_txn should be preferred for new code.
|
||||
|
||||
Args:
|
||||
txn: The transaction to use.
|
||||
table: string giving the table name
|
||||
@@ -933,23 +939,66 @@ class DatabasePool:
|
||||
if k != keys[0]:
|
||||
raise RuntimeError("All items must have the same keys")
|
||||
|
||||
return DatabasePool.simple_insert_many_values_txn(txn, table, keys[0], vals)
|
||||
|
||||
async def simple_insert_many_values(
|
||||
self,
|
||||
table: str,
|
||||
keys: Collection[str],
|
||||
values: Iterable[Iterable[Any]],
|
||||
desc: str,
|
||||
) -> None:
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
The input is given as a list of rows, where each row is a list of values.
|
||||
(Actually any iterable is fine.)
|
||||
|
||||
Args:
|
||||
table: string giving the table name
|
||||
keys: list of column names
|
||||
values: for each row, a list of values in the same order as `keys`
|
||||
desc: description of the transaction, for logging and metrics
|
||||
"""
|
||||
await self.runInteraction(
|
||||
desc, self.simple_insert_many_values_txn, table, keys, values
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def simple_insert_many_values_txn(
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keys: Collection[str],
|
||||
values: Iterable[Iterable[Any]],
|
||||
) -> None:
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
The input is given as a list of rows, where each row is a list of values.
|
||||
(Actually any iterable is fine.)
|
||||
|
||||
Args:
|
||||
txn: The transaction to use.
|
||||
table: string giving the table name
|
||||
keys: list of column names
|
||||
values: for each row, a list of values in the same order as `keys`
|
||||
"""
|
||||
|
||||
if isinstance(txn.database_engine, PostgresEngine):
|
||||
# We use `execute_values` as it can be a lot faster than `execute_batch`,
|
||||
# but it's only available on postgres.
|
||||
sql = "INSERT INTO %s (%s) VALUES ?" % (
|
||||
table,
|
||||
", ".join(k for k in keys[0]),
|
||||
", ".join(k for k in keys),
|
||||
)
|
||||
|
||||
txn.execute_values(sql, vals, fetch=False)
|
||||
txn.execute_values(sql, values, fetch=False)
|
||||
else:
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
table,
|
||||
", ".join(k for k in keys[0]),
|
||||
", ".join("?" for _ in keys[0]),
|
||||
", ".join(k for k in keys),
|
||||
", ".join("?" for _ in keys),
|
||||
)
|
||||
|
||||
txn.execute_batch(sql, vals)
|
||||
txn.execute_batch(sql, values)
|
||||
|
||||
async def simple_upsert(
|
||||
self,
|
||||
|
||||
@@ -101,7 +101,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
"count_devices_by_users", count_devices_by_users_txn, user_ids
|
||||
)
|
||||
|
||||
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
|
||||
async def get_device(
|
||||
self, user_id: str, device_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve a device. Only returns devices that are not marked as
|
||||
hidden.
|
||||
|
||||
@@ -109,17 +111,15 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
user_id: The ID of the user which owns the device
|
||||
device_id: The ID of the device to retrieve
|
||||
Returns:
|
||||
A dict containing the device information
|
||||
Raises:
|
||||
StoreError: if the device is not found
|
||||
See also:
|
||||
`get_device_opt` which returns None instead if the device is not found
|
||||
A dict containing the device information, or `None` if the device does not
|
||||
exist.
|
||||
"""
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
desc="get_device",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
async def get_device_opt(
|
||||
|
||||
@@ -19,6 +19,7 @@ from collections import OrderedDict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
@@ -1319,14 +1320,13 @@ class PersistEventsStore:
|
||||
|
||||
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
|
||||
|
||||
def _store_event_txn(self, txn, events_and_contexts):
|
||||
def _store_event_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: Collection[Tuple[EventBase, EventContext]],
|
||||
) -> None:
|
||||
"""Insert new events into the event, event_json, redaction and
|
||||
state_events tables.
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||
we are persisting
|
||||
"""
|
||||
|
||||
if not events_and_contexts:
|
||||
@@ -1339,46 +1339,58 @@ class PersistEventsStore:
|
||||
d.pop("redacted_because", None)
|
||||
return d
|
||||
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
self.db_pool.simple_insert_many_values_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
values=[
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"internal_metadata": json_encoder.encode(
|
||||
event.internal_metadata.get_dict()
|
||||
),
|
||||
"json": json_encoder.encode(event_dict(event)),
|
||||
"format_version": event.format_version,
|
||||
}
|
||||
keys=("event_id", "room_id", "internal_metadata", "json", "format_version"),
|
||||
values=(
|
||||
(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
json_encoder.encode(event.internal_metadata.get_dict()),
|
||||
json_encoder.encode(event_dict(event)),
|
||||
event.format_version,
|
||||
)
|
||||
for event, _ in events_and_contexts
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
self.db_pool.simple_insert_many_values_txn(
|
||||
txn,
|
||||
table="events",
|
||||
values=[
|
||||
{
|
||||
"instance_name": self._instance_name,
|
||||
"stream_ordering": event.internal_metadata.stream_ordering,
|
||||
"topological_ordering": event.depth,
|
||||
"depth": event.depth,
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"processed": True,
|
||||
"outlier": event.internal_metadata.is_outlier(),
|
||||
"origin_server_ts": int(event.origin_server_ts),
|
||||
"received_ts": self._clock.time_msec(),
|
||||
"sender": event.sender,
|
||||
"contains_url": (
|
||||
"url" in event.content and isinstance(event.content["url"], str)
|
||||
),
|
||||
}
|
||||
keys=(
|
||||
"instance_name",
|
||||
"stream_ordering",
|
||||
"topological_ordering",
|
||||
"depth",
|
||||
"event_id",
|
||||
"room_id",
|
||||
"type",
|
||||
"processed",
|
||||
"outlier",
|
||||
"origin_server_ts",
|
||||
"received_ts",
|
||||
"sender",
|
||||
"contains_url",
|
||||
),
|
||||
values=(
|
||||
(
|
||||
self._instance_name,
|
||||
event.internal_metadata.stream_ordering,
|
||||
event.depth, # topological_ordering
|
||||
event.depth, # depth
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
event.type,
|
||||
True, # processed
|
||||
event.internal_metadata.is_outlier(),
|
||||
int(event.origin_server_ts),
|
||||
self._clock.time_msec(),
|
||||
event.sender,
|
||||
"url" in event.content and isinstance(event.content["url"], str),
|
||||
)
|
||||
for event, _ in events_and_contexts
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# If we're persisting an unredacted event we go and ensure
|
||||
@@ -1397,27 +1409,15 @@ class PersistEventsStore:
|
||||
)
|
||||
txn.execute(sql + clause, [False] + args)
|
||||
|
||||
state_events_and_contexts = [
|
||||
ec for ec in events_and_contexts if ec[0].is_state()
|
||||
]
|
||||
|
||||
state_values = []
|
||||
for event, _ in state_events_and_contexts:
|
||||
vals = {
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
}
|
||||
|
||||
# TODO: How does this work with backfilling?
|
||||
if hasattr(event, "replaces_state"):
|
||||
vals["prev_state"] = event.replaces_state
|
||||
|
||||
state_values.append(vals)
|
||||
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn, table="state_events", values=state_values
|
||||
self.db_pool.simple_insert_many_values_txn(
|
||||
txn,
|
||||
table="state_events",
|
||||
keys=("event_id", "room_id", "type", "state_key"),
|
||||
values=(
|
||||
(event.event_id, event.room_id, event.type, event.state_key)
|
||||
for event, _ in events_and_contexts
|
||||
if event.is_state()
|
||||
),
|
||||
)
|
||||
|
||||
def _store_rejected_events_txn(self, txn, events_and_contexts):
|
||||
@@ -1780,10 +1780,14 @@ class PersistEventsStore:
|
||||
)
|
||||
|
||||
if rel_type == RelationTypes.REPLACE:
|
||||
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
|
||||
txn.call_after(
|
||||
self.store.get_applicable_edit.invalidate, (parent_id, event.room_id)
|
||||
)
|
||||
|
||||
if rel_type == RelationTypes.THREAD:
|
||||
txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
|
||||
txn.call_after(
|
||||
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
|
||||
)
|
||||
|
||||
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
|
||||
"""Handles keeping track of insertion events and edges/connections.
|
||||
|
||||
@@ -18,6 +18,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -103,7 +104,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
||||
: self.hs.config.server.max_mau_value
|
||||
]:
|
||||
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
||||
tp["medium"], tp["address"]
|
||||
tp["medium"], canonicalise_email(tp["address"])
|
||||
)
|
||||
if user_id:
|
||||
users.append(user_id)
|
||||
|
||||
@@ -14,14 +14,25 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import ReceiptTypes
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.replication.tcp.streams import ReceiptsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.types import JsonDict
|
||||
@@ -78,17 +89,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
|
||||
)
|
||||
|
||||
def get_max_receipt_stream_id(self):
|
||||
"""Get the current max stream ID for receipts stream
|
||||
|
||||
Returns:
|
||||
int
|
||||
"""
|
||||
def get_max_receipt_stream_id(self) -> int:
|
||||
"""Get the current max stream ID for receipts stream"""
|
||||
return self._receipts_id_gen.get_current_token()
|
||||
|
||||
@cached()
|
||||
async def get_users_with_read_receipts_in_room(self, room_id):
|
||||
receipts = await self.get_receipts_for_room(room_id, "m.read")
|
||||
async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]:
|
||||
receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ)
|
||||
return {r["user_id"] for r in receipts}
|
||||
|
||||
@cached(num_args=2)
|
||||
@@ -119,7 +126,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@cached(num_args=2)
|
||||
async def get_receipts_for_user(self, user_id, receipt_type):
|
||||
async def get_receipts_for_user(
|
||||
self, user_id: str, receipt_type: str
|
||||
) -> Dict[str, str]:
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
table="receipts_linearized",
|
||||
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
|
||||
@@ -129,8 +138,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
return {row["room_id"]: row["event_id"] for row in rows}
|
||||
|
||||
async def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
|
||||
def f(txn):
|
||||
async def get_receipts_for_user_with_orderings(
|
||||
self, user_id: str, receipt_type: str
|
||||
) -> JsonDict:
|
||||
def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
|
||||
sql = (
|
||||
"SELECT rl.room_id, rl.event_id,"
|
||||
" e.topological_ordering, e.stream_ordering"
|
||||
@@ -209,10 +220,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
@cached(num_args=3, tree=True)
|
||||
async def _get_linearized_receipts_for_room(
|
||||
self, room_id: str, to_key: int, from_key: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
) -> List[JsonDict]:
|
||||
"""See get_linearized_receipts_for_room"""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
||||
if from_key:
|
||||
sql = (
|
||||
"SELECT * FROM receipts_linearized WHERE"
|
||||
@@ -250,11 +261,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
list_name="room_ids",
|
||||
num_args=3,
|
||||
)
|
||||
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
async def _get_linearized_receipts_for_rooms(
|
||||
self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
|
||||
) -> Dict[str, List[JsonDict]]:
|
||||
if not room_ids:
|
||||
return {}
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
||||
if from_key:
|
||||
sql = """
|
||||
SELECT * FROM receipts_linearized WHERE
|
||||
@@ -323,7 +336,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
A dictionary of roomids to a list of receipts.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
||||
if from_key:
|
||||
sql = """
|
||||
SELECT * FROM receipts_linearized WHERE
|
||||
@@ -379,7 +392,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
|
||||
def _get_users_sent_receipts_between_txn(txn):
|
||||
def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
|
||||
sql = """
|
||||
SELECT DISTINCT user_id FROM receipts_linearized
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
@@ -419,7 +432,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
def get_all_updated_receipts_txn(txn):
|
||||
def get_all_updated_receipts_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[Tuple[int, list]], int, bool]:
|
||||
sql = """
|
||||
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
|
||||
FROM receipts_linearized
|
||||
@@ -446,8 +461,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
def _invalidate_get_users_with_receipts_in_room(
|
||||
self, room_id: str, receipt_type: str, user_id: str
|
||||
):
|
||||
if receipt_type != "m.read":
|
||||
) -> None:
|
||||
if receipt_type != ReceiptTypes.READ:
|
||||
return
|
||||
|
||||
res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
|
||||
@@ -461,7 +476,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
|
||||
|
||||
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
|
||||
def invalidate_caches_for_receipt(
|
||||
self, room_id: str, receipt_type: str, user_id: str
|
||||
) -> None:
|
||||
self.get_receipts_for_user.invalidate((user_id, receipt_type))
|
||||
self._get_linearized_receipts_for_room.invalidate((room_id,))
|
||||
self.get_last_receipt_event_id_for_user.invalidate(
|
||||
@@ -482,11 +499,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def insert_linearized_receipt_txn(
|
||||
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
receipt_type: str,
|
||||
user_id: str,
|
||||
event_id: str,
|
||||
data: JsonDict,
|
||||
stream_id: int,
|
||||
) -> Optional[int]:
|
||||
"""Inserts a read-receipt into the database if it's newer than the current RR
|
||||
|
||||
Returns: int|None
|
||||
Returns:
|
||||
None if the RR is older than the current RR
|
||||
otherwise, the rx timestamp of the event that the RR corresponds to
|
||||
(or 0 if the event is unknown)
|
||||
@@ -550,7 +574,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
lock=False,
|
||||
)
|
||||
|
||||
if receipt_type == "m.read" and stream_ordering is not None:
|
||||
if receipt_type == ReceiptTypes.READ and stream_ordering is not None:
|
||||
self._remove_old_push_actions_before_txn(
|
||||
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
|
||||
)
|
||||
@@ -580,7 +604,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
else:
|
||||
# we need to points in graph -> linearized form.
|
||||
# TODO: Make this better.
|
||||
def graph_to_linear(txn):
|
||||
def graph_to_linear(txn: LoggingTransaction) -> str:
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "event_id", event_ids
|
||||
)
|
||||
@@ -634,11 +658,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
return stream_id, max_persisted_id
|
||||
|
||||
async def insert_graph_receipt(
|
||||
self, room_id, receipt_type, user_id, event_ids, data
|
||||
):
|
||||
self,
|
||||
room_id: str,
|
||||
receipt_type: str,
|
||||
user_id: str,
|
||||
event_ids: List[str],
|
||||
data: JsonDict,
|
||||
) -> None:
|
||||
assert self._can_write_to_receipts
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"insert_graph_receipt",
|
||||
self.insert_graph_receipt_txn,
|
||||
room_id,
|
||||
@@ -649,8 +678,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
def insert_graph_receipt_txn(
|
||||
self, txn, room_id, receipt_type, user_id, event_ids, data
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
receipt_type: str,
|
||||
user_id: str,
|
||||
event_ids: List[str],
|
||||
data: JsonDict,
|
||||
) -> None:
|
||||
assert self._can_write_to_receipts
|
||||
|
||||
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
|
||||
|
||||
@@ -856,7 +856,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||
|
||||
Args:
|
||||
medium: threepid medium e.g. email
|
||||
address: threepid address e.g. me@example.com
|
||||
address: threepid address e.g. me@example.com. This must already be
|
||||
in canonical form.
|
||||
|
||||
Returns:
|
||||
The user ID or None if no user id/threepid mapping exists
|
||||
|
||||
@@ -37,6 +37,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
async def get_relations_for_event(
|
||||
self,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
relation_type: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
aggregation_key: Optional[str] = None,
|
||||
@@ -49,6 +50,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
|
||||
Args:
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
room_id: The room the event belongs to.
|
||||
relation_type: Only fetch events with this relation type, if given.
|
||||
event_type: Only fetch events with this event type, if given.
|
||||
aggregation_key: Only fetch events with this aggregation key, if given.
|
||||
@@ -63,8 +65,8 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
the form `{"event_id": "..."}`.
|
||||
"""
|
||||
|
||||
where_clause = ["relates_to_id = ?"]
|
||||
where_args: List[Union[str, int]] = [event_id]
|
||||
where_clause = ["relates_to_id = ?", "room_id = ?"]
|
||||
where_args: List[Union[str, int]] = [event_id, room_id]
|
||||
|
||||
if relation_type is not None:
|
||||
where_clause.append("relation_type = ?")
|
||||
@@ -199,6 +201,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
async def get_aggregation_groups_for_event(
|
||||
self,
|
||||
event_id: str,
|
||||
room_id: str,
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
direction: str = "b",
|
||||
@@ -213,6 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
|
||||
Args:
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
room_id: The room the event belongs to.
|
||||
event_type: Only fetch events with this event type, if given.
|
||||
limit: Only fetch the `limit` groups.
|
||||
direction: Whether to fetch the highest count first (`"b"`) or
|
||||
@@ -225,8 +229,12 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
`type`, `key` and `count` fields.
|
||||
"""
|
||||
|
||||
where_clause = ["relates_to_id = ?", "relation_type = ?"]
|
||||
where_args: List[Union[str, int]] = [event_id, RelationTypes.ANNOTATION]
|
||||
where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
|
||||
where_args: List[Union[str, int]] = [
|
||||
event_id,
|
||||
room_id,
|
||||
RelationTypes.ANNOTATION,
|
||||
]
|
||||
|
||||
if event_type:
|
||||
where_clause.append("type = ?")
|
||||
@@ -288,7 +296,9 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@cached()
|
||||
async def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
|
||||
async def get_applicable_edit(
|
||||
self, event_id: str, room_id: str
|
||||
) -> Optional[EventBase]:
|
||||
"""Get the most recent edit (if any) that has happened for the given
|
||||
event.
|
||||
|
||||
@@ -296,6 +306,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
|
||||
Args:
|
||||
event_id: The original event ID
|
||||
room_id: The original event's room ID
|
||||
|
||||
Returns:
|
||||
The most recent edit, if any.
|
||||
@@ -317,13 +328,14 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
WHERE
|
||||
relates_to_id = ?
|
||||
AND relation_type = ?
|
||||
AND edit.room_id = ?
|
||||
AND edit.type = 'm.room.message'
|
||||
ORDER by edit.origin_server_ts DESC, edit.event_id DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
def _get_applicable_edit_txn(txn: LoggingTransaction) -> Optional[str]:
|
||||
txn.execute(sql, (event_id, RelationTypes.REPLACE))
|
||||
txn.execute(sql, (event_id, RelationTypes.REPLACE, room_id))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
return row[0]
|
||||
@@ -340,13 +352,14 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
|
||||
@cached()
|
||||
async def get_thread_summary(
|
||||
self, event_id: str
|
||||
self, event_id: str, room_id: str
|
||||
) -> Tuple[int, Optional[EventBase]]:
|
||||
"""Get the number of threaded replies, the senders of those replies, and
|
||||
the latest reply (if any) for the given event.
|
||||
|
||||
Args:
|
||||
event_id: The original event ID
|
||||
event_id: Summarize the thread related to this event ID.
|
||||
room_id: The room the event belongs to.
|
||||
|
||||
Returns:
|
||||
The number of items in the thread and the most recent response, if any.
|
||||
@@ -363,12 +376,13 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE
|
||||
relates_to_id = ?
|
||||
AND room_id = ?
|
||||
AND relation_type = ?
|
||||
ORDER BY topological_ordering DESC, stream_ordering DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
txn.execute(sql, (event_id, RelationTypes.THREAD))
|
||||
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
|
||||
row = txn.fetchone()
|
||||
if row is None:
|
||||
return 0, None
|
||||
@@ -378,11 +392,13 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
sql = """
|
||||
SELECT COALESCE(COUNT(event_id), 0)
|
||||
FROM event_relations
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE
|
||||
relates_to_id = ?
|
||||
AND room_id = ?
|
||||
AND relation_type = ?
|
||||
"""
|
||||
txn.execute(sql, (event_id, RelationTypes.THREAD))
|
||||
txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
|
||||
count = txn.fetchone()[0] # type: ignore[index]
|
||||
|
||||
return count, latest_event_id
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SCHEMA_VERSION = 66 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 67 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
@@ -50,6 +50,9 @@ Changes in SCHEMA_VERSION = 65:
|
||||
Changes in SCHEMA_VERSION = 66:
|
||||
- Queries on state_key columns are now disambiguated (ie, the codebase can handle
|
||||
the `events` table having a `state_key` column).
|
||||
|
||||
Changes in SCHEMA_VERSION = 67:
|
||||
- state_events.prev_state is no longer written to.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -260,7 +260,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
# This just needs to return a truth-y value.
|
||||
self.store.get_user_by_id = simple_async_mock({"is_guest": False})
|
||||
self.store.get_user_by_access_token = simple_async_mock(None)
|
||||
# This also needs to just return a truth-y value
|
||||
# This also needs to just return a falsey value
|
||||
self.store.get_device_opt = simple_async_mock(None)
|
||||
|
||||
request = Mock(args={})
|
||||
|
||||
@@ -161,8 +161,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
def test_fallback_key(self):
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
device_id = "xyz"
|
||||
fallback_key = {"alg1:k1": "key1"}
|
||||
fallback_key2 = {"alg1:k2": "key2"}
|
||||
fallback_key = {"alg1:k1": "fallback_key1"}
|
||||
fallback_key2 = {"alg1:k2": "fallback_key2"}
|
||||
fallback_key3 = {"alg1:k2": "fallback_key3"}
|
||||
otk = {"alg1:k2": "key2"}
|
||||
|
||||
# we shouldn't have any unused fallback keys yet
|
||||
@@ -175,7 +176,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id,
|
||||
{"org.matrix.msc2732.fallback_keys": fallback_key},
|
||||
{"fallback_keys": fallback_key},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -220,7 +221,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id,
|
||||
{"org.matrix.msc2732.fallback_keys": fallback_key},
|
||||
{"fallback_keys": fallback_key},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -234,7 +235,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id,
|
||||
{"org.matrix.msc2732.fallback_keys": fallback_key2},
|
||||
{"fallback_keys": fallback_key2},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -271,6 +272,25 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
|
||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key2}}},
|
||||
)
|
||||
|
||||
# using the unstable prefix should also set the fallback key
|
||||
self.get_success(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id,
|
||||
{"org.matrix.msc2732.fallback_keys": fallback_key3},
|
||||
)
|
||||
)
|
||||
|
||||
res = self.get_success(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
res,
|
||||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}},
|
||||
)
|
||||
|
||||
def test_replace_master_key(self):
|
||||
"""uploading a new signing key should make the old signing key unavailable"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
|
||||
@@ -23,6 +23,7 @@ from synapse.types import create_requester
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils.event_injection import create_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,6 +52,24 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
|
||||
|
||||
def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]:
|
||||
# Create a member event we can use as an auth_event
|
||||
memberEvent, memberEventContext = self.get_success(
|
||||
create_event(
|
||||
self.hs,
|
||||
room_id=self.room_id,
|
||||
type="m.room.member",
|
||||
sender=self.requester.user.to_string(),
|
||||
state_key=self.requester.user.to_string(),
|
||||
content={"membership": "join"},
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self.persist_event_storage.persist_event(memberEvent, memberEventContext)
|
||||
)
|
||||
|
||||
return memberEvent, memberEventContext
|
||||
|
||||
def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
|
||||
"""Create a new event with the given transaction ID. All events produced
|
||||
by this method will be considered duplicates.
|
||||
@@ -156,6 +175,90 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(len(events), 2)
|
||||
self.assertEqual(events[0].event_id, events[1].event_id)
|
||||
|
||||
def test_when_empty_prev_events_allowed_create_event_with_empty_prev_events(self):
|
||||
"""When we set allow_no_prev_events=True, should be able to create a
|
||||
event without any prev_events (only auth_events).
|
||||
"""
|
||||
# Create a member event we can use as an auth_event
|
||||
memberEvent, _ = self._create_and_persist_member_event()
|
||||
|
||||
# Try to create the event with empty prev_events bit with some auth_events
|
||||
event, _ = self.get_success(
|
||||
self.handler.create_event(
|
||||
self.requester,
|
||||
{
|
||||
"type": EventTypes.Message,
|
||||
"room_id": self.room_id,
|
||||
"sender": self.requester.user.to_string(),
|
||||
"content": {"msgtype": "m.text", "body": random_string(5)},
|
||||
},
|
||||
# Empty prev_events is the key thing we're testing here
|
||||
prev_event_ids=[],
|
||||
# But with some auth_events
|
||||
auth_event_ids=[memberEvent.event_id],
|
||||
# Allow no prev_events!
|
||||
allow_no_prev_events=True,
|
||||
)
|
||||
)
|
||||
self.assertIsNotNone(event)
|
||||
|
||||
def test_when_empty_prev_events_not_allowed_reject_event_with_empty_prev_events(
|
||||
self,
|
||||
):
|
||||
"""When we set allow_no_prev_events=False, shouldn't be able to create a
|
||||
event without any prev_events even if it has auth_events. Expect an
|
||||
exception to be raised.
|
||||
"""
|
||||
# Create a member event we can use as an auth_event
|
||||
memberEvent, _ = self._create_and_persist_member_event()
|
||||
|
||||
# Try to create the event with empty prev_events but with some auth_events
|
||||
self.get_failure(
|
||||
self.handler.create_event(
|
||||
self.requester,
|
||||
{
|
||||
"type": EventTypes.Message,
|
||||
"room_id": self.room_id,
|
||||
"sender": self.requester.user.to_string(),
|
||||
"content": {"msgtype": "m.text", "body": random_string(5)},
|
||||
},
|
||||
# Empty prev_events is the key thing we're testing here
|
||||
prev_event_ids=[],
|
||||
# But with some auth_events
|
||||
auth_event_ids=[memberEvent.event_id],
|
||||
# We expect the test to fail because empty prev_events are not
|
||||
# allowed here!
|
||||
allow_no_prev_events=False,
|
||||
),
|
||||
AssertionError,
|
||||
)
|
||||
|
||||
def test_when_empty_prev_events_allowed_reject_event_with_empty_prev_events_and_auth_events(
|
||||
self,
|
||||
):
|
||||
"""When we set allow_no_prev_events=True, should be able to create a
|
||||
event without any prev_events or auth_events. Expect an exception to be
|
||||
raised.
|
||||
"""
|
||||
# Try to create the event with empty prev_events and empty auth_events
|
||||
self.get_failure(
|
||||
self.handler.create_event(
|
||||
self.requester,
|
||||
{
|
||||
"type": EventTypes.Message,
|
||||
"room_id": self.room_id,
|
||||
"sender": self.requester.user.to_string(),
|
||||
"content": {"msgtype": "m.text", "body": random_string(5)},
|
||||
},
|
||||
prev_event_ids=[],
|
||||
# The event should be rejected when there are no auth_events
|
||||
auth_event_ids=[],
|
||||
# Allow no prev_events!
|
||||
allow_no_prev_events=True,
|
||||
),
|
||||
AssertionError,
|
||||
)
|
||||
|
||||
|
||||
class ServerAclValidationTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
|
||||
@@ -95,7 +95,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
# invalid search order
|
||||
channel = self.make_request(
|
||||
@@ -105,7 +105,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
# invalid destination
|
||||
channel = self.make_request(
|
||||
|
||||
@@ -360,7 +360,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
|
||||
channel.code,
|
||||
msg=channel.json_body,
|
||||
)
|
||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
self.assertEqual(
|
||||
"Boolean query parameter 'keep_profiles' must be one of ['true', 'false']",
|
||||
channel.json_body["error"],
|
||||
|
||||
@@ -608,7 +608,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
# invalid deactivated
|
||||
channel = self.make_request(
|
||||
@@ -618,7 +618,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
# unkown order_by
|
||||
channel = self.make_request(
|
||||
@@ -628,7 +628,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
# invalid search order
|
||||
channel = self.make_request(
|
||||
@@ -638,7 +638,7 @@ class UsersListTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
def test_limit(self):
|
||||
"""
|
||||
@@ -1550,7 +1550,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
# Create user
|
||||
body = {
|
||||
"password": "abc123",
|
||||
"threepids": [{"medium": "email", "address": "bob@bob.bob"}],
|
||||
# Note that the given email is not in canonical form.
|
||||
"threepids": [{"medium": "email", "address": "Bob@bob.bob"}],
|
||||
}
|
||||
|
||||
channel = self.make_request(
|
||||
@@ -2896,7 +2897,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
# invalid search order
|
||||
channel = self.make_request(
|
||||
@@ -2906,7 +2907,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.assertEqual(HTTPStatus.BAD_REQUEST, channel.code, msg=channel.json_body)
|
||||
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
|
||||
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
|
||||
|
||||
# negative limit
|
||||
channel = self.make_request(
|
||||
|
||||
@@ -85,7 +85,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||
channel = self.make_request(
|
||||
"GET", "auth/m.login.recaptcha/fallback/web?session=" + session
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK)
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
@@ -104,7 +104,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||
"""Ensure that fallback auth via a captcha works."""
|
||||
# Returns a 401 as per the spec
|
||||
channel = self.register(
|
||||
401,
|
||||
HTTPStatus.UNAUTHORIZED,
|
||||
{"username": "user", "type": "m.login.password", "password": "bar"},
|
||||
)
|
||||
|
||||
@@ -116,15 +116,17 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# Complete the recaptcha step.
|
||||
self.recaptcha(session, 200)
|
||||
self.recaptcha(session, HTTPStatus.OK)
|
||||
|
||||
# also complete the dummy auth
|
||||
self.register(200, {"auth": {"session": session, "type": "m.login.dummy"}})
|
||||
self.register(
|
||||
HTTPStatus.OK, {"auth": {"session": session, "type": "m.login.dummy"}}
|
||||
)
|
||||
|
||||
# Now we should have fulfilled a complete auth flow, including
|
||||
# the recaptcha fallback step, we can then send a
|
||||
# request to the register API with the session in the authdict.
|
||||
channel = self.register(200, {"auth": {"session": session}})
|
||||
channel = self.register(HTTPStatus.OK, {"auth": {"session": session}})
|
||||
|
||||
# We're given a registered user.
|
||||
self.assertEqual(channel.json_body["user_id"], "@user:test")
|
||||
@@ -137,7 +139,8 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||
# will be used.)
|
||||
# Returns a 401 as per the spec
|
||||
channel = self.register(
|
||||
401, {"username": "user", "type": "m.login.password", "password": "bar"}
|
||||
HTTPStatus.UNAUTHORIZED,
|
||||
{"username": "user", "type": "m.login.password", "password": "bar"},
|
||||
)
|
||||
|
||||
# Grab the session
|
||||
@@ -231,7 +234,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
"""
|
||||
# Attempt to delete this device.
|
||||
# Returns a 401 as per the spec
|
||||
channel = self.delete_device(self.user_tok, self.device_id, 401)
|
||||
channel = self.delete_device(
|
||||
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Grab the session
|
||||
session = channel.json_body["session"]
|
||||
@@ -242,7 +247,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
self.device_id,
|
||||
200,
|
||||
HTTPStatus.OK,
|
||||
{
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
@@ -260,14 +265,16 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
UIA - check that still works.
|
||||
"""
|
||||
|
||||
channel = self.delete_device(self.user_tok, self.device_id, 401)
|
||||
channel = self.delete_device(
|
||||
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
|
||||
)
|
||||
session = channel.json_body["session"]
|
||||
|
||||
# Make another request providing the UI auth flow.
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
self.device_id,
|
||||
200,
|
||||
HTTPStatus.OK,
|
||||
{
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
@@ -293,7 +300,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
|
||||
# Attempt to delete the first device.
|
||||
# Returns a 401 as per the spec
|
||||
channel = self.delete_devices(401, {"devices": [self.device_id]})
|
||||
channel = self.delete_devices(
|
||||
HTTPStatus.UNAUTHORIZED, {"devices": [self.device_id]}
|
||||
)
|
||||
|
||||
# Grab the session
|
||||
session = channel.json_body["session"]
|
||||
@@ -303,7 +312,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
# Make another request providing the UI auth flow, but try to delete the
|
||||
# second device.
|
||||
self.delete_devices(
|
||||
200,
|
||||
HTTPStatus.OK,
|
||||
{
|
||||
"devices": ["dev2"],
|
||||
"auth": {
|
||||
@@ -324,7 +333,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
|
||||
# Attempt to delete the first device.
|
||||
# Returns a 401 as per the spec
|
||||
channel = self.delete_device(self.user_tok, self.device_id, 401)
|
||||
channel = self.delete_device(
|
||||
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
|
||||
)
|
||||
|
||||
# Grab the session
|
||||
session = channel.json_body["session"]
|
||||
@@ -338,7 +349,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
"dev2",
|
||||
403,
|
||||
HTTPStatus.FORBIDDEN,
|
||||
{
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
@@ -361,13 +372,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
self.login("test", self.user_pass, "dev3")
|
||||
|
||||
# Attempt to delete a device. This works since the user just logged in.
|
||||
self.delete_device(self.user_tok, "dev2", 200)
|
||||
self.delete_device(self.user_tok, "dev2", HTTPStatus.OK)
|
||||
|
||||
# Move the clock forward past the validation timeout.
|
||||
self.reactor.advance(6)
|
||||
|
||||
# Deleting another devices throws the user into UI auth.
|
||||
channel = self.delete_device(self.user_tok, "dev3", 401)
|
||||
channel = self.delete_device(self.user_tok, "dev3", HTTPStatus.UNAUTHORIZED)
|
||||
|
||||
# Grab the session
|
||||
session = channel.json_body["session"]
|
||||
@@ -378,7 +389,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
"dev3",
|
||||
200,
|
||||
HTTPStatus.OK,
|
||||
{
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
@@ -393,7 +404,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
# due to re-using the previous session.
|
||||
#
|
||||
# Note that *no auth* information is provided, not even a session iD!
|
||||
self.delete_device(self.user_tok, self.device_id, 200)
|
||||
self.delete_device(self.user_tok, self.device_id, HTTPStatus.OK)
|
||||
|
||||
@skip_unless(HAS_OIDC, "requires OIDC")
|
||||
@override_config({"oidc_config": TEST_OIDC_CONFIG})
|
||||
@@ -413,7 +424,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(login_resp["user_id"], self.user)
|
||||
|
||||
# initiate a UI Auth process by attempting to delete the device
|
||||
channel = self.delete_device(self.user_tok, self.device_id, 401)
|
||||
channel = self.delete_device(
|
||||
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
|
||||
)
|
||||
|
||||
# check that SSO is offered
|
||||
flows = channel.json_body["flows"]
|
||||
@@ -426,13 +439,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
# that should serve a confirmation page
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
||||
|
||||
# and now the delete request should succeed.
|
||||
self.delete_device(
|
||||
self.user_tok,
|
||||
self.device_id,
|
||||
200,
|
||||
HTTPStatus.OK,
|
||||
body={"auth": {"session": session_id}},
|
||||
)
|
||||
|
||||
@@ -445,13 +458,15 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
|
||||
# now call the device deletion API: we should get the option to auth with SSO
|
||||
# and not password.
|
||||
channel = self.delete_device(user_tok, device_id, 401)
|
||||
channel = self.delete_device(user_tok, device_id, HTTPStatus.UNAUTHORIZED)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
|
||||
|
||||
def test_does_not_offer_sso_for_password_user(self):
|
||||
channel = self.delete_device(self.user_tok, self.device_id, 401)
|
||||
channel = self.delete_device(
|
||||
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
|
||||
)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
self.assertEqual(flows, [{"stages": ["m.login.password"]}])
|
||||
@@ -463,7 +478,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
|
||||
self.assertEqual(login_resp["user_id"], self.user)
|
||||
|
||||
channel = self.delete_device(self.user_tok, self.device_id, 401)
|
||||
channel = self.delete_device(
|
||||
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
|
||||
)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
# we have no particular expectations of ordering here
|
||||
@@ -480,7 +497,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(login_resp["user_id"], self.user)
|
||||
|
||||
# start a UI Auth flow by attempting to delete a device
|
||||
channel = self.delete_device(self.user_tok, self.device_id, 401)
|
||||
channel = self.delete_device(
|
||||
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
|
||||
)
|
||||
|
||||
flows = channel.json_body["flows"]
|
||||
self.assertIn({"stages": ["m.login.sso"]}, flows)
|
||||
@@ -496,7 +515,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
|
||||
# ... and the delete op should now fail with a 403
|
||||
self.delete_device(
|
||||
self.user_tok, self.device_id, 403, body={"auth": {"session": session_id}}
|
||||
self.user_tok,
|
||||
self.device_id,
|
||||
HTTPStatus.FORBIDDEN,
|
||||
body={"auth": {"session": session_id}},
|
||||
)
|
||||
|
||||
|
||||
@@ -551,7 +573,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
login_without_refresh = self.make_request(
|
||||
"POST", "/_matrix/client/r0/login", body
|
||||
)
|
||||
self.assertEqual(login_without_refresh.code, 200, login_without_refresh.result)
|
||||
self.assertEqual(
|
||||
login_without_refresh.code, HTTPStatus.OK, login_without_refresh.result
|
||||
)
|
||||
self.assertNotIn("refresh_token", login_without_refresh.json_body)
|
||||
|
||||
login_with_refresh = self.make_request(
|
||||
@@ -559,7 +583,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
"/_matrix/client/r0/login",
|
||||
{"refresh_token": True, **body},
|
||||
)
|
||||
self.assertEqual(login_with_refresh.code, 200, login_with_refresh.result)
|
||||
self.assertEqual(
|
||||
login_with_refresh.code, HTTPStatus.OK, login_with_refresh.result
|
||||
)
|
||||
self.assertIn("refresh_token", login_with_refresh.json_body)
|
||||
self.assertIn("expires_in_ms", login_with_refresh.json_body)
|
||||
|
||||
@@ -577,7 +603,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
register_without_refresh.code, 200, register_without_refresh.result
|
||||
register_without_refresh.code,
|
||||
HTTPStatus.OK,
|
||||
register_without_refresh.result,
|
||||
)
|
||||
self.assertNotIn("refresh_token", register_without_refresh.json_body)
|
||||
|
||||
@@ -591,7 +619,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
"refresh_token": True,
|
||||
},
|
||||
)
|
||||
self.assertEqual(register_with_refresh.code, 200, register_with_refresh.result)
|
||||
self.assertEqual(
|
||||
register_with_refresh.code, HTTPStatus.OK, register_with_refresh.result
|
||||
)
|
||||
self.assertIn("refresh_token", register_with_refresh.json_body)
|
||||
self.assertIn("expires_in_ms", register_with_refresh.json_body)
|
||||
|
||||
@@ -610,14 +640,14 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
"/_matrix/client/r0/login",
|
||||
body,
|
||||
)
|
||||
self.assertEqual(login_response.code, 200, login_response.result)
|
||||
self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
|
||||
|
||||
refresh_response = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/v1/refresh",
|
||||
{"refresh_token": login_response.json_body["refresh_token"]},
|
||||
)
|
||||
self.assertEqual(refresh_response.code, 200, refresh_response.result)
|
||||
self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
|
||||
self.assertIn("access_token", refresh_response.json_body)
|
||||
self.assertIn("refresh_token", refresh_response.json_body)
|
||||
self.assertIn("expires_in_ms", refresh_response.json_body)
|
||||
@@ -648,7 +678,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
"/_matrix/client/r0/login",
|
||||
body,
|
||||
)
|
||||
self.assertEqual(login_response.code, 200, login_response.result)
|
||||
self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
|
||||
self.assertApproximates(
|
||||
login_response.json_body["expires_in_ms"], 60 * 1000, 100
|
||||
)
|
||||
@@ -658,7 +688,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
"/_matrix/client/v1/refresh",
|
||||
{"refresh_token": login_response.json_body["refresh_token"]},
|
||||
)
|
||||
self.assertEqual(refresh_response.code, 200, refresh_response.result)
|
||||
self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
|
||||
self.assertApproximates(
|
||||
refresh_response.json_body["expires_in_ms"], 60 * 1000, 100
|
||||
)
|
||||
@@ -705,7 +735,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
"/_matrix/client/r0/login",
|
||||
{"refresh_token": True, **body},
|
||||
)
|
||||
self.assertEqual(login_response1.code, 200, login_response1.result)
|
||||
self.assertEqual(login_response1.code, HTTPStatus.OK, login_response1.result)
|
||||
self.assertApproximates(
|
||||
login_response1.json_body["expires_in_ms"], 60 * 1000, 100
|
||||
)
|
||||
@@ -716,7 +746,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
"/_matrix/client/r0/login",
|
||||
body,
|
||||
)
|
||||
self.assertEqual(login_response2.code, 200, login_response2.result)
|
||||
self.assertEqual(login_response2.code, HTTPStatus.OK, login_response2.result)
|
||||
nonrefreshable_access_token = login_response2.json_body["access_token"]
|
||||
|
||||
# Advance 59 seconds in the future (just shy of 1 minute, the time of expiry)
|
||||
@@ -818,7 +848,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
"/_matrix/client/r0/login",
|
||||
body,
|
||||
)
|
||||
self.assertEqual(login_response.code, 200, login_response.result)
|
||||
self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
|
||||
refresh_token = login_response.json_body["refresh_token"]
|
||||
|
||||
# Advance shy of 2 minutes into the future
|
||||
@@ -826,7 +856,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
|
||||
# Refresh our session. The refresh token should still be valid right now.
|
||||
refresh_response = self.use_refresh_token(refresh_token)
|
||||
self.assertEqual(refresh_response.code, 200, refresh_response.result)
|
||||
self.assertEqual(refresh_response.code, HTTPStatus.OK, refresh_response.result)
|
||||
self.assertIn(
|
||||
"refresh_token",
|
||||
refresh_response.json_body,
|
||||
@@ -846,7 +876,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
# This should fail because the refresh token's lifetime has also been
|
||||
# diminished as our session expired.
|
||||
refresh_response = self.use_refresh_token(refresh_token)
|
||||
self.assertEqual(refresh_response.code, 403, refresh_response.result)
|
||||
self.assertEqual(
|
||||
refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
|
||||
)
|
||||
|
||||
def test_refresh_token_invalidation(self):
|
||||
"""Refresh tokens are invalidated after first use of the next token.
|
||||
@@ -875,7 +907,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
"/_matrix/client/r0/login",
|
||||
body,
|
||||
)
|
||||
self.assertEqual(login_response.code, 200, login_response.result)
|
||||
self.assertEqual(login_response.code, HTTPStatus.OK, login_response.result)
|
||||
|
||||
# This first refresh should work properly
|
||||
first_refresh_response = self.make_request(
|
||||
@@ -884,7 +916,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
{"refresh_token": login_response.json_body["refresh_token"]},
|
||||
)
|
||||
self.assertEqual(
|
||||
first_refresh_response.code, 200, first_refresh_response.result
|
||||
first_refresh_response.code, HTTPStatus.OK, first_refresh_response.result
|
||||
)
|
||||
|
||||
# This one as well, since the token in the first one was never used
|
||||
@@ -894,7 +926,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
{"refresh_token": login_response.json_body["refresh_token"]},
|
||||
)
|
||||
self.assertEqual(
|
||||
second_refresh_response.code, 200, second_refresh_response.result
|
||||
second_refresh_response.code, HTTPStatus.OK, second_refresh_response.result
|
||||
)
|
||||
|
||||
# This one should not, since the token from the first refresh is not valid anymore
|
||||
@@ -904,7 +936,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
{"refresh_token": first_refresh_response.json_body["refresh_token"]},
|
||||
)
|
||||
self.assertEqual(
|
||||
third_refresh_response.code, 401, third_refresh_response.result
|
||||
third_refresh_response.code,
|
||||
HTTPStatus.UNAUTHORIZED,
|
||||
third_refresh_response.result,
|
||||
)
|
||||
|
||||
# The associated access token should also be invalid
|
||||
@@ -913,7 +947,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
"/_matrix/client/r0/account/whoami",
|
||||
access_token=first_refresh_response.json_body["access_token"],
|
||||
)
|
||||
self.assertEqual(whoami_response.code, 401, whoami_response.result)
|
||||
self.assertEqual(
|
||||
whoami_response.code, HTTPStatus.UNAUTHORIZED, whoami_response.result
|
||||
)
|
||||
|
||||
# But all other tokens should work (they will expire after some time)
|
||||
for access_token in [
|
||||
@@ -923,7 +959,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
whoami_response = self.make_request(
|
||||
"GET", "/_matrix/client/r0/account/whoami", access_token=access_token
|
||||
)
|
||||
self.assertEqual(whoami_response.code, 200, whoami_response.result)
|
||||
self.assertEqual(
|
||||
whoami_response.code, HTTPStatus.OK, whoami_response.result
|
||||
)
|
||||
|
||||
# Now that the access token from the last valid refresh was used once, refreshing with the N-1 token should fail
|
||||
fourth_refresh_response = self.make_request(
|
||||
@@ -932,7 +970,9 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
{"refresh_token": login_response.json_body["refresh_token"]},
|
||||
)
|
||||
self.assertEqual(
|
||||
fourth_refresh_response.code, 403, fourth_refresh_response.result
|
||||
fourth_refresh_response.code,
|
||||
HTTPStatus.FORBIDDEN,
|
||||
fourth_refresh_response.result,
|
||||
)
|
||||
|
||||
# But refreshing from the last valid refresh token still works
|
||||
@@ -942,5 +982,5 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
|
||||
{"refresh_token": second_refresh_response.json_body["refresh_token"]},
|
||||
)
|
||||
self.assertEqual(
|
||||
fifth_refresh_response.code, 200, fifth_refresh_response.result
|
||||
fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import itertools
|
||||
import urllib.parse
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
from synapse.api.constants import EventTypes, RelationTypes
|
||||
from synapse.rest import admin
|
||||
@@ -23,6 +24,8 @@ from synapse.rest.client import login, register, relations, room, sync
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeChannel
|
||||
from tests.test_utils import make_awaitable
|
||||
from tests.test_utils.event_injection import inject_event
|
||||
|
||||
|
||||
class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
@@ -651,6 +654,118 @@ class RelationsTestCase(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
|
||||
def test_ignore_invalid_room(self):
|
||||
"""Test that we ignore invalid relations over federation."""
|
||||
# Create another room and send a message in it.
|
||||
room2 = self.helper.create_room_as(self.user_id, tok=self.user_token)
|
||||
res = self.helper.send(room2, body="Hi!", tok=self.user_token)
|
||||
parent_id = res["event_id"]
|
||||
|
||||
# Disable the validation to pretend this came over federation.
|
||||
with patch(
|
||||
"synapse.handlers.message.EventCreationHandler._validate_event_relation",
|
||||
new=lambda self, event: make_awaitable(None),
|
||||
):
|
||||
# Generate a various relations from a different room.
|
||||
self.get_success(
|
||||
inject_event(
|
||||
self.hs,
|
||||
room_id=self.room,
|
||||
type="m.reaction",
|
||||
sender=self.user_id,
|
||||
content={
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.ANNOTATION,
|
||||
"event_id": parent_id,
|
||||
"key": "A",
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
inject_event(
|
||||
self.hs,
|
||||
room_id=self.room,
|
||||
type="m.room.message",
|
||||
sender=self.user_id,
|
||||
content={
|
||||
"body": "foo",
|
||||
"msgtype": "m.text",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.REFERENCE,
|
||||
"event_id": parent_id,
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
inject_event(
|
||||
self.hs,
|
||||
room_id=self.room,
|
||||
type="m.room.message",
|
||||
sender=self.user_id,
|
||||
content={
|
||||
"body": "foo",
|
||||
"msgtype": "m.text",
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
"event_id": parent_id,
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
self.get_success(
|
||||
inject_event(
|
||||
self.hs,
|
||||
room_id=self.room,
|
||||
type="m.room.message",
|
||||
sender=self.user_id,
|
||||
content={
|
||||
"body": "foo",
|
||||
"msgtype": "m.text",
|
||||
"new_content": {
|
||||
"body": "new content",
|
||||
"msgtype": "m.text",
|
||||
},
|
||||
"m.relates_to": {
|
||||
"rel_type": RelationTypes.REPLACE,
|
||||
"event_id": parent_id,
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# They should be ignored when fetching relations.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/rooms/{room2}/relations/{parent_id}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
self.assertEqual(channel.json_body["chunk"], [])
|
||||
|
||||
# And when fetching aggregations.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/rooms/{room2}/aggregations/{parent_id}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
self.assertEqual(channel.json_body["chunk"], [])
|
||||
|
||||
# And for bundled aggregations.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/rooms/{room2}/event/{parent_id}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
self.assertNotIn("m.relations", channel.json_body["unsigned"])
|
||||
|
||||
def test_edit(self):
|
||||
"""Test that a simple edit works."""
|
||||
|
||||
|
||||
180
tests/rest/client/test_room_batch.py
Normal file
180
tests/rest/client/test_room_batch.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, register, room, room_batch
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_join_state_events_for_batch_send_request(
|
||||
virtual_user_ids: List[str],
|
||||
insert_time: int,
|
||||
) -> List[JsonDict]:
|
||||
return [
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
"sender": virtual_user_id,
|
||||
"origin_server_ts": insert_time,
|
||||
"content": {
|
||||
"membership": "join",
|
||||
"displayname": "display-name-for-%s" % (virtual_user_id,),
|
||||
},
|
||||
"state_key": virtual_user_id,
|
||||
}
|
||||
for virtual_user_id in virtual_user_ids
|
||||
]
|
||||
|
||||
|
||||
def _create_message_events_for_batch_send_request(
|
||||
virtual_user_id: str, insert_time: int, count: int
|
||||
) -> List[JsonDict]:
|
||||
return [
|
||||
{
|
||||
"type": EventTypes.Message,
|
||||
"sender": virtual_user_id,
|
||||
"origin_server_ts": insert_time,
|
||||
"content": {
|
||||
"msgtype": "m.text",
|
||||
"body": "Historical %d" % (i),
|
||||
EventContentFields.MSC2716_HISTORICAL: True,
|
||||
},
|
||||
}
|
||||
for i in range(count)
|
||||
]
|
||||
|
||||
|
||||
class RoomBatchTestCase(unittest.HomeserverTestCase):
|
||||
"""Test importing batches of historical messages."""
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets_for_client_rest_resource,
|
||||
room_batch.register_servlets,
|
||||
room.register_servlets,
|
||||
register.register_servlets,
|
||||
login.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
self.appservice = ApplicationService(
|
||||
token="i_am_an_app_service",
|
||||
hostname="test",
|
||||
id="1234",
|
||||
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
|
||||
# Note: this user does not have to match the regex above
|
||||
sender="@as_main:test",
|
||||
)
|
||||
|
||||
mock_load_appservices = Mock(return_value=[self.appservice])
|
||||
with patch(
|
||||
"synapse.storage.databases.main.appservice.load_appservices",
|
||||
mock_load_appservices,
|
||||
):
|
||||
hs = self.setup_test_homeserver(config=config)
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.clock = clock
|
||||
self.storage = hs.get_storage()
|
||||
|
||||
self.virtual_user_id = self.register_appservice_user(
|
||||
"as_user_potato", self.appservice.token
|
||||
)
|
||||
|
||||
def _create_test_room(self) -> Tuple[str, str, str, str]:
|
||||
room_id = self.helper.create_room_as(
|
||||
self.appservice.sender, tok=self.appservice.token
|
||||
)
|
||||
|
||||
res_a = self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "A",
|
||||
},
|
||||
tok=self.appservice.token,
|
||||
)
|
||||
event_id_a = res_a["event_id"]
|
||||
|
||||
res_b = self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "B",
|
||||
},
|
||||
tok=self.appservice.token,
|
||||
)
|
||||
event_id_b = res_b["event_id"]
|
||||
|
||||
res_c = self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "C",
|
||||
},
|
||||
tok=self.appservice.token,
|
||||
)
|
||||
event_id_c = res_c["event_id"]
|
||||
|
||||
return room_id, event_id_a, event_id_b, event_id_c
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc2716_enabled": True}})
|
||||
def test_same_state_groups_for_whole_historical_batch(self):
|
||||
"""Make sure that when using the `/batch_send` endpoint to import a
|
||||
bunch of historical messages, it re-uses the same `state_group` across
|
||||
the whole batch. This is an easy optimization to make sure we're getting
|
||||
right because the state for the whole batch is contained in
|
||||
`state_events_at_start` and can be shared across everything.
|
||||
"""
|
||||
|
||||
time_before_room = int(self.clock.time_msec())
|
||||
room_id, event_id_a, _, _ = self._create_test_room()
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s"
|
||||
% (room_id, event_id_a),
|
||||
content={
|
||||
"events": _create_message_events_for_batch_send_request(
|
||||
self.virtual_user_id, time_before_room, 3
|
||||
),
|
||||
"state_events_at_start": _create_join_state_events_for_batch_send_request(
|
||||
[self.virtual_user_id], time_before_room
|
||||
),
|
||||
},
|
||||
access_token=self.appservice.token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
# Get the historical event IDs that we just imported
|
||||
historical_event_ids = channel.json_body["event_ids"]
|
||||
self.assertEqual(len(historical_event_ids), 3)
|
||||
|
||||
# Fetch the state_groups
|
||||
state_group_map = self.get_success(
|
||||
self.storage.state.get_state_groups_ids(room_id, historical_event_ids)
|
||||
)
|
||||
|
||||
# We expect all of the historical events to be using the same state_group
|
||||
# so there should only be a single state_group here!
|
||||
self.assertEqual(
|
||||
len(state_group_map.keys()),
|
||||
1,
|
||||
"Expected a single state_group to be returned by saw state_groups=%s"
|
||||
% (state_group_map.keys(),),
|
||||
)
|
||||
Reference in New Issue
Block a user