Merge commit '56efa9ec7' into anoa/dinsic_release_1_21_x
* commit '56efa9ec7': (22 commits) Fix rate limiting unit tests. (#8167) Add functions to `MultiWriterIdGen` used by events stream (#8164) Do not allow send_nonmember_event to be called with shadow-banned users. (#8158) Changelog fixes Make StreamIdGen `get_next` and `get_next_mult` async (#8161) Wording fixes to 'name' user admin api filter (#8163) Fix missing double-backtick in RST document Search in columns 'name' and 'displayname' in the admin users endpoint (#7377) Add type hints for state. (#8140) Stop shadow-banned users from sending non-member events. (#8142) Allow capping a room's retention policy (#8104) Add healthcheck for default localhost 8008 port on /health endpoint. (#8147) Fix flaky shadow-ban tests. (#8152) Don't fail /submit_token requests on incorrect session ID if request_token_inhibit_3pid_errors is turned on (#7991) Do not apply ratelimiting on joins to appservices (#8139) Micro-optimisations to get_auth_chain_ids (#8132) Allow denying or shadow banning registrations via the spam checker (#8034) Stop shadow-banned users from sending invites. (#8095) Be more tolerant of membership events in unknown rooms (#8110) Improve the error code when trying to register using a name reserved for guests. (#8135) ...
This commit is contained in:
+15
-1
@@ -1,10 +1,24 @@
|
||||
For the next release
|
||||
====================
|
||||
|
||||
Removal warning
|
||||
---------------
|
||||
|
||||
Some older clients used a
|
||||
[disallowed character](https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-register-email-requesttoken)
|
||||
(`:`) in the `client_secret` parameter of various endpoints. The incorrect
|
||||
behaviour was allowed for backwards compatibility, but is now being removed
|
||||
from Synapse as most users have updated their client. Further context can be
|
||||
found at [\#6766](https://github.com/matrix-org/synapse/issues/6766).
|
||||
|
||||
|
||||
Synapse 1.19.1rc1 (2020-08-25)
|
||||
==============================
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fixes a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0. ([\#8139](https://github.com/matrix-org/synapse/issues/8139))
|
||||
- Fix a bug introduced in v1.19.0 where appservices with ratelimiting disabled would still be ratelimited when joining rooms. ([\#8139](https://github.com/matrix-org/synapse/issues/8139))
|
||||
- Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests. ([\#8153](https://github.com/matrix-org/synapse/issues/8153))
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH.
|
||||
@@ -0,0 +1 @@
|
||||
Don't fail `/submit_token` requests on incorrect session ID if `request_token_inhibit_3pid_errors` is turned on.
|
||||
@@ -0,0 +1 @@
|
||||
Add support for shadow-banning users (ignoring any message send requests).
|
||||
@@ -0,0 +1 @@
|
||||
Add support for shadow-banning users (ignoring any message send requests).
|
||||
@@ -0,0 +1 @@
|
||||
Fix a bug introduced in v1.7.2 impacting message retention policies that would allow federated homeservers to dictate a retention period that's lower than the configured minimum allowed duration in the configuration file.
|
||||
@@ -0,0 +1 @@
|
||||
Fix a bug introduced in Synapse 1.12.0 which could cause `/sync` requests to fail with a 404 if you had a very old outstanding room invite.
|
||||
@@ -0,0 +1 @@
|
||||
Reduce the amount of whitespace in JSON stored and sent in responses.
|
||||
@@ -0,0 +1 @@
|
||||
Add type hints to `synapse.storage.database`.
|
||||
@@ -0,0 +1 @@
|
||||
Micro-optimisations to get_auth_chain_ids.
|
||||
@@ -0,0 +1 @@
|
||||
Clarify the error code if a user tries to register with a numeric ID. This bug was introduced in v1.15.0.
|
||||
@@ -0,0 +1 @@
|
||||
Fixes a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0.
|
||||
@@ -0,0 +1 @@
|
||||
Add type hints to `synapse.state`.
|
||||
@@ -0,0 +1 @@
|
||||
Add support for shadow-banning users (ignoring any message send requests).
|
||||
@@ -0,0 +1 @@
|
||||
Added curl for healthcheck support and readme updates for the change. Contributed by @maquis196.
|
||||
@@ -0,0 +1 @@
|
||||
Add support for shadow-banning users (ignoring any message send requests).
|
||||
@@ -0,0 +1 @@
|
||||
Add support for shadow-banning users (ignoring any message send requests).
|
||||
@@ -0,0 +1 @@
|
||||
Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface.
|
||||
@@ -0,0 +1 @@
|
||||
Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH.
|
||||
@@ -0,0 +1 @@
|
||||
Add functions to `MultiWriterIdGen` used by events stream.
|
||||
@@ -0,0 +1 @@
|
||||
Fix tests that were broken due to the merge of 1.19.1.
|
||||
@@ -55,6 +55,7 @@ RUN pip install --prefix="/install" --no-warn-script-location \
|
||||
FROM docker.io/python:${PYTHON_VERSION}-slim
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
libpq5 \
|
||||
xmlsec1 \
|
||||
gosu \
|
||||
@@ -69,3 +70,6 @@ VOLUME ["/data"]
|
||||
EXPOSE 8008/tcp 8009/tcp 8448/tcp
|
||||
|
||||
ENTRYPOINT ["/start.py"]
|
||||
|
||||
HEALTHCHECK --interval=1m --timeout=5s \
|
||||
CMD curl -fSs http://localhost:8008/health || exit 1
|
||||
|
||||
@@ -162,3 +162,32 @@ docker build -t matrixdotorg/synapse -f docker/Dockerfile .
|
||||
|
||||
You can choose to build a different docker image by changing the value of the `-f` flag to
|
||||
point to another Dockerfile.
|
||||
|
||||
## Disabling the healthcheck
|
||||
|
||||
If you are using a non-standard port or tls inside docker you can disable the healthcheck
|
||||
whilst running the above `docker run` commands.
|
||||
|
||||
```
|
||||
--no-healthcheck
|
||||
```
|
||||
## Setting custom healthcheck on docker run
|
||||
|
||||
If you wish to point the healthcheck at a different port with docker command, add the following
|
||||
|
||||
```
|
||||
--health-cmd 'curl -fSs http://localhost:1234/health'
|
||||
```
|
||||
|
||||
## Setting the healthcheck in docker-compose file
|
||||
|
||||
You can add the following to set a custom healthcheck in a docker compose file.
|
||||
You will need version >2.1 for this to work.
|
||||
|
||||
```
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-fSs", "http://localhost:8008/health"]
|
||||
interval: 1m
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
```
|
||||
|
||||
@@ -108,7 +108,7 @@ The api is::
|
||||
|
||||
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
|
||||
|
||||
To use it, you will need to authenticate by providing an `access_token` for a
|
||||
To use it, you will need to authenticate by providing an ``access_token`` for a
|
||||
server admin: see `README.rst <README.rst>`_.
|
||||
|
||||
The parameter ``from`` is optional but used for pagination, denoting the
|
||||
@@ -119,8 +119,11 @@ from a previous call.
|
||||
The parameter ``limit`` is optional but is used for pagination, denoting the
|
||||
maximum number of items to return in this call. Defaults to ``100``.
|
||||
|
||||
The parameter ``user_id`` is optional and filters to only users with user IDs
|
||||
that contain this value.
|
||||
The parameter ``user_id`` is optional and filters to only return users with user IDs
|
||||
that contain this value. This parameter is ignored when using the ``name`` parameter.
|
||||
|
||||
The parameter ``name`` is optional and filters to only return users with user ID localparts
|
||||
**or** displaynames that contain this value.
|
||||
|
||||
The parameter ``guests`` is optional and if ``false`` will **exclude** guest users.
|
||||
Defaults to ``true`` to include guest users.
|
||||
|
||||
+14
-8
@@ -446,11 +446,10 @@ retention:
|
||||
# min_lifetime: 1d
|
||||
# max_lifetime: 1y
|
||||
|
||||
# Retention policy limits. If set, a user won't be able to send a
|
||||
# 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
|
||||
# that's not within this range. This is especially useful in closed federations,
|
||||
# in which server admins can make sure every federating server applies the same
|
||||
# rules.
|
||||
# Retention policy limits. If set, and the state of a room contains a
|
||||
# 'm.room.retention' event in its state which contains a 'min_lifetime' or a
|
||||
# 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
|
||||
# to these limits when running purge jobs.
|
||||
#
|
||||
#allowed_lifetime_min: 1d
|
||||
#allowed_lifetime_max: 1y
|
||||
@@ -476,12 +475,19 @@ retention:
|
||||
# (e.g. every 12h), but not want that purge to be performed by a job that's
|
||||
# iterating over every room it knows, which could be heavy on the server.
|
||||
#
|
||||
# If any purge job is configured, it is strongly recommended to have at least
|
||||
# a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
|
||||
# set, or one job without 'shortest_max_lifetime' and one job without
|
||||
# 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
|
||||
# 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
|
||||
# room's policy to these values is done after the policies are retrieved from
|
||||
# Synapse's database (which is done using the range specified in a purge job's
|
||||
# configuration).
|
||||
#
|
||||
#purge_jobs:
|
||||
# - shortest_max_lifetime: 1d
|
||||
# longest_max_lifetime: 3d
|
||||
# - longest_max_lifetime: 3d
|
||||
# interval: 12h
|
||||
# - shortest_max_lifetime: 3d
|
||||
# longest_max_lifetime: 1y
|
||||
# interval: 1d
|
||||
|
||||
# Inhibits the /requestToken endpoints from returning an error that might leak
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Stub for frozendict.
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Hashable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Mapping,
|
||||
overload,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
_KT = TypeVar("_KT", bound=Hashable) # Key type.
|
||||
_VT = TypeVar("_VT") # Value type.
|
||||
|
||||
class frozendict(Mapping[_KT, _VT]):
|
||||
@overload
|
||||
def __init__(self, **kwargs: _VT) -> None: ...
|
||||
@overload
|
||||
def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
|
||||
) -> None: ...
|
||||
def __getitem__(self, key: _KT) -> _VT: ...
|
||||
def __contains__(self, key: Any) -> bool: ...
|
||||
def copy(self, **add_or_replace: Any) -> frozendict: ...
|
||||
def __iter__(self) -> Iterator[_KT]: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __hash__(self) -> int: ...
|
||||
@@ -605,3 +605,11 @@ class HttpResponseException(CodeMessageException):
|
||||
errmsg = j.pop("error", self.msg)
|
||||
|
||||
return ProxiedRequestError(self.code, errmsg, errcode, j)
|
||||
|
||||
|
||||
class ShadowBanError(Exception):
|
||||
"""
|
||||
Raised when a shadow-banned user attempts to perform an action.
|
||||
|
||||
This should be caught and a proper "fake" success response sent to the user.
|
||||
"""
|
||||
|
||||
@@ -1048,11 +1048,10 @@ class ServerConfig(Config):
|
||||
# min_lifetime: 1d
|
||||
# max_lifetime: 1y
|
||||
|
||||
# Retention policy limits. If set, a user won't be able to send a
|
||||
# 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
|
||||
# that's not within this range. This is especially useful in closed federations,
|
||||
# in which server admins can make sure every federating server applies the same
|
||||
# rules.
|
||||
# Retention policy limits. If set, and the state of a room contains a
|
||||
# 'm.room.retention' event in its state which contains a 'min_lifetime' or a
|
||||
# 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
|
||||
# to these limits when running purge jobs.
|
||||
#
|
||||
#allowed_lifetime_min: 1d
|
||||
#allowed_lifetime_max: 1y
|
||||
@@ -1078,12 +1077,19 @@ class ServerConfig(Config):
|
||||
# (e.g. every 12h), but not want that purge to be performed by a job that's
|
||||
# iterating over every room it knows, which could be heavy on the server.
|
||||
#
|
||||
# If any purge job is configured, it is strongly recommended to have at least
|
||||
# a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
|
||||
# set, or one job without 'shortest_max_lifetime' and one job without
|
||||
# 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
|
||||
# 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
|
||||
# room's policy to these values is done after the policies are retrieved from
|
||||
# Synapse's database (which is done using the range specified in a purge job's
|
||||
# configuration).
|
||||
#
|
||||
#purge_jobs:
|
||||
# - shortest_max_lifetime: 1d
|
||||
# longest_max_lifetime: 3d
|
||||
# - longest_max_lifetime: 3d
|
||||
# interval: 12h
|
||||
# - shortest_max_lifetime: 3d
|
||||
# longest_max_lifetime: 1y
|
||||
# interval: 1d
|
||||
|
||||
# Inhibits the /requestToken endpoints from returning an error that might leak
|
||||
|
||||
@@ -133,6 +133,8 @@ class _EventInternalMetadata(object):
|
||||
rejection. This is needed as those events are marked as outliers, but
|
||||
they still need to be processed as if they're new events (e.g. updating
|
||||
invite state in the database, relaying to clients, etc).
|
||||
|
||||
(Added in synapse 0.99.0, so may be unreliable for events received before that)
|
||||
"""
|
||||
return self._dict.get("out_of_band_membership", False)
|
||||
|
||||
|
||||
@@ -15,9 +15,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.spam_checker_api import SpamCheckerApi
|
||||
from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
|
||||
from synapse.types import Collection
|
||||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
@@ -219,3 +220,33 @@ class SpamChecker(object):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def check_registration_for_spam(
|
||||
self,
|
||||
email_threepid: Optional[dict],
|
||||
username: Optional[str],
|
||||
request_info: Collection[Tuple[str, str]],
|
||||
) -> RegistrationBehaviour:
|
||||
"""Checks if we should allow the given registration request.
|
||||
|
||||
Args:
|
||||
email_threepid: The email threepid used for registering, if any
|
||||
username: The request user name, if any
|
||||
request_info: List of tuples of user agent and IP that
|
||||
were used during the registration process.
|
||||
|
||||
Returns:
|
||||
Enum for how the request should be handled
|
||||
"""
|
||||
|
||||
for spam_checker in self.spam_checkers:
|
||||
# For backwards compatibility, only run if the method exists on the
|
||||
# spam checker
|
||||
checker = getattr(spam_checker, "check_registration_for_spam", None)
|
||||
if checker:
|
||||
behaviour = checker(email_threepid, username, request_info)
|
||||
assert isinstance(behaviour, RegistrationBehaviour)
|
||||
if behaviour != RegistrationBehaviour.ALLOW:
|
||||
return behaviour
|
||||
|
||||
return RegistrationBehaviour.ALLOW
|
||||
|
||||
@@ -74,15 +74,14 @@ class EventValidator(object):
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Retention:
|
||||
self._validate_retention(event, config)
|
||||
self._validate_retention(event)
|
||||
|
||||
def _validate_retention(self, event, config):
|
||||
def _validate_retention(self, event):
|
||||
"""Checks that an event that defines the retention policy for a room respects the
|
||||
boundaries imposed by the server's administrator.
|
||||
format enforced by the spec.
|
||||
|
||||
Args:
|
||||
event (FrozenEvent): The event to validate.
|
||||
config (Config): The homeserver's configuration.
|
||||
"""
|
||||
min_lifetime = event.content.get("min_lifetime")
|
||||
max_lifetime = event.content.get("max_lifetime")
|
||||
@@ -95,32 +94,6 @@ class EventValidator(object):
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if (
|
||||
config.retention_allowed_lifetime_min is not None
|
||||
and min_lifetime < config.retention_allowed_lifetime_min
|
||||
):
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg=(
|
||||
"'min_lifetime' can't be lower than the minimum allowed"
|
||||
" value enforced by the server's administrator"
|
||||
),
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if (
|
||||
config.retention_allowed_lifetime_max is not None
|
||||
and min_lifetime > config.retention_allowed_lifetime_max
|
||||
):
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg=(
|
||||
"'min_lifetime' can't be greater than the maximum allowed"
|
||||
" value enforced by the server's administrator"
|
||||
),
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if max_lifetime is not None:
|
||||
if not isinstance(max_lifetime, int):
|
||||
raise SynapseError(
|
||||
@@ -129,32 +102,6 @@ class EventValidator(object):
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if (
|
||||
config.retention_allowed_lifetime_min is not None
|
||||
and max_lifetime < config.retention_allowed_lifetime_min
|
||||
):
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg=(
|
||||
"'max_lifetime' can't be lower than the minimum allowed value"
|
||||
" enforced by the server's administrator"
|
||||
),
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if (
|
||||
config.retention_allowed_lifetime_max is not None
|
||||
and max_lifetime > config.retention_allowed_lifetime_max
|
||||
):
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg=(
|
||||
"'max_lifetime' can't be greater than the maximum allowed"
|
||||
" value enforced by the server's administrator"
|
||||
),
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if (
|
||||
min_lifetime is not None
|
||||
and max_lifetime is not None
|
||||
|
||||
@@ -329,10 +329,10 @@ class FederationSender(object):
|
||||
room_id = receipt.room_id
|
||||
|
||||
# Work out which remote servers should be poked and poke them.
|
||||
domains = await self.state.get_current_hosts_in_room(room_id)
|
||||
domains_set = await self.state.get_current_hosts_in_room(room_id)
|
||||
domains = [
|
||||
d
|
||||
for d in domains
|
||||
for d in domains_set
|
||||
if d != self.server_name
|
||||
and self._federation_shard_config.should_handle(self._instance_name, d)
|
||||
]
|
||||
|
||||
@@ -364,6 +364,14 @@ class AuthHandler(BaseHandler):
|
||||
# authentication flow.
|
||||
await self.store.set_ui_auth_clientdict(sid, clientdict)
|
||||
|
||||
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
|
||||
0
|
||||
].decode("ascii", "surrogateescape")
|
||||
|
||||
await self.store.add_user_agent_ip_to_ui_auth_session(
|
||||
session.session_id, user_agent, clientip
|
||||
)
|
||||
|
||||
if not authdict:
|
||||
raise InteractiveAuthIncompleteError(
|
||||
session.session_id, self._auth_dict_for_flows(flows, session.session_id)
|
||||
|
||||
@@ -35,6 +35,7 @@ class CasHandler:
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self._hostname = hs.hostname
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._registration_handler = hs.get_registration_handler()
|
||||
@@ -210,8 +211,16 @@ class CasHandler:
|
||||
|
||||
else:
|
||||
if not registered_user_id:
|
||||
# Pull out the user-agent and IP from the request.
|
||||
user_agent = request.requestHeaders.getRawHeaders(
|
||||
b"User-Agent", default=[b""]
|
||||
)[0].decode("ascii", "surrogateescape")
|
||||
ip_address = self.hs.get_ip_from_request(request)
|
||||
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart, default_display_name=user_display_name
|
||||
localpart=localpart,
|
||||
default_display_name=user_display_name,
|
||||
user_agent_ips=(user_agent, ip_address),
|
||||
)
|
||||
|
||||
await self._auth_handler.complete_sso_login(
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.logging.opentracing import (
|
||||
@@ -27,6 +25,7 @@ from synapse.logging.opentracing import (
|
||||
start_active_span,
|
||||
)
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -174,7 +173,7 @@ class DeviceMessageHandler(object):
|
||||
"sender": sender_user_id,
|
||||
"type": message_type,
|
||||
"message_id": message_id,
|
||||
"org.matrix.opentracing_context": json.dumps(context),
|
||||
"org.matrix.opentracing_context": json_encoder.encode(context),
|
||||
}
|
||||
|
||||
log_kv({"local_messages": local_messages})
|
||||
|
||||
@@ -23,6 +23,7 @@ from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
Codes,
|
||||
NotFoundError,
|
||||
ShadowBanError,
|
||||
StoreError,
|
||||
SynapseError,
|
||||
)
|
||||
@@ -200,6 +201,8 @@ class DirectoryHandler(BaseHandler):
|
||||
|
||||
try:
|
||||
await self._update_canonical_alias(requester, user_id, room_id, room_alias)
|
||||
except ShadowBanError as e:
|
||||
logger.info("Failed to update alias events due to shadow-ban: %s", e)
|
||||
except AuthError as e:
|
||||
logger.info("Failed to update alias events: %s", e)
|
||||
|
||||
@@ -293,6 +296,9 @@ class DirectoryHandler(BaseHandler):
|
||||
"""
|
||||
Send an updated canonical alias event if the removed alias was set as
|
||||
the canonical alias or listed in the alt_aliases field.
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester has been shadow-banned.
|
||||
"""
|
||||
alias_event = await self.state.get_current_state(
|
||||
room_id, EventTypes.CanonicalAlias, ""
|
||||
|
||||
@@ -2144,10 +2144,10 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
state_sets = list(state_sets.values())
|
||||
state_sets.append(state)
|
||||
current_state_ids = await self.state_handler.resolve_events(
|
||||
current_states = await self.state_handler.resolve_events(
|
||||
room_version, state_sets, event
|
||||
)
|
||||
current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
|
||||
current_state_ids = {k: e.event_id for k, e in current_states.items()}
|
||||
else:
|
||||
current_state_ids = await self.state_handler.get_current_state_ids(
|
||||
event.room_id, latest_event_ids=extrem_ids
|
||||
@@ -2159,9 +2159,11 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
# Now check if event pass auth against said current state
|
||||
auth_types = auth_types_for_event(event)
|
||||
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
|
||||
current_state_ids_list = [
|
||||
e for k, e in current_state_ids.items() if k in auth_types
|
||||
]
|
||||
|
||||
auth_events_map = await self.store.get_events(current_state_ids)
|
||||
auth_events_map = await self.store.get_events(current_state_ids_list)
|
||||
current_auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_events_map.values()
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
@@ -34,6 +35,7 @@ from synapse.api.errors import (
|
||||
Codes,
|
||||
ConsentNotGivenError,
|
||||
NotFoundError,
|
||||
ShadowBanError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
||||
@@ -648,24 +650,35 @@ class EventCreationHandler(object):
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
ratelimit: bool = True,
|
||||
ignore_shadow_ban: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Persists and notifies local clients and federation of an event.
|
||||
|
||||
Args:
|
||||
requester
|
||||
event the event to send.
|
||||
context: the context of the event.
|
||||
requester: The requester sending the event.
|
||||
event: The event to send.
|
||||
context: The context of the event.
|
||||
ratelimit: Whether to rate limit this send.
|
||||
ignore_shadow_ban: True if shadow-banned users should be allowed to
|
||||
send this event.
|
||||
|
||||
Return:
|
||||
The stream_id of the persisted event.
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester has been shadow-banned.
|
||||
"""
|
||||
if event.type == EventTypes.Member:
|
||||
raise SynapseError(
|
||||
500, "Tried to send member event through non-member codepath"
|
||||
)
|
||||
|
||||
if not ignore_shadow_ban and requester.shadow_banned:
|
||||
# We randomly sleep a bit just to annoy the requester.
|
||||
await self.clock.sleep(random.randint(1, 10))
|
||||
raise ShadowBanError()
|
||||
|
||||
user = UserID.from_string(event.sender)
|
||||
|
||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
@@ -719,12 +732,28 @@ class EventCreationHandler(object):
|
||||
event_dict: dict,
|
||||
ratelimit: bool = True,
|
||||
txn_id: Optional[str] = None,
|
||||
ignore_shadow_ban: bool = False,
|
||||
) -> Tuple[EventBase, int]:
|
||||
"""
|
||||
Creates an event, then sends it.
|
||||
|
||||
See self.create_event and self.send_nonmember_event.
|
||||
|
||||
Args:
|
||||
requester: The requester sending the event.
|
||||
event_dict: An entire event.
|
||||
ratelimit: Whether to rate limit this send.
|
||||
txn_id: The transaction ID.
|
||||
ignore_shadow_ban: True if shadow-banned users should be allowed to
|
||||
send this event.
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester has been shadow-banned.
|
||||
"""
|
||||
if not ignore_shadow_ban and requester.shadow_banned:
|
||||
# We randomly sleep a bit just to annoy the requester.
|
||||
await self.clock.sleep(random.randint(1, 10))
|
||||
raise ShadowBanError()
|
||||
|
||||
# We limit the number of concurrent event sends in a room so that we
|
||||
# don't fork the DAG too much. If we don't limit then we can end up in
|
||||
@@ -743,7 +772,11 @@ class EventCreationHandler(object):
|
||||
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
|
||||
|
||||
stream_id = await self.send_nonmember_event(
|
||||
requester, event, context, ratelimit=ratelimit
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=ratelimit,
|
||||
ignore_shadow_ban=ignore_shadow_ban,
|
||||
)
|
||||
return event, stream_id
|
||||
|
||||
@@ -1183,8 +1216,14 @@ class EventCreationHandler(object):
|
||||
|
||||
event.internal_metadata.proactively_send = False
|
||||
|
||||
# Since this is a dummy-event it is OK if it is sent by a
|
||||
# shadow-banned user.
|
||||
await self.send_nonmember_event(
|
||||
requester, event, context, ratelimit=False
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=False,
|
||||
ignore_shadow_ban=True,
|
||||
)
|
||||
dummy_event_sent = True
|
||||
break
|
||||
|
||||
@@ -93,6 +93,7 @@ class OidcHandler:
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self._callback_url = hs.config.oidc_callback_url # type: str
|
||||
self._scopes = hs.config.oidc_scopes # type: List[str]
|
||||
self._client_auth = ClientAuth(
|
||||
@@ -689,9 +690,17 @@ class OidcHandler:
|
||||
self._render_error(request, "invalid_token", str(e))
|
||||
return
|
||||
|
||||
# Pull out the user-agent and IP from the request.
|
||||
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
|
||||
0
|
||||
].decode("ascii", "surrogateescape")
|
||||
ip_address = self.hs.get_ip_from_request(request)
|
||||
|
||||
# Call the mapper to register/login the user
|
||||
try:
|
||||
user_id = await self._map_userinfo_to_user(userinfo, token)
|
||||
user_id = await self._map_userinfo_to_user(
|
||||
userinfo, token, user_agent, ip_address
|
||||
)
|
||||
except MappingException as e:
|
||||
logger.exception("Could not map user")
|
||||
self._render_error(request, "mapping_error", str(e))
|
||||
@@ -828,7 +837,9 @@ class OidcHandler:
|
||||
now = self._clock.time_msec()
|
||||
return now < expiry
|
||||
|
||||
async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
|
||||
async def _map_userinfo_to_user(
|
||||
self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
|
||||
) -> str:
|
||||
"""Maps a UserInfo object to a mxid.
|
||||
|
||||
UserInfo should have a claim that uniquely identifies users. This claim
|
||||
@@ -843,6 +854,8 @@ class OidcHandler:
|
||||
Args:
|
||||
userinfo: an object representing the user
|
||||
token: a dict with the tokens obtained from the provider
|
||||
user_agent: The user agent of the client making the request.
|
||||
ip_address: The IP address of the client making the request.
|
||||
|
||||
Raises:
|
||||
MappingException: if there was an error while mapping some properties
|
||||
@@ -899,7 +912,9 @@ class OidcHandler:
|
||||
# It's the first time this user is logging in and the mapped mxid was
|
||||
# not taken, register the user
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart, default_display_name=attributes["display_name"],
|
||||
localpart=localpart,
|
||||
default_display_name=attributes["display_name"],
|
||||
user_agent_ips=(user_agent, ip_address),
|
||||
)
|
||||
|
||||
await self._datastore.record_user_external_id(
|
||||
|
||||
@@ -82,6 +82,9 @@ class PaginationHandler(object):
|
||||
|
||||
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
|
||||
|
||||
self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
|
||||
self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
|
||||
|
||||
if hs.config.retention_enabled:
|
||||
# Run the purge jobs described in the configuration file.
|
||||
for job in hs.config.retention_purge_jobs:
|
||||
@@ -111,7 +114,7 @@ class PaginationHandler(object):
|
||||
the range to handle (inclusive). If None, it means that the range has no
|
||||
upper limit.
|
||||
"""
|
||||
# We want the storage layer to to include rooms with no retention policy in its
|
||||
# We want the storage layer to include rooms with no retention policy in its
|
||||
# return value only if a default retention policy is defined in the server's
|
||||
# configuration and that policy's 'max_lifetime' is either lower (or equal) than
|
||||
# max_ms or higher than min_ms (or both).
|
||||
@@ -152,13 +155,32 @@ class PaginationHandler(object):
|
||||
)
|
||||
continue
|
||||
|
||||
max_lifetime = retention_policy["max_lifetime"]
|
||||
# If max_lifetime is None, it means that the room has no retention policy.
|
||||
# Given we only retrieve such rooms when there's a default retention policy
|
||||
# defined in the server's configuration, we can safely assume that's the
|
||||
# case and use it for this room.
|
||||
max_lifetime = (
|
||||
retention_policy["max_lifetime"] or self._retention_default_max_lifetime
|
||||
)
|
||||
|
||||
if max_lifetime is None:
|
||||
# If max_lifetime is None, it means that include_null equals True,
|
||||
# therefore we can safely assume that there is a default policy defined
|
||||
# in the server's configuration.
|
||||
max_lifetime = self._retention_default_max_lifetime
|
||||
# Cap the effective max_lifetime to be within the range allowed in the
|
||||
# config.
|
||||
# We do this in two steps:
|
||||
# 1. Make sure it's higher or equal to the minimum allowed value, and if
|
||||
# it's not replace it with that value. This is because the server
|
||||
# operator can be required to not delete information before a given
|
||||
# time, e.g. to comply with freedom of information laws.
|
||||
# 2. Make sure the resulting value is lower or equal to the maximum allowed
|
||||
# value, and if it's not replace it with that value. This is because the
|
||||
# server operator can be required to delete any data after a specific
|
||||
# amount of time.
|
||||
if self._retention_allowed_lifetime_min is not None:
|
||||
max_lifetime = max(self._retention_allowed_lifetime_min, max_lifetime)
|
||||
|
||||
if self._retention_allowed_lifetime_max is not None:
|
||||
max_lifetime = min(max_lifetime, self._retention_allowed_lifetime_max)
|
||||
|
||||
logger.debug("[purge] max_lifetime for room %s: %s", room_id, max_lifetime)
|
||||
|
||||
# Figure out what token we should start purging at.
|
||||
ts = self.clock.time_msec() - max_lifetime
|
||||
|
||||
@@ -40,7 +40,7 @@ from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.metrics import Measure
|
||||
@@ -1318,7 +1318,7 @@ async def get_interested_parties(
|
||||
|
||||
async def get_interested_remotes(
|
||||
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
|
||||
) -> List[Tuple[List[str], List[UserPresenceState]]]:
|
||||
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
|
||||
"""Given a list of presence states figure out which remote servers
|
||||
should be sent which.
|
||||
|
||||
@@ -1334,7 +1334,7 @@ async def get_interested_remotes(
|
||||
each tuple the list of UserPresenceState should be sent to each
|
||||
destination
|
||||
"""
|
||||
hosts_and_states = []
|
||||
hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]]
|
||||
|
||||
# First we look up the rooms each user is in (as well as any explicit
|
||||
# subscriptions), then for each distinct room we look up the remote
|
||||
|
||||
@@ -26,6 +26,7 @@ from synapse.replication.http.register import (
|
||||
ReplicationPostRegisterActionsServlet,
|
||||
ReplicationRegisterServlet,
|
||||
)
|
||||
from synapse.spam_checker_api import RegistrationBehaviour
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import RoomAlias, UserID, create_requester
|
||||
|
||||
@@ -53,6 +54,8 @@ class RegistrationHandler(BaseHandler):
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
self._show_in_user_directory = self.hs.config.show_users_in_user_directory
|
||||
|
||||
if hs.config.worker_app:
|
||||
@@ -139,7 +142,9 @@ class RegistrationHandler(BaseHandler):
|
||||
try:
|
||||
int(localpart)
|
||||
raise SynapseError(
|
||||
400, "Numeric user IDs are reserved for guest users."
|
||||
400,
|
||||
"Numeric user IDs are reserved for guest users.",
|
||||
errcode=Codes.INVALID_USERNAME,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
@@ -157,7 +162,7 @@ class RegistrationHandler(BaseHandler):
|
||||
address=None,
|
||||
bind_emails=[],
|
||||
by_admin=False,
|
||||
shadow_banned=False,
|
||||
user_agent_ips=None,
|
||||
):
|
||||
"""Registers a new client on the server.
|
||||
|
||||
@@ -175,7 +180,8 @@ class RegistrationHandler(BaseHandler):
|
||||
bind_emails (List[str]): list of emails to bind to this account.
|
||||
by_admin (bool): True if this registration is being made via the
|
||||
admin api, otherwise False.
|
||||
shadow_banned (bool): Shadow-ban the created user.
|
||||
user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
|
||||
during the registration process.
|
||||
Returns:
|
||||
str: user_id
|
||||
Raises:
|
||||
@@ -183,6 +189,24 @@ class RegistrationHandler(BaseHandler):
|
||||
"""
|
||||
self.check_registration_ratelimit(address)
|
||||
|
||||
result = self.spam_checker.check_registration_for_spam(
|
||||
threepid, localpart, user_agent_ips or [],
|
||||
)
|
||||
|
||||
if result == RegistrationBehaviour.DENY:
|
||||
logger.info(
|
||||
"Blocked registration of %r", localpart,
|
||||
)
|
||||
# We return a 429 to make it not obvious that they've been
|
||||
# denied.
|
||||
raise SynapseError(429, "Rate limited")
|
||||
|
||||
shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
|
||||
if shadow_banned:
|
||||
logger.info(
|
||||
"Shadow banning registration of %r", localpart,
|
||||
)
|
||||
|
||||
# do not check_auth_blocking if the call is coming through the Admin API
|
||||
if not by_admin:
|
||||
await self.auth.check_auth_blocking(threepid=threepid)
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import string
|
||||
from collections import OrderedDict
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
|
||||
@@ -135,6 +136,9 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
Returns:
|
||||
the new room id
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester is shadow-banned.
|
||||
"""
|
||||
await self.ratelimit(requester)
|
||||
|
||||
@@ -170,6 +174,15 @@ class RoomCreationHandler(BaseHandler):
|
||||
async def _upgrade_room(
|
||||
self, requester: Requester, old_room_id: str, new_version: RoomVersion
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
requester: the user requesting the upgrade
|
||||
old_room_id: the id of the room to be replaced
|
||||
new_versions: the version to upgrade the room to
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester is shadow-banned.
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
# start by allocating a new room id
|
||||
@@ -256,6 +269,9 @@ class RoomCreationHandler(BaseHandler):
|
||||
old_room_id: the id of the room to be replaced
|
||||
new_room_id: the id of the replacement room
|
||||
old_room_state: the state map for the old room
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester is shadow-banned.
|
||||
"""
|
||||
old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
|
||||
|
||||
@@ -644,6 +660,8 @@ class RoomCreationHandler(BaseHandler):
|
||||
if mapping:
|
||||
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
|
||||
|
||||
invite_3pid_list = config.get("invite_3pid", [])
|
||||
invite_list = config.get("invite", [])
|
||||
for i in invite_list:
|
||||
try:
|
||||
uid = UserID.from_string(i)
|
||||
@@ -651,6 +669,14 @@ class RoomCreationHandler(BaseHandler):
|
||||
except Exception:
|
||||
raise SynapseError(400, "Invalid user_id: %s" % (i,))
|
||||
|
||||
if (invite_list or invite_3pid_list) and requester.shadow_banned:
|
||||
# We randomly sleep a bit just to annoy the requester.
|
||||
await self.clock.sleep(random.randint(1, 10))
|
||||
|
||||
# Allow the request to go through, but remove any associated invites.
|
||||
invite_3pid_list = []
|
||||
invite_list = []
|
||||
|
||||
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
|
||||
|
||||
power_level_content_override = config.get("power_level_content_override")
|
||||
@@ -768,6 +794,8 @@ class RoomCreationHandler(BaseHandler):
|
||||
if is_direct:
|
||||
content["is_direct"] = is_direct
|
||||
|
||||
# Note that update_membership with an action of "invite" can raise a
|
||||
# ShadowBanError, but this was handled above by emptying invite_list.
|
||||
_, last_stream_id = await self.room_member_handler.update_membership(
|
||||
requester,
|
||||
UserID.from_string(invitee),
|
||||
@@ -783,6 +811,8 @@ class RoomCreationHandler(BaseHandler):
|
||||
id_access_token = invite_3pid.get("id_access_token") # optional
|
||||
address = invite_3pid["address"]
|
||||
medium = invite_3pid["medium"]
|
||||
# Note that do_3pid_invite can raise a ShadowBanError, but this was
|
||||
# handled above by emptying invite_3pid_list.
|
||||
last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
|
||||
room_id,
|
||||
requester.user,
|
||||
@@ -843,11 +873,13 @@ class RoomCreationHandler(BaseHandler):
|
||||
async def send(etype: str, content: JsonDict, **kwargs) -> int:
|
||||
event = create(etype, content, **kwargs)
|
||||
logger.debug("Sending %s in new room", etype)
|
||||
# Allow these events to be sent even if the user is shadow-banned to
|
||||
# allow the room creation to complete.
|
||||
(
|
||||
_,
|
||||
last_stream_id,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
creator, event, ratelimit=False
|
||||
creator, event, ratelimit=False, ignore_shadow_ban=True,
|
||||
)
|
||||
return last_stream_id
|
||||
|
||||
|
||||
@@ -15,14 +15,21 @@
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import random
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from synapse import types
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
|
||||
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
LimitExceededError,
|
||||
ShadowBanError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.api.room_versions import EventFormatVersions
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
@@ -31,7 +38,15 @@ from synapse.events.builder import create_local_event_from_event_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
|
||||
from synapse.types import (
|
||||
Collection,
|
||||
JsonDict,
|
||||
Requester,
|
||||
RoomAlias,
|
||||
RoomID,
|
||||
StateMap,
|
||||
UserID,
|
||||
)
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.distributor import user_joined_room, user_left_room
|
||||
|
||||
@@ -303,7 +318,31 @@ class RoomMemberHandler(object):
|
||||
new_room: bool = False,
|
||||
require_consent: bool = True,
|
||||
) -> Tuple[str, int]:
|
||||
"""Update a user's membership in a room"""
|
||||
"""Update a user's membership in a room.
|
||||
|
||||
Params:
|
||||
requester: The user who is performing the update.
|
||||
target: The user whose membership is being updated.
|
||||
room_id: The room ID whose membership is being updated.
|
||||
action: The membership change, see synapse.api.constants.Membership.
|
||||
txn_id: The transaction ID, if given.
|
||||
remote_room_hosts: Remote servers to send the update to.
|
||||
third_party_signed: Information from a 3PID invite.
|
||||
ratelimit: Whether to rate limit the request.
|
||||
content: The content of the created event.
|
||||
require_consent: Whether consent is required.
|
||||
|
||||
Returns:
|
||||
A tuple of the new event ID and stream ID.
|
||||
|
||||
Raises:
|
||||
ShadowBanError if a shadow-banned requester attempts to send an invite.
|
||||
"""
|
||||
if action == Membership.INVITE and requester.shadow_banned:
|
||||
# We randomly sleep a bit just to annoy the requester.
|
||||
await self.clock.sleep(random.randint(1, 10))
|
||||
raise ShadowBanError()
|
||||
|
||||
key = (room_id,)
|
||||
|
||||
with (await self.member_linearizer.queue(key)):
|
||||
@@ -741,9 +780,7 @@ class RoomMemberHandler(object):
|
||||
if prev_member_event.membership == Membership.JOIN:
|
||||
await self._user_left_room(target_user, room_id)
|
||||
|
||||
async def _can_guest_join(
|
||||
self, current_state_ids: Dict[Tuple[str, str], str]
|
||||
) -> bool:
|
||||
async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
|
||||
"""
|
||||
Returns whether a guest can join a room based on its current state.
|
||||
"""
|
||||
@@ -811,6 +848,25 @@ class RoomMemberHandler(object):
|
||||
new_room: bool = False,
|
||||
id_access_token: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Invite a 3PID to a room.
|
||||
|
||||
Args:
|
||||
room_id: The room to invite the 3PID to.
|
||||
inviter: The user sending the invite.
|
||||
medium: The 3PID's medium.
|
||||
address: The 3PID's address.
|
||||
id_server: The identity server to use.
|
||||
requester: The user making the request.
|
||||
txn_id: The transaction ID this is part of, or None if this is not
|
||||
part of a transaction.
|
||||
id_access_token: The optional identity server access token.
|
||||
|
||||
Returns:
|
||||
The new stream ID.
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester has been shadow-banned.
|
||||
"""
|
||||
if self.config.block_non_admin_invites:
|
||||
is_requester_admin = await self.auth.is_server_admin(requester.user)
|
||||
if not is_requester_admin:
|
||||
@@ -818,6 +874,11 @@ class RoomMemberHandler(object):
|
||||
403, "Invites have been disabled on this server", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
if requester.shadow_banned:
|
||||
# We randomly sleep a bit just to annoy the requester.
|
||||
await self.clock.sleep(random.randint(1, 10))
|
||||
raise ShadowBanError()
|
||||
|
||||
# We need to rate limit *before* we send out any 3PID invites, so we
|
||||
# can't just rely on the standard ratelimiting of events.
|
||||
await self.base_handler.ratelimit(requester)
|
||||
@@ -865,6 +926,8 @@ class RoomMemberHandler(object):
|
||||
raise SynapseError(403, "Invites have been disabled on this server")
|
||||
|
||||
if invitee:
|
||||
# Note that update_membership with an action of "invite" can raise
|
||||
# a ShadowBanError, but this was done above already.
|
||||
_, stream_id = await self.update_membership(
|
||||
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
|
||||
)
|
||||
@@ -970,9 +1033,7 @@ class RoomMemberHandler(object):
|
||||
)
|
||||
return stream_id
|
||||
|
||||
async def _is_host_in_room(
|
||||
self, current_state_ids: Dict[Tuple[str, str], str]
|
||||
) -> bool:
|
||||
async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
|
||||
# Have we just created the room, and is this about to be the very
|
||||
# first member event?
|
||||
create_event_id = current_state_ids.get(("m.room.create", ""))
|
||||
@@ -1103,7 +1164,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
||||
return event_id, stream_id
|
||||
|
||||
# The room is too large. Leave.
|
||||
requester = types.create_requester(user, None, False, None)
|
||||
requester = types.create_requester(user, None, False, False, None)
|
||||
await self.update_membership(
|
||||
requester=requester, target=user, room_id=room_id, action="leave"
|
||||
)
|
||||
|
||||
@@ -54,6 +54,7 @@ class Saml2SessionData:
|
||||
|
||||
class SamlHandler:
|
||||
def __init__(self, hs: "synapse.server.HomeServer"):
|
||||
self.hs = hs
|
||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||
self._auth = hs.get_auth()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
@@ -133,8 +134,14 @@ class SamlHandler:
|
||||
# the dict.
|
||||
self.expire_sessions()
|
||||
|
||||
# Pull out the user-agent and IP from the request.
|
||||
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
|
||||
0
|
||||
].decode("ascii", "surrogateescape")
|
||||
ip_address = self.hs.get_ip_from_request(request)
|
||||
|
||||
user_id, current_session = await self._map_saml_response_to_user(
|
||||
resp_bytes, relay_state
|
||||
resp_bytes, relay_state, user_agent, ip_address
|
||||
)
|
||||
|
||||
# Complete the interactive auth session or the login.
|
||||
@@ -147,7 +154,11 @@ class SamlHandler:
|
||||
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
|
||||
|
||||
async def _map_saml_response_to_user(
|
||||
self, resp_bytes: str, client_redirect_url: str
|
||||
self,
|
||||
resp_bytes: str,
|
||||
client_redirect_url: str,
|
||||
user_agent: str,
|
||||
ip_address: str,
|
||||
) -> Tuple[str, Optional[Saml2SessionData]]:
|
||||
"""
|
||||
Given a sample response, retrieve the cached session and user for it.
|
||||
@@ -155,6 +166,8 @@ class SamlHandler:
|
||||
Args:
|
||||
resp_bytes: The SAML response.
|
||||
client_redirect_url: The redirect URL passed in by the client.
|
||||
user_agent: The user agent of the client making the request.
|
||||
ip_address: The IP address of the client making the request.
|
||||
|
||||
Returns:
|
||||
Tuple of the user ID and SAML session associated with this response.
|
||||
@@ -291,6 +304,7 @@ class SamlHandler:
|
||||
localpart=localpart,
|
||||
default_display_name=displayname,
|
||||
bind_emails=emails,
|
||||
user_agent_ips=(user_agent, ip_address),
|
||||
)
|
||||
|
||||
await self._datastore.record_user_external_id(
|
||||
|
||||
@@ -172,12 +172,11 @@ from functools import wraps
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Type
|
||||
|
||||
import attr
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.http.site import SynapseRequest
|
||||
@@ -693,7 +692,7 @@ def active_span_context_as_string():
|
||||
opentracing.tracer.inject(
|
||||
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
|
||||
)
|
||||
return json.dumps(carrier)
|
||||
return json_encoder.encode(carrier)
|
||||
|
||||
|
||||
@only_if_tracing
|
||||
|
||||
@@ -316,6 +316,9 @@ class JoinRoomAliasServlet(RestServlet):
|
||||
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
|
||||
if join_rules_event:
|
||||
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
|
||||
# update_membership with an action of "invite" can raise a
|
||||
# ShadowBanError. This is not handled since it is assumed that
|
||||
# an admin isn't going to call this API with a shadow-banned user.
|
||||
await self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
target=fake_requester.user,
|
||||
|
||||
@@ -73,6 +73,7 @@ class UsersRestServletV2(RestServlet):
|
||||
The parameters `from` and `limit` are required only for pagination.
|
||||
By default, a `limit` of 100 is used.
|
||||
The parameter `user_id` can be used to filter by user id.
|
||||
The parameter `name` can be used to filter by user id or display name.
|
||||
The parameter `guests` can be used to exclude guest users.
|
||||
The parameter `deactivated` can be used to include deactivated users.
|
||||
"""
|
||||
@@ -89,11 +90,12 @@ class UsersRestServletV2(RestServlet):
|
||||
start = parse_integer(request, "from", default=0)
|
||||
limit = parse_integer(request, "limit", default=100)
|
||||
user_id = parse_string(request, "user_id", default=None)
|
||||
name = parse_string(request, "name", default=None)
|
||||
guests = parse_boolean(request, "guests", default=True)
|
||||
deactivated = parse_boolean(request, "deactivated", default=False)
|
||||
|
||||
users, total = await self.store.get_users_paginate(
|
||||
start, limit, user_id, guests, deactivated
|
||||
start, limit, user_id, name, guests, deactivated
|
||||
)
|
||||
ret = {"users": users, "total": total}
|
||||
if len(users) >= limit:
|
||||
|
||||
@@ -27,6 +27,7 @@ from synapse.api.errors import (
|
||||
Codes,
|
||||
HttpResponseException,
|
||||
InvalidClientCredentialsError,
|
||||
ShadowBanError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.filtering import Filter
|
||||
@@ -45,6 +46,7 @@ from synapse.storage.state import StateFilter
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
@@ -199,23 +201,26 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
||||
if state_key is not None:
|
||||
event_dict["state_key"] = state_key
|
||||
|
||||
if event_type == EventTypes.Member:
|
||||
membership = content.get("membership", None)
|
||||
event_id, _ = await self.room_member_handler.update_membership(
|
||||
requester,
|
||||
target=UserID.from_string(state_key),
|
||||
room_id=room_id,
|
||||
action=membership,
|
||||
content=content,
|
||||
)
|
||||
else:
|
||||
(
|
||||
event,
|
||||
_,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
event_id = event.event_id
|
||||
try:
|
||||
if event_type == EventTypes.Member:
|
||||
membership = content.get("membership", None)
|
||||
event_id, _ = await self.room_member_handler.update_membership(
|
||||
requester,
|
||||
target=UserID.from_string(state_key),
|
||||
room_id=room_id,
|
||||
action=membership,
|
||||
content=content,
|
||||
)
|
||||
else:
|
||||
(
|
||||
event,
|
||||
_,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
event_id = event.event_id
|
||||
except ShadowBanError:
|
||||
event_id = "$" + random_string(43)
|
||||
|
||||
set_tag("event_id", event_id)
|
||||
ret = {"event_id": event_id}
|
||||
@@ -248,12 +253,19 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
||||
if b"ts" in request.args and requester.app_service:
|
||||
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
|
||||
|
||||
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
try:
|
||||
(
|
||||
event,
|
||||
_,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
event_id = event.event_id
|
||||
except ShadowBanError:
|
||||
event_id = "$" + random_string(43)
|
||||
|
||||
set_tag("event_id", event.event_id)
|
||||
return 200, {"event_id": event.event_id}
|
||||
set_tag("event_id", event_id)
|
||||
return 200, {"event_id": event_id}
|
||||
|
||||
def on_GET(self, request, room_id, event_type, txn_id):
|
||||
return 200, "Not implemented"
|
||||
@@ -719,17 +731,21 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
||||
content = {}
|
||||
|
||||
if membership_action == "invite" and self._has_3pid_invite_keys(content):
|
||||
await self.room_member_handler.do_3pid_invite(
|
||||
room_id,
|
||||
requester.user,
|
||||
content["medium"],
|
||||
content["address"],
|
||||
content["id_server"],
|
||||
requester,
|
||||
txn_id,
|
||||
new_room=False,
|
||||
id_access_token=content.get("id_access_token"),
|
||||
)
|
||||
try:
|
||||
await self.room_member_handler.do_3pid_invite(
|
||||
room_id,
|
||||
requester.user,
|
||||
content["medium"],
|
||||
content["address"],
|
||||
content["id_server"],
|
||||
requester,
|
||||
txn_id,
|
||||
new_room=False,
|
||||
id_access_token=content.get("id_access_token"),
|
||||
)
|
||||
except ShadowBanError:
|
||||
# Pretend the request succeeded.
|
||||
pass
|
||||
return 200, {}
|
||||
|
||||
target = requester.user
|
||||
@@ -741,15 +757,19 @@ class RoomMembershipRestServlet(TransactionRestServlet):
|
||||
if "reason" in content:
|
||||
event_content = {"reason": content["reason"]}
|
||||
|
||||
await self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
target=target,
|
||||
room_id=room_id,
|
||||
action=membership_action,
|
||||
txn_id=txn_id,
|
||||
third_party_signed=content.get("third_party_signed", None),
|
||||
content=event_content,
|
||||
)
|
||||
try:
|
||||
await self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
target=target,
|
||||
room_id=room_id,
|
||||
action=membership_action,
|
||||
txn_id=txn_id,
|
||||
third_party_signed=content.get("third_party_signed", None),
|
||||
content=event_content,
|
||||
)
|
||||
except ShadowBanError:
|
||||
# Pretend the request succeeded.
|
||||
pass
|
||||
|
||||
return_value = {}
|
||||
|
||||
@@ -787,20 +807,27 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Redaction,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": requester.user.to_string(),
|
||||
"redacts": event_id,
|
||||
},
|
||||
txn_id=txn_id,
|
||||
)
|
||||
try:
|
||||
(
|
||||
event,
|
||||
_,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Redaction,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": requester.user.to_string(),
|
||||
"redacts": event_id,
|
||||
},
|
||||
txn_id=txn_id,
|
||||
)
|
||||
event_id = event.event_id
|
||||
except ShadowBanError:
|
||||
event_id = "$" + random_string(43)
|
||||
|
||||
set_tag("event_id", event.event_id)
|
||||
return 200, {"event_id": event.event_id}
|
||||
set_tag("event_id", event_id)
|
||||
return 200, {"event_id": event_id}
|
||||
|
||||
def on_PUT(self, request, room_id, event_id, txn_id):
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -122,6 +123,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||
if self.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
|
||||
@@ -491,6 +495,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||
if self.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
@@ -563,6 +570,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||
if self.hs.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
from typing import List, Union
|
||||
|
||||
@@ -133,6 +134,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||
if self.hs.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
@@ -207,6 +211,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||
if self.hs.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(
|
||||
@@ -658,6 +665,10 @@ class RegisterRestServlet(RestServlet):
|
||||
Codes.THREEPID_IN_USE,
|
||||
)
|
||||
|
||||
entries = await self.store.get_user_agents_ips_to_ui_auth_session(
|
||||
session_id
|
||||
)
|
||||
|
||||
registered_user_id = await self.registration_handler.register_user(
|
||||
localpart=desired_username,
|
||||
password_hash=password_hash,
|
||||
@@ -665,6 +676,7 @@ class RegisterRestServlet(RestServlet):
|
||||
default_display_name=desired_display_name,
|
||||
threepid=threepid,
|
||||
address=client_addr,
|
||||
user_agent_ips=entries,
|
||||
)
|
||||
# Necessary due to auth checks prior to the threepid being
|
||||
# written to the db
|
||||
|
||||
@@ -22,7 +22,7 @@ any time to reflect changes in the MSC.
|
||||
import logging
|
||||
|
||||
from synapse.api.constants import EventTypes, RelationTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import ShadowBanError, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_integer,
|
||||
@@ -35,6 +35,7 @@ from synapse.storage.relations import (
|
||||
PaginationChunk,
|
||||
RelationPaginationToken,
|
||||
)
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
@@ -111,11 +112,18 @@ class RelationSendServlet(RestServlet):
|
||||
"sender": requester.user.to_string(),
|
||||
}
|
||||
|
||||
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict=event_dict, txn_id=txn_id
|
||||
)
|
||||
try:
|
||||
(
|
||||
event,
|
||||
_,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict=event_dict, txn_id=txn_id
|
||||
)
|
||||
event_id = event.event_id
|
||||
except ShadowBanError:
|
||||
event_id = "$" + random_string(43)
|
||||
|
||||
return 200, {"event_id": event.event_id}
|
||||
return 200, {"event_id": event_id}
|
||||
|
||||
|
||||
class RelationPaginationServlet(RestServlet):
|
||||
|
||||
@@ -15,13 +15,14 @@
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.errors import Codes, ShadowBanError, SynapseError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.util import stringutils
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
@@ -62,7 +63,6 @@ class RoomUpgradeRestServlet(RestServlet):
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(content, ("new_version",))
|
||||
new_version = content["new_version"]
|
||||
|
||||
new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"])
|
||||
if new_version is None:
|
||||
@@ -72,9 +72,13 @@ class RoomUpgradeRestServlet(RestServlet):
|
||||
Codes.UNSUPPORTED_ROOM_VERSION,
|
||||
)
|
||||
|
||||
new_room_id = await self._room_creation_handler.upgrade_room(
|
||||
requester, room_id, new_version
|
||||
)
|
||||
try:
|
||||
new_room_id = await self._room_creation_handler.upgrade_room(
|
||||
requester, room_id, new_version
|
||||
)
|
||||
except ShadowBanError:
|
||||
# Generate a random room ID.
|
||||
new_room_id = stringutils.random_string(18)
|
||||
|
||||
ret = {"replacement_room": new_room_id}
|
||||
|
||||
|
||||
@@ -13,12 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.http.server import set_cors_headers
|
||||
from synapse.util import json_encoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -67,4 +67,4 @@ class WellKnownResource(Resource):
|
||||
|
||||
logger.debug("returning: %s", r)
|
||||
request.setHeader(b"Content-Type", b"application/json")
|
||||
return json.dumps(r).encode("utf-8")
|
||||
return json_encoder.encode(r).encode("utf-8")
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from enum import Enum
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -25,6 +26,16 @@ if MYPY:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegistrationBehaviour(Enum):
|
||||
"""
|
||||
Enum to define whether a registration request should allowed, denied, or shadow-banned.
|
||||
"""
|
||||
|
||||
ALLOW = "allow"
|
||||
SHADOW_BAN = "shadow_ban"
|
||||
DENY = "deny"
|
||||
|
||||
|
||||
class SpamCheckerApi(object):
|
||||
"""A proxy object that gets passed to spam checkers so they can get
|
||||
access to rooms and other relevant information.
|
||||
|
||||
+122
-70
@@ -16,11 +16,22 @@
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Awaitable, Dict, Iterable, List, Optional, Set
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
from prometheus_client import Histogram
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
|
||||
@@ -30,7 +41,7 @@ from synapse.logging.utils import log_function
|
||||
from synapse.state import v1, v2
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.types import StateMap
|
||||
from synapse.types import Collection, StateMap
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
@@ -68,8 +79,14 @@ def _gen_state_id():
|
||||
class _StateCacheEntry(object):
|
||||
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
|
||||
|
||||
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
|
||||
# dict[(str, str), str] map from (type, state_key) to event_id
|
||||
def __init__(
|
||||
self,
|
||||
state: StateMap[str],
|
||||
state_group: Optional[int],
|
||||
prev_group: Optional[int] = None,
|
||||
delta_ids: Optional[StateMap[str]] = None,
|
||||
):
|
||||
# A map from (type, state_key) to event_id.
|
||||
self.state = frozendict(state)
|
||||
|
||||
# the ID of a state group if one and only one is involved.
|
||||
@@ -107,24 +124,49 @@ class StateHandler(object):
|
||||
self.hs = hs
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
|
||||
@overload
|
||||
async def get_current_state(
|
||||
self, room_id, event_type=None, state_key="", latest_event_ids=None
|
||||
):
|
||||
""" Retrieves the current state for the room. This is done by
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: Literal[None] = None,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> StateMap[EventBase]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> Optional[EventBase]:
|
||||
...
|
||||
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: Optional[str] = None,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> Union[Optional[EventBase], StateMap[EventBase]]:
|
||||
"""Retrieves the current state for the room. This is done by
|
||||
calling `get_latest_events_in_room` to get the leading edges of the
|
||||
event graph and then resolving any of the state conflicts.
|
||||
|
||||
This is equivalent to getting the state of an event that were to send
|
||||
next before receiving any new events.
|
||||
|
||||
If `event_type` is specified, then the method returns only the one
|
||||
event (or None) with that `event_type` and `state_key`.
|
||||
|
||||
Returns:
|
||||
map from (type, state_key) to event
|
||||
If `event_type` is specified, then the method returns only the one
|
||||
event (or None) with that `event_type` and `state_key`.
|
||||
|
||||
Otherwise, a map from (type, state_key) to event.
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state")
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
@@ -140,34 +182,30 @@ class StateHandler(object):
|
||||
state_map = await self.store.get_events(
|
||||
list(state.values()), get_prev_content=False
|
||||
)
|
||||
state = {
|
||||
return {
|
||||
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
|
||||
}
|
||||
|
||||
return state
|
||||
|
||||
async def get_current_state_ids(self, room_id, latest_event_ids=None):
|
||||
async def get_current_state_ids(
|
||||
self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
|
||||
) -> StateMap[str]:
|
||||
"""Get the current state, or the state at a set of events, for a room
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
|
||||
latest_event_ids (iterable[str]|None): if given, the forward
|
||||
extremities to resolve. If None, we look them up from the
|
||||
database (via a cache)
|
||||
room_id:
|
||||
latest_event_ids: if given, the forward extremities to resolve. If
|
||||
None, we look them up from the database (via a cache).
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str)]]: the state dict, mapping from
|
||||
(event_type, state_key) -> event_id
|
||||
the state dict, mapping from (event_type, state_key) -> event_id
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
state = ret.state
|
||||
|
||||
return state
|
||||
return dict(ret.state)
|
||||
|
||||
async def get_current_users_in_room(
|
||||
self, room_id: str, latest_event_ids: Optional[List[str]] = None
|
||||
@@ -183,32 +221,34 @@ class StateHandler(object):
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_users_in_room")
|
||||
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
joined_users = await self.store.get_joined_users_from_state(room_id, entry)
|
||||
return joined_users
|
||||
return await self.store.get_joined_users_from_state(room_id, entry)
|
||||
|
||||
async def get_current_hosts_in_room(self, room_id):
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
return await self.get_hosts_in_room_at_events(room_id, event_ids)
|
||||
|
||||
async def get_hosts_in_room_at_events(self, room_id, event_ids):
|
||||
async def get_hosts_in_room_at_events(
|
||||
self, room_id: str, event_ids: List[str]
|
||||
) -> Set[str]:
|
||||
"""Get the hosts that were in a room at the given event ids
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
event_ids (list[str]):
|
||||
room_id:
|
||||
event_ids:
|
||||
|
||||
Returns:
|
||||
Deferred[list[str]]: the hosts in the room at the given events
|
||||
The hosts in the room at the given events
|
||||
"""
|
||||
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
|
||||
joined_hosts = await self.store.get_joined_hosts(room_id, entry)
|
||||
return joined_hosts
|
||||
return await self.store.get_joined_hosts(room_id, entry)
|
||||
|
||||
async def compute_event_context(
|
||||
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
|
||||
):
|
||||
) -> EventContext:
|
||||
"""Build an EventContext structure for the event.
|
||||
|
||||
This works out what the current state should be for the event, and
|
||||
@@ -221,7 +261,7 @@ class StateHandler(object):
|
||||
when receiving an event from federation where we don't have the
|
||||
prev events for, e.g. when backfilling.
|
||||
Returns:
|
||||
synapse.events.snapshot.EventContext:
|
||||
The event context.
|
||||
"""
|
||||
|
||||
if event.internal_metadata.is_outlier():
|
||||
@@ -275,7 +315,7 @@ class StateHandler(object):
|
||||
event.room_id, event.prev_event_ids()
|
||||
)
|
||||
|
||||
state_ids_before_event = entry.state
|
||||
state_ids_before_event = dict(entry.state)
|
||||
state_group_before_event = entry.state_group
|
||||
state_group_before_event_prev_group = entry.prev_group
|
||||
deltas_to_state_group_before_event = entry.delta_ids
|
||||
@@ -346,19 +386,18 @@ class StateHandler(object):
|
||||
)
|
||||
|
||||
@measure_func()
|
||||
async def resolve_state_groups_for_events(self, room_id, event_ids):
|
||||
async def resolve_state_groups_for_events(
|
||||
self, room_id: str, event_ids: Iterable[str]
|
||||
) -> _StateCacheEntry:
|
||||
""" Given a list of event_ids this method fetches the state at each
|
||||
event, resolves conflicts between them and returns them.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
event_ids (list[str])
|
||||
explicit_room_version (str|None): If set uses the the given room
|
||||
version to choose the resolution algorithm. If None, then
|
||||
checks the database for room version.
|
||||
room_id
|
||||
event_ids
|
||||
|
||||
Returns:
|
||||
Deferred[_StateCacheEntry]: resolved state
|
||||
The resolved state
|
||||
"""
|
||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||
|
||||
@@ -394,7 +433,12 @@ class StateHandler(object):
|
||||
)
|
||||
return result
|
||||
|
||||
async def resolve_events(self, room_version, state_sets, event):
|
||||
async def resolve_events(
|
||||
self,
|
||||
room_version: str,
|
||||
state_sets: Collection[Iterable[EventBase]],
|
||||
event: EventBase,
|
||||
) -> StateMap[EventBase]:
|
||||
logger.info(
|
||||
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
||||
)
|
||||
@@ -414,9 +458,7 @@ class StateHandler(object):
|
||||
state_res_store=StateResolutionStore(self.store),
|
||||
)
|
||||
|
||||
new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()}
|
||||
|
||||
return new_state
|
||||
return {key: state_map[ev_id] for key, ev_id in new_state.items()}
|
||||
|
||||
|
||||
class StateResolutionHandler(object):
|
||||
@@ -444,7 +486,12 @@ class StateResolutionHandler(object):
|
||||
|
||||
@log_function
|
||||
async def resolve_state_groups(
|
||||
self, room_id, room_version, state_groups_ids, event_map, state_res_store
|
||||
self,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
state_groups_ids: Dict[int, StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_res_store: "StateResolutionStore",
|
||||
):
|
||||
"""Resolves conflicts between a set of state groups
|
||||
|
||||
@@ -452,13 +499,13 @@ class StateResolutionHandler(object):
|
||||
not be called for a single state group
|
||||
|
||||
Args:
|
||||
room_id (str): room we are resolving for (used for logging and sanity checks)
|
||||
room_version (str): version of the room
|
||||
state_groups_ids (dict[int, dict[(str, str), str]]):
|
||||
map from state group id to the state in that state group
|
||||
room_id: room we are resolving for (used for logging and sanity checks)
|
||||
room_version: version of the room
|
||||
state_groups_ids:
|
||||
A map from state group id to the state in that state group
|
||||
(where 'state' is a map from state key to event id)
|
||||
|
||||
event_map(dict[str,FrozenEvent]|None):
|
||||
event_map:
|
||||
a dict from event_id to event, for any events that we happen to
|
||||
have in flight (eg, those currently being persisted). This will be
|
||||
used as a starting point fof finding the state we need; any missing
|
||||
@@ -466,10 +513,10 @@ class StateResolutionHandler(object):
|
||||
|
||||
If None, all events will be fetched via state_res_store.
|
||||
|
||||
state_res_store (StateResolutionStore)
|
||||
state_res_store
|
||||
|
||||
Returns:
|
||||
_StateCacheEntry: resolved state
|
||||
The resolved state
|
||||
"""
|
||||
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
|
||||
|
||||
@@ -530,21 +577,22 @@ class StateResolutionHandler(object):
|
||||
return cache
|
||||
|
||||
|
||||
def _make_state_cache_entry(new_state, state_groups_ids):
|
||||
def _make_state_cache_entry(
|
||||
new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
|
||||
) -> _StateCacheEntry:
|
||||
"""Given a resolved state, and a set of input state groups, pick one to base
|
||||
a new state group on (if any), and return an appropriately-constructed
|
||||
_StateCacheEntry.
|
||||
|
||||
Args:
|
||||
new_state (dict[(str, str), str]): resolved state map (mapping from
|
||||
(type, state_key) to event_id)
|
||||
new_state: resolved state map (mapping from (type, state_key) to event_id)
|
||||
|
||||
state_groups_ids (dict[int, dict[(str, str), str]]):
|
||||
map from state group id to the state in that state group
|
||||
(where 'state' is a map from state key to event id)
|
||||
state_groups_ids:
|
||||
map from state group id to the state in that state group (where
|
||||
'state' is a map from state key to event id)
|
||||
|
||||
Returns:
|
||||
_StateCacheEntry
|
||||
The cache entry.
|
||||
"""
|
||||
# if the new state matches any of the input state groups, we can
|
||||
# use that state group again. Otherwise we will generate a state_id
|
||||
@@ -585,7 +633,7 @@ def resolve_events_with_store(
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
state_sets: List[StateMap[str]],
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_res_store: "StateResolutionStore",
|
||||
) -> Awaitable[StateMap[str]]:
|
||||
@@ -633,15 +681,17 @@ class StateResolutionStore(object):
|
||||
|
||||
store = attr.ib()
|
||||
|
||||
def get_events(self, event_ids, allow_rejected=False):
|
||||
def get_events(
|
||||
self, event_ids: Iterable[str], allow_rejected: bool = False
|
||||
) -> Awaitable[Dict[str, EventBase]]:
|
||||
"""Get events from the database
|
||||
|
||||
Args:
|
||||
event_ids (list): The event_ids of the events to fetch
|
||||
allow_rejected (bool): If True return rejected events.
|
||||
event_ids: The event_ids of the events to fetch
|
||||
allow_rejected: If True return rejected events.
|
||||
|
||||
Returns:
|
||||
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
|
||||
An awaitable which resolves to a dict from event_id to event.
|
||||
"""
|
||||
|
||||
return self.store.get_events(
|
||||
@@ -651,7 +701,9 @@ class StateResolutionStore(object):
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
|
||||
def get_auth_chain_difference(
|
||||
self, state_sets: List[Set[str]]
|
||||
) -> Awaitable[Set[str]]:
|
||||
"""Given sets of state events figure out the auth chain difference (as
|
||||
per state res v2 algorithm).
|
||||
|
||||
@@ -660,7 +712,7 @@ class StateResolutionStore(object):
|
||||
chain.
|
||||
|
||||
Returns:
|
||||
Deferred[Set[str]]: Set of event IDs.
|
||||
An awaitable that resolves to a set of event IDs.
|
||||
"""
|
||||
|
||||
return self.store.get_auth_chain_difference(state_sets)
|
||||
|
||||
+59
-28
@@ -15,7 +15,17 @@
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Awaitable, Callable, Dict, List, Optional
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes
|
||||
@@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "")
|
||||
|
||||
async def resolve_events_with_store(
|
||||
room_id: str,
|
||||
state_sets: List[StateMap[str]],
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_map_factory: Callable[[List[str]], Awaitable],
|
||||
):
|
||||
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Args:
|
||||
room_id: the room we are working in
|
||||
@@ -56,8 +66,7 @@ async def resolve_events_with_store(
|
||||
an Awaitable that resolves to a dict of event_id to event.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
A map from (type, state_key) to event_id.
|
||||
"""
|
||||
if len(state_sets) == 1:
|
||||
return state_sets[0]
|
||||
@@ -75,8 +84,8 @@ async def resolve_events_with_store(
|
||||
"Asking for %d/%d conflicted events", len(needed_events), needed_event_count
|
||||
)
|
||||
|
||||
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
|
||||
# the state events which are in conflict (and those in event_map)
|
||||
# A map from state event id to event. Only includes the state events which
|
||||
# are in conflict (and those in event_map).
|
||||
state_map = await state_map_factory(needed_events)
|
||||
if event_map is not None:
|
||||
state_map.update(event_map)
|
||||
@@ -91,8 +100,6 @@ async def resolve_events_with_store(
|
||||
|
||||
# get the ids of the auth events which allow us to authenticate the
|
||||
# conflicted state, picking only from the unconflicting state.
|
||||
#
|
||||
# dict[(str, str), str]: a map from state key to event id
|
||||
auth_events = _create_auth_events_from_maps(
|
||||
unconflicted_state, conflicted_state, state_map
|
||||
)
|
||||
@@ -122,29 +129,30 @@ async def resolve_events_with_store(
|
||||
)
|
||||
|
||||
|
||||
def _seperate(state_sets):
|
||||
def _seperate(
|
||||
state_sets: Iterable[StateMap[str]],
|
||||
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
|
||||
"""Takes the state_sets and figures out which keys are conflicted and
|
||||
which aren't. i.e., which have multiple different event_ids associated
|
||||
with them in different state sets.
|
||||
|
||||
Args:
|
||||
state_sets(iterable[dict[(str, str), str]]):
|
||||
state_sets:
|
||||
List of dicts of (type, state_key) -> event_id, which are the
|
||||
different state groups to resolve.
|
||||
|
||||
Returns:
|
||||
(dict[(str, str), str], dict[(str, str), set[str]]):
|
||||
A tuple of (unconflicted_state, conflicted_state), where:
|
||||
A tuple of (unconflicted_state, conflicted_state), where:
|
||||
|
||||
unconflicted_state is a dict mapping (type, state_key)->event_id
|
||||
for unconflicted state keys.
|
||||
unconflicted_state is a dict mapping (type, state_key)->event_id
|
||||
for unconflicted state keys.
|
||||
|
||||
conflicted_state is a dict mapping (type, state_key) to a set of
|
||||
event ids for conflicted state keys.
|
||||
conflicted_state is a dict mapping (type, state_key) to a set of
|
||||
event ids for conflicted state keys.
|
||||
"""
|
||||
state_set_iterator = iter(state_sets)
|
||||
unconflicted_state = dict(next(state_set_iterator))
|
||||
conflicted_state = {}
|
||||
conflicted_state = {} # type: StateMap[Set[str]]
|
||||
|
||||
for state_set in state_set_iterator:
|
||||
for key, value in state_set.items():
|
||||
@@ -171,7 +179,21 @@ def _seperate(state_sets):
|
||||
return unconflicted_state, conflicted_state
|
||||
|
||||
|
||||
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
|
||||
def _create_auth_events_from_maps(
|
||||
unconflicted_state: StateMap[str],
|
||||
conflicted_state: StateMap[Set[str]],
|
||||
state_map: Dict[str, EventBase],
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
unconflicted_state: The unconflicted state map.
|
||||
conflicted_state: The conflicted state map.
|
||||
state_map:
|
||||
|
||||
Returns:
|
||||
A map from state key to event id.
|
||||
"""
|
||||
auth_events = {}
|
||||
for event_ids in conflicted_state.values():
|
||||
for event_id in event_ids:
|
||||
@@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
|
||||
keys = event_auth.auth_types_for_event(state_map[event_id])
|
||||
for key in keys:
|
||||
if key not in auth_events:
|
||||
event_id = unconflicted_state.get(key, None)
|
||||
if event_id:
|
||||
auth_events[key] = event_id
|
||||
auth_event_id = unconflicted_state.get(key, None)
|
||||
if auth_event_id:
|
||||
auth_events[key] = auth_event_id
|
||||
return auth_events
|
||||
|
||||
|
||||
def _resolve_with_state(
|
||||
unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
|
||||
unconflicted_state_ids: StateMap[str],
|
||||
conflicted_state_ids: StateMap[Set[str]],
|
||||
auth_event_ids: StateMap[str],
|
||||
state_map: Dict[str, EventBase],
|
||||
):
|
||||
conflicted_state = {}
|
||||
for key, event_ids in conflicted_state_ids.items():
|
||||
@@ -215,7 +240,9 @@ def _resolve_with_state(
|
||||
return new_state
|
||||
|
||||
|
||||
def _resolve_state_events(conflicted_state, auth_events):
|
||||
def _resolve_state_events(
|
||||
conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase]
|
||||
) -> StateMap[EventBase]:
|
||||
""" This is where we actually decide which of the conflicted state to
|
||||
use.
|
||||
|
||||
@@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events):
|
||||
return resolved_state
|
||||
|
||||
|
||||
def _resolve_auth_events(events, auth_events):
|
||||
def _resolve_auth_events(
|
||||
events: List[EventBase], auth_events: StateMap[EventBase]
|
||||
) -> EventBase:
|
||||
reverse = list(reversed(_ordered_events(events)))
|
||||
|
||||
auth_keys = {
|
||||
@@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events):
|
||||
return event
|
||||
|
||||
|
||||
def _resolve_normal_events(events, auth_events):
|
||||
def _resolve_normal_events(
|
||||
events: List[EventBase], auth_events: StateMap[EventBase]
|
||||
) -> EventBase:
|
||||
for event in _ordered_events(events):
|
||||
try:
|
||||
# The signatures have already been checked at this point
|
||||
@@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events):
|
||||
return event
|
||||
|
||||
|
||||
def _ordered_events(events):
|
||||
def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
|
||||
def key_func(e):
|
||||
# we have to use utf-8 rather than ascii here because it turns out we allow
|
||||
# people to send us events with non-ascii event IDs :/
|
||||
|
||||
+167
-88
@@ -16,7 +16,21 @@
|
||||
import heapq
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
import synapse.state
|
||||
from synapse import event_auth
|
||||
@@ -40,10 +54,10 @@ async def resolve_events_with_store(
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
state_sets: List[StateMap[str]],
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
):
|
||||
) -> StateMap[str]:
|
||||
"""Resolves the state using the v2 state resolution algorithm
|
||||
|
||||
Args:
|
||||
@@ -63,8 +77,7 @@ async def resolve_events_with_store(
|
||||
state_res_store:
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
A map from (type, state_key) to event_id.
|
||||
"""
|
||||
|
||||
logger.debug("Computing conflicted state")
|
||||
@@ -171,18 +184,23 @@ async def resolve_events_with_store(
|
||||
return resolved_state
|
||||
|
||||
|
||||
async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
|
||||
async def _get_power_level_for_sender(
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
) -> int:
|
||||
"""Return the power level of the sender of the given event according to
|
||||
their auth events.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
event_id (str)
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
room_id
|
||||
event_id
|
||||
event_map
|
||||
state_res_store
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
The power level.
|
||||
"""
|
||||
event = await _get_event(room_id, event_id, event_map, state_res_store)
|
||||
|
||||
@@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st
|
||||
return int(level)
|
||||
|
||||
|
||||
async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
|
||||
async def _get_auth_chain_difference(
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
) -> Set[str]:
|
||||
"""Compare the auth chains of each state set and return the set of events
|
||||
that only appear in some but not all of the auth chains.
|
||||
|
||||
Args:
|
||||
state_sets (list)
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
state_sets
|
||||
event_map
|
||||
state_res_store
|
||||
|
||||
Returns:
|
||||
Deferred[set[str]]: Set of event IDs
|
||||
Set of event IDs
|
||||
"""
|
||||
|
||||
difference = await state_res_store.get_auth_chain_difference(
|
||||
@@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
|
||||
return difference
|
||||
|
||||
|
||||
def _seperate(state_sets):
|
||||
def _seperate(
|
||||
state_sets: Iterable[StateMap[str]],
|
||||
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
|
||||
"""Return the unconflicted and conflicted state. This is different than in
|
||||
the original algorithm, as this defines a key to be conflicted if one of
|
||||
the state sets doesn't have that key.
|
||||
|
||||
Args:
|
||||
state_sets (list)
|
||||
state_sets
|
||||
|
||||
Returns:
|
||||
tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
|
||||
conflicted state dict is a map from type/state_key to set of event IDs
|
||||
A tuple of unconflicted and conflicted state. The conflicted state dict
|
||||
is a map from type/state_key to set of event IDs
|
||||
"""
|
||||
unconflicted_state = {}
|
||||
conflicted_state = {}
|
||||
@@ -260,18 +284,20 @@ def _seperate(state_sets):
|
||||
event_ids.discard(None)
|
||||
conflicted_state[key] = event_ids
|
||||
|
||||
return unconflicted_state, conflicted_state
|
||||
# mypy doesn't understand that discarding None above means that conflicted
|
||||
# state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
|
||||
return unconflicted_state, conflicted_state # type: ignore
|
||||
|
||||
|
||||
def _is_power_event(event):
|
||||
def _is_power_event(event: EventBase) -> bool:
|
||||
"""Return whether or not the event is a "power event", as defined by the
|
||||
v2 state resolution algorithm
|
||||
|
||||
Args:
|
||||
event (FrozenEvent)
|
||||
event
|
||||
|
||||
Returns:
|
||||
boolean
|
||||
True if the event is a power event.
|
||||
"""
|
||||
if (event.type, event.state_key) in (
|
||||
(EventTypes.PowerLevels, ""),
|
||||
@@ -288,19 +314,23 @@ def _is_power_event(event):
|
||||
|
||||
|
||||
async def _add_event_and_auth_chain_to_graph(
|
||||
graph, room_id, event_id, event_map, state_res_store, auth_diff
|
||||
):
|
||||
graph: Dict[str, Set[str]],
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
auth_diff: Set[str],
|
||||
) -> None:
|
||||
"""Helper function for _reverse_topological_power_sort that add the event
|
||||
and its auth chain (that is in the auth diff) to the graph
|
||||
|
||||
Args:
|
||||
graph (dict[str, set[str]]): A map from event ID to the events auth
|
||||
event IDs
|
||||
room_id (str): the room we are working in
|
||||
event_id (str): Event to add to the graph
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
auth_diff (set[str]): Set of event IDs that are in the auth difference.
|
||||
graph: A map from event ID to the events auth event IDs
|
||||
room_id: the room we are working in
|
||||
event_id: Event to add to the graph
|
||||
event_map
|
||||
state_res_store
|
||||
auth_diff: Set of event IDs that are in the auth difference.
|
||||
"""
|
||||
|
||||
state = [event_id]
|
||||
@@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph(
|
||||
|
||||
|
||||
async def _reverse_topological_power_sort(
|
||||
clock, room_id, event_ids, event_map, state_res_store, auth_diff
|
||||
):
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
event_ids: Iterable[str],
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
auth_diff: Set[str],
|
||||
) -> List[str]:
|
||||
"""Returns a list of the event_ids sorted by reverse topological ordering,
|
||||
and then by power level and origin_server_ts
|
||||
|
||||
Args:
|
||||
clock (Clock)
|
||||
room_id (str): the room we are working in
|
||||
event_ids (list[str]): The events to sort
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
auth_diff (set[str]): Set of event IDs that are in the auth difference.
|
||||
clock
|
||||
room_id: the room we are working in
|
||||
event_ids: The events to sort
|
||||
event_map
|
||||
state_res_store
|
||||
auth_diff: Set of event IDs that are in the auth difference.
|
||||
|
||||
Returns:
|
||||
Deferred[list[str]]: The sorted list
|
||||
The sorted list
|
||||
"""
|
||||
|
||||
graph = {}
|
||||
graph = {} # type: Dict[str, Set[str]]
|
||||
for idx, event_id in enumerate(event_ids, start=1):
|
||||
await _add_event_and_auth_chain_to_graph(
|
||||
graph, room_id, event_id, event_map, state_res_store, auth_diff
|
||||
@@ -372,22 +407,28 @@ async def _reverse_topological_power_sort(
|
||||
|
||||
|
||||
async def _iterative_auth_checks(
|
||||
clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
|
||||
):
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
event_ids: List[str],
|
||||
base_state: StateMap[str],
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
) -> StateMap[str]:
|
||||
"""Sequentially apply auth checks to each event in given list, updating the
|
||||
state as it goes along.
|
||||
|
||||
Args:
|
||||
clock (Clock)
|
||||
room_id (str)
|
||||
room_version (str)
|
||||
event_ids (list[str]): Ordered list of events to apply auth checks to
|
||||
base_state (StateMap[str]): The set of state to start with
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
clock
|
||||
room_id
|
||||
room_version
|
||||
event_ids: Ordered list of events to apply auth checks to
|
||||
base_state: The set of state to start with
|
||||
event_map
|
||||
state_res_store
|
||||
|
||||
Returns:
|
||||
Deferred[StateMap[str]]: Returns the final updated state
|
||||
Returns the final updated state
|
||||
"""
|
||||
resolved_state = base_state.copy()
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
@@ -439,21 +480,26 @@ async def _iterative_auth_checks(
|
||||
|
||||
|
||||
async def _mainline_sort(
|
||||
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
|
||||
):
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
event_ids: List[str],
|
||||
resolved_power_event_id: Optional[str],
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
) -> List[str]:
|
||||
"""Returns a sorted list of event_ids sorted by mainline ordering based on
|
||||
the given event resolved_power_event_id
|
||||
|
||||
Args:
|
||||
clock (Clock)
|
||||
room_id (str): room we're working in
|
||||
event_ids (list[str]): Events to sort
|
||||
resolved_power_event_id (str): The final resolved power level event ID
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
clock
|
||||
room_id: room we're working in
|
||||
event_ids: Events to sort
|
||||
resolved_power_event_id: The final resolved power level event ID
|
||||
event_map
|
||||
state_res_store
|
||||
|
||||
Returns:
|
||||
Deferred[list[str]]: The sorted list
|
||||
The sorted list
|
||||
"""
|
||||
if not event_ids:
|
||||
# It's possible for there to be no event IDs here to sort, so we can
|
||||
@@ -505,59 +551,90 @@ async def _mainline_sort(
|
||||
|
||||
|
||||
async def _get_mainline_depth_for_event(
|
||||
event, mainline_map, event_map, state_res_store
|
||||
):
|
||||
event: EventBase,
|
||||
mainline_map: Dict[str, int],
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
) -> int:
|
||||
"""Get the mainline depths for the given event based on the mainline map
|
||||
|
||||
Args:
|
||||
event (FrozenEvent)
|
||||
mainline_map (dict[str, int]): Map from event_id to mainline depth for
|
||||
events in the mainline.
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
event
|
||||
mainline_map: Map from event_id to mainline depth for events in the mainline.
|
||||
event_map
|
||||
state_res_store
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
The mainline depth
|
||||
"""
|
||||
|
||||
room_id = event.room_id
|
||||
tmp_event = event # type: Optional[EventBase]
|
||||
|
||||
# We do an iterative search, replacing `event with the power level in its
|
||||
# auth events (if any)
|
||||
while event:
|
||||
while tmp_event:
|
||||
depth = mainline_map.get(event.event_id)
|
||||
if depth is not None:
|
||||
return depth
|
||||
|
||||
auth_events = event.auth_event_ids()
|
||||
event = None
|
||||
auth_events = tmp_event.auth_event_ids()
|
||||
tmp_event = None
|
||||
|
||||
for aid in auth_events:
|
||||
aev = await _get_event(
|
||||
room_id, aid, event_map, state_res_store, allow_none=True
|
||||
)
|
||||
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
|
||||
event = aev
|
||||
tmp_event = aev
|
||||
break
|
||||
|
||||
# Didn't find a power level auth event, so we just return 0
|
||||
return 0
|
||||
|
||||
|
||||
async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
|
||||
@overload
|
||||
async def _get_event(
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
allow_none: Literal[False] = False,
|
||||
) -> EventBase:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
async def _get_event(
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
allow_none: Literal[True],
|
||||
) -> Optional[EventBase]:
|
||||
...
|
||||
|
||||
|
||||
async def _get_event(
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
allow_none: bool = False,
|
||||
) -> Optional[EventBase]:
|
||||
"""Helper function to look up event in event_map, falling back to looking
|
||||
it up in the store
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
event_id (str)
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
allow_none (bool): if the event is not found, return None rather than raising
|
||||
room_id
|
||||
event_id
|
||||
event_map
|
||||
state_res_store
|
||||
allow_none: if the event is not found, return None rather than raising
|
||||
an exception
|
||||
|
||||
Returns:
|
||||
Deferred[Optional[FrozenEvent]]
|
||||
The event, or none if the event does not exist (and allow_none is True).
|
||||
"""
|
||||
if event_id not in event_map:
|
||||
events = await state_res_store.get_events([event_id], allow_rejected=True)
|
||||
@@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F
|
||||
return event
|
||||
|
||||
|
||||
def lexicographical_topological_sort(graph, key):
|
||||
def lexicographical_topological_sort(
|
||||
graph: Dict[str, Set[str]], key: Callable[[str], Any]
|
||||
) -> Generator[str, None, None]:
|
||||
"""Performs a lexicographic reverse topological sort on the graph.
|
||||
|
||||
This returns a reverse topological sort (i.e. if node A references B then B
|
||||
@@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key):
|
||||
NOTE: `graph` is modified during the sort.
|
||||
|
||||
Args:
|
||||
graph (dict[str, set[str]]): A representation of the graph where each
|
||||
node is a key in the dict and its value are the nodes edges.
|
||||
key (func): A function that takes a node and returns a value that is
|
||||
comparable and used to order nodes
|
||||
graph: A representation of the graph where each node is a key in the
|
||||
dict and its value are the nodes edges.
|
||||
key: A function that takes a node and returns a value that is comparable
|
||||
and used to order nodes
|
||||
|
||||
Yields:
|
||||
str: The next node in the topological sort
|
||||
The next node in the topological sort
|
||||
"""
|
||||
|
||||
# Note, this is basically Kahn's algorithm except we look at nodes with no
|
||||
# outgoing edges, c.f.
|
||||
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
|
||||
outdegree_map = graph
|
||||
reverse_graph = {}
|
||||
reverse_graph = {} # type: Dict[str, Set[str]]
|
||||
|
||||
# Lists of nodes with zero out degree. Is actually a tuple of
|
||||
# `(key(node), node)` so that sorting does the right thing
|
||||
|
||||
@@ -16,9 +16,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util import json_encoder
|
||||
|
||||
from . import engines
|
||||
|
||||
@@ -457,7 +456,7 @@ class BackgroundUpdater(object):
|
||||
progress(dict): The progress of the update.
|
||||
"""
|
||||
|
||||
progress_json = json.dumps(progress)
|
||||
progress_json = json_encoder.encode(progress)
|
||||
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
|
||||
+358
-219
File diff suppressed because it is too large
Load Diff
@@ -498,7 +498,7 @@ class DataStore(
|
||||
)
|
||||
|
||||
def get_users_paginate(
|
||||
self, start, limit, name=None, guests=True, deactivated=False
|
||||
self, start, limit, user_id=None, name=None, guests=True, deactivated=False
|
||||
):
|
||||
"""Function to retrieve a paginated list of users from
|
||||
users list. This will return a json list of users and the
|
||||
@@ -507,7 +507,8 @@ class DataStore(
|
||||
Args:
|
||||
start (int): start number to begin the query from
|
||||
limit (int): number of rows to retrieve
|
||||
name (string): filter for user names
|
||||
user_id (string): search for user_id. ignored if name is not None
|
||||
name (string): search for local part of user_id or display name
|
||||
guests (bool): whether to in include guest users
|
||||
deactivated (bool): whether to include deactivated users
|
||||
Returns:
|
||||
@@ -516,11 +517,14 @@ class DataStore(
|
||||
|
||||
def get_users_paginate_txn(txn):
|
||||
filters = []
|
||||
args = []
|
||||
args = [self.hs.config.server_name]
|
||||
|
||||
if name:
|
||||
filters.append("(name LIKE ? OR displayname LIKE ?)")
|
||||
args.extend(["@%" + name + "%:%", "%" + name + "%"])
|
||||
elif user_id:
|
||||
filters.append("name LIKE ?")
|
||||
args.append("%" + name + "%")
|
||||
args.extend(["%" + user_id + "%"])
|
||||
|
||||
if not guests:
|
||||
filters.append("is_guest = 0")
|
||||
@@ -530,20 +534,23 @@ class DataStore(
|
||||
|
||||
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
|
||||
|
||||
sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
|
||||
txn.execute(sql, args)
|
||||
count = txn.fetchone()[0]
|
||||
|
||||
args = [self.hs.config.server_name] + args + [limit, start]
|
||||
sql = """
|
||||
SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
|
||||
sql_base = """
|
||||
FROM users as u
|
||||
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
|
||||
{}
|
||||
ORDER BY u.name LIMIT ? OFFSET ?
|
||||
""".format(
|
||||
where_clause
|
||||
)
|
||||
sql = "SELECT COUNT(*) as total_users " + sql_base
|
||||
txn.execute(sql, args)
|
||||
count = txn.fetchone()[0]
|
||||
|
||||
sql = (
|
||||
"SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
|
||||
+ sql_base
|
||||
+ " ORDER BY u.name LIMIT ? OFFSET ?"
|
||||
)
|
||||
args += [limit, start]
|
||||
txn.execute(sql, args)
|
||||
users = self.db_pool.cursor_to_dict(txn)
|
||||
return users, count
|
||||
|
||||
@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
"""
|
||||
content_json = json_encoder.encode(content)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
# no need to lock here as room_account_data has a unique constraint
|
||||
# on (user_id, room_id, account_data_type) so simple_upsert will
|
||||
# retry if there is a conflict.
|
||||
@@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
"""
|
||||
content_json = json_encoder.encode(content)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
# no need to lock here as account_data has a unique constraint on
|
||||
# (user_id, account_data_type) so simple_upsert will retry if
|
||||
# there is a conflict.
|
||||
|
||||
@@ -16,13 +16,12 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.appservice import AppServiceTransaction
|
||||
from synapse.config.appservice import load_appservices
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.util import json_encoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -204,7 +203,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
new_txn_id = max(highest_txn_id, last_txn_id) + 1
|
||||
|
||||
# Insert new txn into txn table
|
||||
event_ids = json.dumps([e.event_id for e in events])
|
||||
event_ids = json_encoder.encode([e.event_id for e in events])
|
||||
txn.execute(
|
||||
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
|
||||
"VALUES(?,?,?)",
|
||||
|
||||
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
||||
rows.append((destination, stream_id, now_ms, edu_json))
|
||||
txn.executemany(sql, rows)
|
||||
|
||||
with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
with await self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
await self.db_pool.runInteraction(
|
||||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
||||
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
||||
txn, stream_id, local_messages_by_user_then_device
|
||||
)
|
||||
|
||||
with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
with await self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
await self.db_pool.runInteraction(
|
||||
"add_messages_from_remote_to_device_inbox",
|
||||
|
||||
@@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
THe new stream ID.
|
||||
"""
|
||||
|
||||
with self._device_list_id_gen.get_next() as stream_id:
|
||||
with await self._device_list_id_gen.get_next() as stream_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_user_sig_change_to_streams",
|
||||
self._add_user_signature_change_txn,
|
||||
@@ -1146,7 +1146,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
if not device_ids:
|
||||
return
|
||||
|
||||
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
|
||||
with await self._device_list_id_gen.get_next_mult(
|
||||
len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_device_change_to_stream",
|
||||
self._add_device_change_to_stream_txn,
|
||||
@@ -1159,7 +1161,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
return stream_ids[-1]
|
||||
|
||||
context = get_active_span_text_map()
|
||||
with self._device_list_id_gen.get_next_mult(
|
||||
with await self._device_list_id_gen.get_next_mult(
|
||||
len(hosts) * len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
|
||||
@@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||
)
|
||||
|
||||
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
|
||||
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
|
||||
"""Set a user's cross-signing key.
|
||||
|
||||
Args:
|
||||
@@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
for a master key, 'self_signing' for a self-signing key, or
|
||||
'user_signing' for a user-signing key
|
||||
key (dict): the key data
|
||||
stream_id (int)
|
||||
"""
|
||||
# the 'key' dict will look something like:
|
||||
# {
|
||||
@@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
)
|
||||
|
||||
# and finally, store the key itself
|
||||
with self._cross_signing_id_gen.get_next() as stream_id:
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
"e2e_cross_signing_keys",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"keytype": key_type,
|
||||
"keydata": json_encoder.encode(key),
|
||||
"stream_id": stream_id,
|
||||
},
|
||||
)
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
"e2e_cross_signing_keys",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"keytype": key_type,
|
||||
"keydata": json_encoder.encode(key),
|
||||
"stream_id": stream_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
|
||||
)
|
||||
|
||||
def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
||||
async def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
||||
"""Set a user's cross-signing key.
|
||||
|
||||
Args:
|
||||
@@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
key_type (str): the type of cross-signing key to set
|
||||
key (dict): the key data
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
"add_e2e_cross_signing_key",
|
||||
self._set_e2e_cross_signing_key_txn,
|
||||
user_id,
|
||||
key_type,
|
||||
key,
|
||||
)
|
||||
|
||||
with await self._cross_signing_id_gen.get_next() as stream_id:
|
||||
return await self.db_pool.runInteraction(
|
||||
"add_e2e_cross_signing_key",
|
||||
self._set_e2e_cross_signing_key_txn,
|
||||
user_id,
|
||||
key_type,
|
||||
key,
|
||||
stream_id,
|
||||
)
|
||||
|
||||
def store_e2e_cross_signing_signatures(self, user_id, signatures):
|
||||
"""Stores cross-signing signatures.
|
||||
|
||||
@@ -15,14 +15,16 @@
|
||||
import itertools
|
||||
import logging
|
||||
from queue import Empty, PriorityQueue
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import Dict, Iterable, List, Set, Tuple
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.events import EventBase
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.types import Collection
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
@@ -30,12 +32,14 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
|
||||
async def get_auth_chain(self, event_ids, include_given=False):
|
||||
async def get_auth_chain(
|
||||
self, event_ids: Collection[str], include_given: bool = False
|
||||
) -> List[EventBase]:
|
||||
"""Get auth events for given event_ids. The events *must* be state events.
|
||||
|
||||
Args:
|
||||
event_ids (list): state events
|
||||
include_given (bool): include the given events in result
|
||||
event_ids: state events
|
||||
include_given: include the given events in result
|
||||
|
||||
Returns:
|
||||
list of events
|
||||
@@ -45,43 +49,34 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
)
|
||||
return await self.get_events_as_list(event_ids)
|
||||
|
||||
def get_auth_chain_ids(
|
||||
self,
|
||||
event_ids: List[str],
|
||||
include_given: bool = False,
|
||||
ignore_events: Optional[Set[str]] = None,
|
||||
):
|
||||
async def get_auth_chain_ids(
|
||||
self, event_ids: Collection[str], include_given: bool = False,
|
||||
) -> List[str]:
|
||||
"""Get auth events for given event_ids. The events *must* be state events.
|
||||
|
||||
Args:
|
||||
event_ids: state events
|
||||
include_given: include the given events in result
|
||||
ignore_events: Set of events to exclude from the returned auth
|
||||
chain. This is useful if the caller will just discard the
|
||||
given events anyway, and saves us from figuring out their auth
|
||||
chains if not required.
|
||||
|
||||
Returns:
|
||||
list of event_ids
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_ids",
|
||||
self._get_auth_chain_ids_txn,
|
||||
event_ids,
|
||||
include_given,
|
||||
ignore_events,
|
||||
)
|
||||
|
||||
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
|
||||
if ignore_events is None:
|
||||
ignore_events = set()
|
||||
|
||||
def _get_auth_chain_ids_txn(
|
||||
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
|
||||
) -> List[str]:
|
||||
if include_given:
|
||||
results = set(event_ids)
|
||||
else:
|
||||
results = set()
|
||||
|
||||
base_sql = "SELECT auth_id FROM event_auth WHERE "
|
||||
base_sql = "SELECT DISTINCT auth_id FROM event_auth WHERE "
|
||||
|
||||
front = set(event_ids)
|
||||
while front:
|
||||
@@ -93,7 +88,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
txn.execute(base_sql + clause, args)
|
||||
new_front.update(r[0] for r in txn)
|
||||
|
||||
new_front -= ignore_events
|
||||
new_front -= results
|
||||
|
||||
front = new_front
|
||||
|
||||
@@ -153,11 +153,11 @@ class PersistEventsStore:
|
||||
# Note: Multiple instances of this function cannot be in flight at
|
||||
# the same time for the same room.
|
||||
if backfilled:
|
||||
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
|
||||
stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
)
|
||||
else:
|
||||
stream_ordering_manager = self._stream_id_gen.get_next_mult(
|
||||
stream_ordering_manager = await self._stream_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
)
|
||||
|
||||
|
||||
@@ -620,19 +620,38 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
room_version_id = row["room_version_id"]
|
||||
|
||||
if not room_version_id:
|
||||
# this should only happen for out-of-band membership events
|
||||
if not internal_metadata.get("out_of_band_membership"):
|
||||
logger.warning(
|
||||
"Room %s for event %s is unknown", d["room_id"], event_id
|
||||
# this should only happen for out-of-band membership events which
|
||||
# arrived before #6983 landed. For all other events, we should have
|
||||
# an entry in the 'rooms' table.
|
||||
#
|
||||
# However, the 'out_of_band_membership' flag is unreliable for older
|
||||
# invites, so just accept it for all membership events.
|
||||
#
|
||||
if d["type"] != EventTypes.Member:
|
||||
raise Exception(
|
||||
"Room %s for event %s is unknown" % (d["room_id"], event_id)
|
||||
)
|
||||
continue
|
||||
|
||||
# take a wild stab at the room version based on the event format
|
||||
# so, assuming this is an out-of-band-invite that arrived before #6983
|
||||
# landed, we know that the room version must be v5 or earlier (because
|
||||
# v6 hadn't been invented at that point, so invites from such rooms
|
||||
# would have been rejected.)
|
||||
#
|
||||
# The main reason we need to know the room version here (other than
|
||||
# choosing the right python Event class) is in case the event later has
|
||||
# to be redacted - and all the room versions up to v5 used the same
|
||||
# redaction algorithm.
|
||||
#
|
||||
# So, the following approximations should be adequate.
|
||||
|
||||
if format_version == EventFormatVersions.V1:
|
||||
# if it's event format v1 then it must be room v1 or v2
|
||||
room_version = RoomVersions.V1
|
||||
elif format_version == EventFormatVersions.V2:
|
||||
# if it's event format v2 then it must be room v3
|
||||
room_version = RoomVersions.V3
|
||||
else:
|
||||
# if it's event format v3 then it must be room v4 or v5
|
||||
room_version = RoomVersions.V5
|
||||
else:
|
||||
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
|
||||
|
||||
@@ -1182,7 +1182,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
|
||||
return next_id
|
||||
|
||||
with self._group_updates_id_gen.get_next() as next_id:
|
||||
with await self._group_updates_id_gen.get_next() as next_id:
|
||||
res = await self.db_pool.runInteraction(
|
||||
"register_user_group_membership",
|
||||
_register_user_group_membership_txn,
|
||||
|
||||
@@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter
|
||||
|
||||
class PresenceStore(SQLBaseStore):
|
||||
async def update_presence(self, presence_states):
|
||||
stream_ordering_manager = self._presence_id_gen.get_next_mult(
|
||||
stream_ordering_manager = await self._presence_id_gen.get_next_mult(
|
||||
len(presence_states)
|
||||
)
|
||||
|
||||
|
||||
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
) -> None:
|
||||
conditions_json = json_encoder.encode(conditions)
|
||||
actions_json = json_encoder.encode(actions)
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
if before or after:
|
||||
@@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
|
||||
)
|
||||
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
@@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
)
|
||||
|
||||
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
@@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
data={"actions": actions_json},
|
||||
)
|
||||
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
||||
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
|
||||
last_stream_ordering,
|
||||
profile_tag="",
|
||||
) -> None:
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
with await self._pushers_id_gen.get_next() as stream_id:
|
||||
# no need to lock because `pushers` has a unique key on
|
||||
# (app_id, pushkey, user_name) so simple_upsert will retry
|
||||
await self.db_pool.simple_upsert(
|
||||
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
|
||||
},
|
||||
)
|
||||
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
with await self._pushers_id_gen.get_next() as stream_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_pusher", delete_pusher_txn, stream_id
|
||||
)
|
||||
|
||||
@@ -520,8 +520,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
"insert_receipt_conv", graph_to_linear
|
||||
)
|
||||
|
||||
stream_id_manager = self._receipts_id_gen.get_next()
|
||||
with stream_id_manager as stream_id:
|
||||
with await self._receipts_id_gen.get_next() as stream_id:
|
||||
event_ts = await self.db_pool.runInteraction(
|
||||
"insert_linearized_receipt",
|
||||
self.insert_linearized_receipt_txn,
|
||||
|
||||
@@ -968,6 +968,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
super(RegistrationStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
self._account_validity = hs.config.account_validity
|
||||
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
|
||||
|
||||
if self._account_validity.enabled:
|
||||
self._clock.call_later(
|
||||
@@ -1381,15 +1382,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
)
|
||||
|
||||
if not row:
|
||||
raise ThreepidValidationError(400, "Unknown session_id")
|
||||
if self._ignore_unknown_session_error:
|
||||
# If we need to inhibit the error caused by an incorrect session ID,
|
||||
# use None as placeholder values for the client secret and the
|
||||
# validation timestamp.
|
||||
# It shouldn't be an issue because they're both only checked after
|
||||
# the token check, which should fail. And if it doesn't for some
|
||||
# reason, the next check is on the client secret, which is NOT NULL,
|
||||
# so we don't have to worry about the client secret matching by
|
||||
# accident.
|
||||
row = {"client_secret": None, "validated_at": None}
|
||||
else:
|
||||
raise ThreepidValidationError(400, "Unknown session_id")
|
||||
|
||||
retrieved_client_secret = row["client_secret"]
|
||||
validated_at = row["validated_at"]
|
||||
|
||||
if retrieved_client_secret != client_secret:
|
||||
raise ThreepidValidationError(
|
||||
400, "This client_secret does not match the provided session_id"
|
||||
)
|
||||
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="threepid_validation_token",
|
||||
@@ -1405,6 +1413,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
expires = row["expires"]
|
||||
next_link = row["next_link"]
|
||||
|
||||
if retrieved_client_secret != client_secret:
|
||||
raise ThreepidValidationError(
|
||||
400, "This client_secret does not match the provided session_id"
|
||||
)
|
||||
|
||||
# If the session is already validated, no need to revalidate
|
||||
if validated_at:
|
||||
return next_link
|
||||
|
||||
@@ -21,10 +21,6 @@ from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.api.room_versions import RoomVersion, RoomVersions
|
||||
@@ -32,6 +28,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.databases.main.search import SearchStore
|
||||
from synapse.types import ThirdPartyInstanceID
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -342,23 +339,22 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
desc="is_room_blocked",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_room_published(self, room_id):
|
||||
async def is_room_published(self, room_id: str) -> bool:
|
||||
"""Check whether a room has been published in the local public room
|
||||
directory.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id
|
||||
Returns:
|
||||
bool: Whether the room is currently published in the room directory
|
||||
Whether the room is currently published in the room directory
|
||||
"""
|
||||
# Get room information
|
||||
room_info = yield self.get_room(room_id)
|
||||
room_info = await self.get_room(room_id)
|
||||
if not room_info:
|
||||
defer.returnValue(False)
|
||||
return False
|
||||
|
||||
# Check the is_public value
|
||||
defer.returnValue(room_info.get("is_public", False))
|
||||
return room_info.get("is_public", False)
|
||||
|
||||
async def get_rooms_paginate(
|
||||
self,
|
||||
@@ -572,7 +568,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
# maximum, in order not to filter out events we should filter out when sending to
|
||||
# the client.
|
||||
if not self.config.retention_enabled:
|
||||
defer.returnValue({"min_lifetime": None, "max_lifetime": None})
|
||||
return {"min_lifetime": None, "max_lifetime": None}
|
||||
|
||||
def get_retention_policy_for_room_txn(txn):
|
||||
txn.execute(
|
||||
@@ -1155,7 +1151,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||
},
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"store_room_txn", store_room_txn, next_id
|
||||
)
|
||||
@@ -1222,7 +1218,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||
},
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"set_room_is_public", set_room_is_public_txn, next_id
|
||||
)
|
||||
@@ -1302,7 +1298,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||
},
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"set_room_is_public_appservice",
|
||||
set_room_is_public_appservice_txn,
|
||||
@@ -1335,7 +1331,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||
"event_id": event_id,
|
||||
"user_id": user_id,
|
||||
"reason": reason,
|
||||
"content": json.dumps(content),
|
||||
"content": json_encoder.encode(content),
|
||||
},
|
||||
desc="add_event_report",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- A table of the IP address and user-agent used to complete each step of a
|
||||
-- user-interactive authentication session.
|
||||
CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips(
|
||||
session_id TEXT NOT NULL,
|
||||
ip TEXT NOT NULL,
|
||||
user_agent TEXT NOT NULL,
|
||||
UNIQUE (session_id, ip, user_agent),
|
||||
FOREIGN KEY (session_id)
|
||||
REFERENCES ui_auth_sessions (session_id)
|
||||
);
|
||||
@@ -17,11 +17,10 @@
|
||||
import logging
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.storage._base import db_to_json
|
||||
from synapse.storage.databases.main.account_data import AccountDataWorkerStore
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -98,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||
txn.execute(sql, (user_id, room_id))
|
||||
tags = []
|
||||
for tag, content in txn:
|
||||
tags.append(json.dumps(tag) + ":" + content)
|
||||
tags.append(json_encoder.encode(tag) + ":" + content)
|
||||
tag_json = "{" + ",".join(tags) + "}"
|
||||
results.append((stream_id, (user_id, room_id, tag_json)))
|
||||
|
||||
@@ -200,7 +199,7 @@ class TagsStore(TagsWorkerStore):
|
||||
Returns:
|
||||
The next account data ID.
|
||||
"""
|
||||
content_json = json.dumps(content)
|
||||
content_json = json_encoder.encode(content)
|
||||
|
||||
def add_tag_txn(txn, next_id):
|
||||
self.db_pool.simple_upsert_txn(
|
||||
@@ -211,7 +210,7 @@ class TagsStore(TagsWorkerStore):
|
||||
)
|
||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
||||
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
@@ -233,7 +232,7 @@ class TagsStore(TagsWorkerStore):
|
||||
txn.execute(sql, (user_id, room_id, tag))
|
||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
||||
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
|
||||
@@ -12,15 +12,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import stringutils as stringutils
|
||||
from synapse.util import json_encoder, stringutils
|
||||
|
||||
|
||||
@attr.s
|
||||
@@ -72,7 +72,7 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||
StoreError if a unique session ID cannot be generated.
|
||||
"""
|
||||
# The clientdict gets stored as JSON.
|
||||
clientdict_json = json.dumps(clientdict)
|
||||
clientdict_json = json_encoder.encode(clientdict)
|
||||
|
||||
# autogen a session ID and try to create it. We may clash, so just
|
||||
# try a few times till one goes through, giving up eventually.
|
||||
@@ -143,7 +143,7 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||
await self.db_pool.simple_upsert(
|
||||
table="ui_auth_sessions_credentials",
|
||||
keyvalues={"session_id": session_id, "stage_type": stage_type},
|
||||
values={"result": json.dumps(result)},
|
||||
values={"result": json_encoder.encode(result)},
|
||||
desc="mark_ui_auth_stage_complete",
|
||||
)
|
||||
except self.db_pool.engine.module.IntegrityError:
|
||||
@@ -184,7 +184,7 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||
The dictionary from the client root level, not the 'auth' key.
|
||||
"""
|
||||
# The clientdict gets stored as JSON.
|
||||
clientdict_json = json.dumps(clientdict)
|
||||
clientdict_json = json_encoder.encode(clientdict)
|
||||
|
||||
await self.db_pool.simple_update_one(
|
||||
table="ui_auth_sessions",
|
||||
@@ -214,14 +214,16 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||
value,
|
||||
)
|
||||
|
||||
def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
|
||||
def _set_ui_auth_session_data_txn(
|
||||
self, txn: LoggingTransaction, session_id: str, key: str, value: Any
|
||||
):
|
||||
# Get the current value.
|
||||
result = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("serverdict",),
|
||||
)
|
||||
) # type: Dict[str, Any] # type: ignore
|
||||
|
||||
# Update it and add it back to the database.
|
||||
serverdict = db_to_json(result["serverdict"])
|
||||
@@ -231,7 +233,7 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||
txn,
|
||||
table="ui_auth_sessions",
|
||||
keyvalues={"session_id": session_id},
|
||||
updatevalues={"serverdict": json.dumps(serverdict)},
|
||||
updatevalues={"serverdict": json_encoder.encode(serverdict)},
|
||||
)
|
||||
|
||||
async def get_ui_auth_session_data(
|
||||
@@ -258,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||
|
||||
return serverdict.get(key, default)
|
||||
|
||||
async def add_user_agent_ip_to_ui_auth_session(
|
||||
self, session_id: str, user_agent: str, ip: str,
|
||||
):
|
||||
"""Add the given user agent / IP to the tracking table
|
||||
"""
|
||||
await self.db_pool.simple_upsert(
|
||||
table="ui_auth_sessions_ips",
|
||||
keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip},
|
||||
values={},
|
||||
desc="add_user_agent_ip_to_ui_auth_session",
|
||||
)
|
||||
|
||||
async def get_user_agents_ips_to_ui_auth_session(
|
||||
self, session_id: str,
|
||||
) -> List[Tuple[str, str]]:
|
||||
"""Get the given user agents / IPs used during the ui auth process
|
||||
|
||||
Returns:
|
||||
List of user_agent/ip pairs
|
||||
"""
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
table="ui_auth_sessions_ips",
|
||||
keyvalues={"session_id": session_id},
|
||||
retcols=("user_agent", "ip"),
|
||||
desc="get_user_agents_ips_to_ui_auth_session",
|
||||
)
|
||||
return [(row["user_agent"], row["ip"]) for row in rows]
|
||||
|
||||
|
||||
class UIAuthStore(UIAuthWorkerStore):
|
||||
def delete_old_ui_auth_sessions(self, expiration_time: int):
|
||||
@@ -275,12 +305,23 @@ class UIAuthStore(UIAuthWorkerStore):
|
||||
expiration_time,
|
||||
)
|
||||
|
||||
def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
|
||||
def _delete_old_ui_auth_sessions_txn(
|
||||
self, txn: LoggingTransaction, expiration_time: int
|
||||
):
|
||||
# Get the expired sessions.
|
||||
sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
|
||||
txn.execute(sql, [expiration_time])
|
||||
session_ids = [r[0] for r in txn.fetchall()]
|
||||
|
||||
# Delete the corresponding IP/user agents.
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="ui_auth_sessions_ips",
|
||||
column="session_id",
|
||||
iterable=session_ids,
|
||||
keyvalues={},
|
||||
)
|
||||
|
||||
# Delete the corresponding completed credentials.
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
|
||||
@@ -14,9 +14,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import heapq
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Dict, Set
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from typing_extensions import Deque
|
||||
|
||||
@@ -80,7 +81,7 @@ class StreamIdGenerator(object):
|
||||
upwards, -1 to grow downwards.
|
||||
|
||||
Usage:
|
||||
with stream_id_gen.get_next() as stream_id:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
@@ -95,10 +96,10 @@ class StreamIdGenerator(object):
|
||||
)
|
||||
self._unfinished_ids = deque() # type: Deque[int]
|
||||
|
||||
def get_next(self):
|
||||
async def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
with stream_id_gen.get_next() as stream_id:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
with self._lock:
|
||||
@@ -117,10 +118,10 @@ class StreamIdGenerator(object):
|
||||
|
||||
return manager()
|
||||
|
||||
def get_next_mult(self, n):
|
||||
async def get_next_mult(self, n):
|
||||
"""
|
||||
Usage:
|
||||
with stream_id_gen.get_next(n) as stream_ids:
|
||||
with await stream_id_gen.get_next(n) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
with self._lock:
|
||||
@@ -210,6 +211,23 @@ class MultiWriterIdGenerator:
|
||||
# should be less than the minimum of this set (if not empty).
|
||||
self._unfinished_ids = set() # type: Set[int]
|
||||
|
||||
# We track the max position where we know everything before has been
|
||||
# persisted. This is done by a) looking at the min across all instances
|
||||
# and b) noting that if we have seen a run of persisted positions
|
||||
# without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
|
||||
#
|
||||
# Note: There is no guarentee that the IDs generated by the sequence
|
||||
# will be gapless; gaps can form when e.g. a transaction was rolled
|
||||
# back. This means that sometimes we won't be able to skip forward the
|
||||
# position even though everything has been persisted. However, since
|
||||
# gaps should be relatively rare it's still worth doing the book keeping
|
||||
# that allows us to skip forwards when there are gapless runs of
|
||||
# positions.
|
||||
self._persisted_upto_position = (
|
||||
min(self._current_positions.values()) if self._current_positions else 0
|
||||
)
|
||||
self._known_persisted_positions = [] # type: List[int]
|
||||
|
||||
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
|
||||
|
||||
def _load_current_ids(
|
||||
@@ -234,9 +252,12 @@ class MultiWriterIdGenerator:
|
||||
|
||||
return current_positions
|
||||
|
||||
def _load_next_id_txn(self, txn):
|
||||
def _load_next_id_txn(self, txn) -> int:
|
||||
return self._sequence_gen.get_next_id_txn(txn)
|
||||
|
||||
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
|
||||
return self._sequence_gen.get_next_mult_txn(txn, n)
|
||||
|
||||
async def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
@@ -262,6 +283,34 @@ class MultiWriterIdGenerator:
|
||||
|
||||
return manager()
|
||||
|
||||
async def get_next_mult(self, n: int):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next_mult(5) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
next_ids = await self._db.runInteraction(
|
||||
"_load_next_mult_id", self._load_next_mult_id_txn, n
|
||||
)
|
||||
|
||||
# Assert the fetched ID is actually greater than any ID we've already
|
||||
# seen. If not, then the sequence and table have got out of sync
|
||||
# somehow.
|
||||
assert max(self.get_positions().values(), default=0) < min(next_ids)
|
||||
|
||||
with self._lock:
|
||||
self._unfinished_ids.update(next_ids)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_ids
|
||||
finally:
|
||||
for i in next_ids:
|
||||
self._mark_id_as_finished(i)
|
||||
|
||||
return manager()
|
||||
|
||||
def get_next_txn(self, txn: LoggingTransaction):
|
||||
"""
|
||||
Usage:
|
||||
@@ -326,3 +375,53 @@ class MultiWriterIdGenerator:
|
||||
self._current_positions[instance_name] = max(
|
||||
new_id, self._current_positions.get(instance_name, 0)
|
||||
)
|
||||
|
||||
self._add_persisted_position(new_id)
|
||||
|
||||
def get_persisted_upto_position(self) -> int:
|
||||
"""Get the max position where all previous positions have been
|
||||
persisted.
|
||||
|
||||
Note: In the worst case scenario this will be equal to the minimum
|
||||
position across writers. This means that the returned position here can
|
||||
lag if one writer doesn't write very often.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return self._persisted_upto_position
|
||||
|
||||
def _add_persisted_position(self, new_id: int):
|
||||
"""Record that we have persisted a position.
|
||||
|
||||
This is used to keep the `_current_positions` up to date.
|
||||
"""
|
||||
|
||||
# We require that the lock is locked by caller
|
||||
assert self._lock.locked()
|
||||
|
||||
heapq.heappush(self._known_persisted_positions, new_id)
|
||||
|
||||
# We move the current min position up if the minimum current positions
|
||||
# of all instances is higher (since by definition all positions less
|
||||
# that that have been persisted).
|
||||
min_curr = min(self._current_positions.values())
|
||||
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
|
||||
|
||||
# We now iterate through the seen positions, discarding those that are
|
||||
# less than the current min positions, and incrementing the min position
|
||||
# if its exactly one greater.
|
||||
#
|
||||
# This is also where we discard items from `_known_persisted_positions`
|
||||
# (to ensure the list doesn't infinitely grow).
|
||||
while self._known_persisted_positions:
|
||||
if self._known_persisted_positions[0] <= self._persisted_upto_position:
|
||||
heapq.heappop(self._known_persisted_positions)
|
||||
elif (
|
||||
self._known_persisted_positions[0] == self._persisted_upto_position + 1
|
||||
):
|
||||
heapq.heappop(self._known_persisted_positions)
|
||||
self._persisted_upto_position += 1
|
||||
else:
|
||||
# There was a gap in seen positions, so there is nothing more to
|
||||
# do.
|
||||
break
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import threading
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
from synapse.storage.types import Cursor
|
||||
@@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
||||
txn.execute("SELECT nextval(?)", (self._sequence_name,))
|
||||
return txn.fetchone()[0]
|
||||
|
||||
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
||||
txn.execute(
|
||||
"SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
|
||||
)
|
||||
return [i for (i,) in txn]
|
||||
|
||||
|
||||
GetFirstCallbackType = Callable[[Cursor], int]
|
||||
|
||||
|
||||
@@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
|
||||
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
|
||||
self.handler._auth_handler.complete_sso_login = simple_async_mock()
|
||||
request = Mock(spec=["args", "getCookie", "addCookie"])
|
||||
request = Mock(
|
||||
spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
|
||||
)
|
||||
|
||||
code = "code"
|
||||
state = "state"
|
||||
nonce = "nonce"
|
||||
client_redirect_url = "http://client/redirect"
|
||||
user_agent = "Browser"
|
||||
ip_address = "10.0.0.1"
|
||||
session = self.handler._generate_oidc_session_token(
|
||||
state=state,
|
||||
nonce=nonce,
|
||||
@@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
request.args[b"code"] = [code.encode("utf-8")]
|
||||
request.args[b"state"] = [state.encode("utf-8")]
|
||||
|
||||
request.requestHeaders = Mock(spec=["getRawHeaders"])
|
||||
request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
|
||||
request.getClientIP.return_value = ip_address
|
||||
|
||||
yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
|
||||
|
||||
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
|
||||
@@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.handler._exchange_code.assert_called_once_with(code)
|
||||
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
|
||||
self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
|
||||
self.handler._map_userinfo_to_user.assert_called_once_with(
|
||||
userinfo, token, user_agent, ip_address
|
||||
)
|
||||
self.handler._fetch_userinfo.assert_not_called()
|
||||
self.handler._render_error.assert_not_called()
|
||||
|
||||
@@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.handler._exchange_code.assert_called_once_with(code)
|
||||
self.handler._parse_id_token.assert_not_called()
|
||||
self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token)
|
||||
self.handler._map_userinfo_to_user.assert_called_once_with(
|
||||
userinfo, token, user_agent, ip_address
|
||||
)
|
||||
self.handler._fetch_userinfo.assert_called_once_with(token)
|
||||
self.handler._render_error.assert_not_called()
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.auth import Auth
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
|
||||
from synapse.handlers.register import RegistrationHandler
|
||||
@@ -25,16 +26,18 @@ from synapse.rest.client.v2_alpha.register import (
|
||||
_map_email_to_displayname,
|
||||
register_servlets,
|
||||
)
|
||||
from synapse.spam_checker_api import RegistrationBehaviour
|
||||
from synapse.types import RoomAlias, UserID, create_requester
|
||||
|
||||
from tests.server import FakeChannel
|
||||
from tests.test_utils import make_awaitable
|
||||
from tests.unittest import override_config
|
||||
from tests.utils import mock_getRawHeaders
|
||||
|
||||
from .. import unittest
|
||||
|
||||
|
||||
class RegistrationHandlers(object):
|
||||
class RegistrationHandlers:
|
||||
def __init__(self, hs):
|
||||
self.registration_handler = RegistrationHandler(hs)
|
||||
|
||||
@@ -485,6 +488,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
|
||||
self.handler.register_user(localpart=invalid_user_id), SynapseError
|
||||
)
|
||||
|
||||
def test_spam_checker_deny(self):
|
||||
"""A spam checker can deny registration, which results in an error."""
|
||||
|
||||
class DenyAll:
|
||||
def check_registration_for_spam(
|
||||
self, email_threepid, username, request_info
|
||||
):
|
||||
return RegistrationBehaviour.DENY
|
||||
|
||||
# Configure a spam checker that denies all users.
|
||||
spam_checker = self.hs.get_spam_checker()
|
||||
spam_checker.spam_checkers = [DenyAll()]
|
||||
|
||||
self.get_failure(self.handler.register_user(localpart="user"), SynapseError)
|
||||
|
||||
def test_spam_checker_shadow_ban(self):
|
||||
"""A spam checker can choose to shadow-ban a user, which allows registration to succeed."""
|
||||
|
||||
class BanAll:
|
||||
def check_registration_for_spam(
|
||||
self, email_threepid, username, request_info
|
||||
):
|
||||
return RegistrationBehaviour.SHADOW_BAN
|
||||
|
||||
# Configure a spam checker that denies all users.
|
||||
spam_checker = self.hs.get_spam_checker()
|
||||
spam_checker.spam_checkers = [BanAll()]
|
||||
|
||||
user_id = self.get_success(self.handler.register_user(localpart="user"))
|
||||
|
||||
# Get an access token.
|
||||
token = self.macaroon_generator.generate_access_token(user_id)
|
||||
self.get_success(
|
||||
self.store.add_access_token_to_user(
|
||||
user_id=user_id, token=token, device_id=None, valid_until_ms=None
|
||||
)
|
||||
)
|
||||
|
||||
# Ensure the user was marked as shadow-banned.
|
||||
request = Mock(args={})
|
||||
request.args[b"access_token"] = [token.encode("ascii")]
|
||||
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
|
||||
auth = Auth(self.hs)
|
||||
requester = self.get_success(auth.get_user_by_req(request))
|
||||
|
||||
self.assertTrue(requester.shadow_banned)
|
||||
|
||||
def test_email_to_displayname_mapping(self):
|
||||
"""Test that custom emails are mapped to new user displaynames correctly"""
|
||||
self._check_mapping(
|
||||
|
||||
@@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_spam_checker(self):
|
||||
"""
|
||||
A user which fails to the spam checks will not appear in search results.
|
||||
A user which fails the spam checks will not appear in search results.
|
||||
"""
|
||||
u1 = self.register_user("user1", "pass")
|
||||
u1_token = self.login(u1, "pass")
|
||||
@@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
# Configure a spam checker that does not filter any users.
|
||||
spam_checker = self.hs.get_spam_checker()
|
||||
|
||||
class AllowAll(object):
|
||||
class AllowAll:
|
||||
def check_username_for_spam(self, user_profile):
|
||||
# Allow all users.
|
||||
return False
|
||||
@@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(len(s["results"]), 1)
|
||||
|
||||
# Configure a spam checker that filters all users.
|
||||
class BlockAll(object):
|
||||
class BlockAll:
|
||||
def check_username_for_spam(self, user_profile):
|
||||
# All users are spammy.
|
||||
return True
|
||||
|
||||
@@ -46,33 +46,16 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
|
||||
self.hs = self.setup_test_homeserver(config=config)
|
||||
|
||||
return self.hs
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.user_id = self.register_user("user", "password")
|
||||
self.token = self.login("user", "password")
|
||||
|
||||
def test_retention_state_event(self):
|
||||
"""Tests that the server configuration can limit the values a user can set to the
|
||||
room's retention policy.
|
||||
"""
|
||||
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||
|
||||
self.helper.send_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Retention,
|
||||
body={"max_lifetime": one_day_ms * 4},
|
||||
tok=self.token,
|
||||
expect_code=400,
|
||||
)
|
||||
|
||||
self.helper.send_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Retention,
|
||||
body={"max_lifetime": one_hour_ms},
|
||||
tok=self.token,
|
||||
expect_code=400,
|
||||
)
|
||||
self.store = self.hs.get_datastore()
|
||||
self.serializer = self.hs.get_event_client_serializer()
|
||||
self.clock = self.hs.get_clock()
|
||||
|
||||
def test_retention_event_purged_with_state_event(self):
|
||||
"""Tests that expired events are correctly purged when the room's retention policy
|
||||
@@ -91,6 +74,36 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self._test_retention_event_purged(room_id, one_day_ms * 1.5)
|
||||
|
||||
def test_retention_event_purged_with_state_event_outside_allowed(self):
|
||||
"""Tests that the server configuration can override the policy for a room when
|
||||
running the purge jobs.
|
||||
"""
|
||||
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||
|
||||
# Set a max_lifetime higher than the maximum allowed value.
|
||||
self.helper.send_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Retention,
|
||||
body={"max_lifetime": one_day_ms * 4},
|
||||
tok=self.token,
|
||||
)
|
||||
|
||||
# Check that the event is purged after waiting for the maximum allowed duration
|
||||
# instead of the one specified in the room's policy.
|
||||
self._test_retention_event_purged(room_id, one_day_ms * 1.5)
|
||||
|
||||
# Set a max_lifetime lower than the minimum allowed value.
|
||||
self.helper.send_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Retention,
|
||||
body={"max_lifetime": one_hour_ms},
|
||||
tok=self.token,
|
||||
)
|
||||
|
||||
# Check that the event is purged after waiting for the minimum allowed duration
|
||||
# instead of the one specified in the room's policy.
|
||||
self._test_retention_event_purged(room_id, one_day_ms * 0.5)
|
||||
|
||||
def test_retention_event_purged_without_state_event(self):
|
||||
"""Tests that expired events are correctly purged when the room's retention policy
|
||||
is defined by the server's configuration's default retention policy.
|
||||
@@ -141,7 +154,27 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
# That event should be the second, not outdated event.
|
||||
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
|
||||
|
||||
def _test_retention_event_purged(self, room_id, increment):
|
||||
def _test_retention_event_purged(self, room_id: str, increment: float):
|
||||
"""Run the following test scenario to test the message retention policy support:
|
||||
|
||||
1. Send event 1
|
||||
2. Increment time by `increment`
|
||||
3. Send event 2
|
||||
4. Increment time by `increment`
|
||||
5. Check that event 1 has been purged
|
||||
6. Check that event 2 has not been purged
|
||||
7. Check that state events that were sent before event 1 aren't purged.
|
||||
The main reason for sending a second event is because currently Synapse won't
|
||||
purge the latest message in a room because it would otherwise result in a lack of
|
||||
forward extremities for this room. It's also a good thing to ensure the purge jobs
|
||||
aren't too greedy and purge messages they shouldn't.
|
||||
|
||||
Args:
|
||||
room_id: The ID of the room to test retention in.
|
||||
increment: The number of milliseconds to advance the clock each time. Must be
|
||||
defined so that events in the room aren't purged if they are `increment`
|
||||
old but are purged if they are `increment * 2` old.
|
||||
"""
|
||||
# Get the create event to, later, check that we can still access it.
|
||||
message_handler = self.hs.get_message_handler()
|
||||
create_event = self.get_success(
|
||||
@@ -157,7 +190,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
expired_event_id = resp.get("event_id")
|
||||
|
||||
# Check that we can retrieve the event.
|
||||
expired_event = self.get_event(room_id, expired_event_id)
|
||||
expired_event = self.get_event(expired_event_id)
|
||||
self.assertEqual(
|
||||
expired_event.get("content", {}).get("body"), "1", expired_event
|
||||
)
|
||||
@@ -175,26 +208,31 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
# one should still be kept.
|
||||
self.reactor.advance(increment / 1000)
|
||||
|
||||
# Check that the event has been purged from the database.
|
||||
self.get_event(room_id, expired_event_id, expected_code=404)
|
||||
# Check that the first event has been purged from the database, i.e. that we
|
||||
# can't retrieve it anymore, because it has expired.
|
||||
self.get_event(expired_event_id, expect_none=True)
|
||||
|
||||
# Check that the event that hasn't been purged can still be retrieved.
|
||||
valid_event = self.get_event(room_id, valid_event_id)
|
||||
# Check that the event that hasn't expired can still be retrieved.
|
||||
valid_event = self.get_event(valid_event_id)
|
||||
self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
|
||||
|
||||
# Check that we can still access state events that were sent before the event that
|
||||
# has been purged.
|
||||
self.get_event(room_id, create_event.event_id)
|
||||
|
||||
def get_event(self, room_id, event_id, expected_code=200):
|
||||
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
|
||||
def get_event(self, event_id, expect_none=False):
|
||||
event = self.get_success(self.store.get_event(event_id, allow_none=True))
|
||||
|
||||
request, channel = self.make_request("GET", url, access_token=self.token)
|
||||
self.render(request)
|
||||
if expect_none:
|
||||
self.assertIsNone(event)
|
||||
return {}
|
||||
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
self.assertIsNotNone(event)
|
||||
|
||||
return channel.json_body
|
||||
time_now = self.clock.time_msec()
|
||||
serialized = self.get_success(self.serializer.serialize_event(event, time_now))
|
||||
|
||||
return serialized
|
||||
|
||||
|
||||
class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
@@ -21,13 +21,13 @@
|
||||
import json
|
||||
from urllib import parse as urlparse
|
||||
|
||||
from mock import Mock
|
||||
from mock import Mock, patch
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.handlers.pagination import PurgeStatus
|
||||
from synapse.rest.client.v1 import directory, login, profile, room
|
||||
from synapse.rest.client.v2_alpha import account
|
||||
from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet
|
||||
from synapse.types import JsonDict, RoomAlias, UserID
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
@@ -684,38 +684,39 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
]
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
|
||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit(self):
|
||||
"""Tests that local joins are actually rate-limited."""
|
||||
for i in range(5):
|
||||
for i in range(3):
|
||||
self.helper.create_room_as(self.user_id)
|
||||
|
||||
self.helper.create_room_as(self.user_id, expect_code=429)
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
|
||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit_profile_change(self):
|
||||
"""Tests that sending a profile update into all of the user's joined rooms isn't
|
||||
rate-limited by the rate-limiter on joins."""
|
||||
|
||||
# Create and join more rooms than the rate-limiting config allows in a second.
|
||||
# Create and join as many rooms as the rate-limiting config allows in a second.
|
||||
room_ids = [
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
]
|
||||
self.reactor.advance(1)
|
||||
room_ids = room_ids + [
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
]
|
||||
# Let some time for the rate-limiter to forget about our multi-join.
|
||||
self.reactor.advance(2)
|
||||
# Add one to make sure we're joined to more rooms than the config allows us to
|
||||
# join in a second.
|
||||
room_ids.append(self.helper.create_room_as(self.user_id))
|
||||
|
||||
# Create a profile for the user, since it hasn't been done on registration.
|
||||
store = self.hs.get_datastore()
|
||||
store.create_profile(UserID.from_string(self.user_id).localpart)
|
||||
self.get_success(
|
||||
store.create_profile(UserID.from_string(self.user_id).localpart)
|
||||
)
|
||||
|
||||
# Update the display name for the user.
|
||||
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
|
||||
@@ -738,7 +739,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
self.assertEquals(channel.json_body["displayname"], "John Doe")
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
|
||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit_idempotent(self):
|
||||
"""Tests that the room join endpoints remain idempotent despite rate-limiting
|
||||
@@ -754,7 +755,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
for path in paths_to_test:
|
||||
# Make sure we send more requests than the rate-limiting config would allow
|
||||
# if all of these requests ended up joining the user to a room.
|
||||
for i in range(6):
|
||||
for i in range(4):
|
||||
request, channel = self.make_request("POST", path % room_id, {})
|
||||
self.render(request)
|
||||
self.assertEquals(channel.code, 200)
|
||||
@@ -2059,3 +2060,158 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
"""An alias which does not point to the room raises a SynapseError."""
|
||||
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
|
||||
|
||||
|
||||
# To avoid the tests timing out don't add a delay to "annoy the requester".
|
||||
@patch("random.randint", new=lambda a, b: 0)
|
||||
class ShadowBannedTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
directory.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
room_upgrade_rest_servlet.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.banned_user_id = self.register_user("banned", "test")
|
||||
self.banned_access_token = self.login("banned", "test")
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_update(
|
||||
table="users",
|
||||
keyvalues={"name": self.banned_user_id},
|
||||
updatevalues={"shadow_banned": True},
|
||||
desc="shadow_ban",
|
||||
)
|
||||
)
|
||||
|
||||
self.other_user_id = self.register_user("otheruser", "pass")
|
||||
self.other_access_token = self.login("otheruser", "pass")
|
||||
|
||||
def test_invite(self):
|
||||
"""Invites from shadow-banned users don't actually get sent."""
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# Inviting the user completes successfully.
|
||||
self.helper.invite(
|
||||
room=room_id,
|
||||
src=self.banned_user_id,
|
||||
tok=self.banned_access_token,
|
||||
targ=self.other_user_id,
|
||||
)
|
||||
|
||||
# But the user wasn't actually invited.
|
||||
invited_rooms = self.get_success(
|
||||
self.store.get_invited_rooms_for_local_user(self.other_user_id)
|
||||
)
|
||||
self.assertEqual(invited_rooms, [])
|
||||
|
||||
def test_invite_3pid(self):
|
||||
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
identity_handler.lookup_3pid = Mock(
|
||||
side_effect=AssertionError("This should not get called")
|
||||
)
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# Inviting the user completes successfully.
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/rooms/%s/invite" % (room_id,),
|
||||
{"id_server": "test", "medium": "email", "address": "test@test.test"},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
|
||||
# This should have raised an error earlier, but double check this wasn't called.
|
||||
identity_handler.lookup_3pid.assert_not_called()
|
||||
|
||||
def test_create_room(self):
|
||||
"""Invitations during a room creation should be discarded, but the room still gets created."""
|
||||
# The room creation is successful.
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/createRoom",
|
||||
{"visibility": "public", "invite": [self.other_user_id]},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
room_id = channel.json_body["room_id"]
|
||||
|
||||
# But the user wasn't actually invited.
|
||||
invited_rooms = self.get_success(
|
||||
self.store.get_invited_rooms_for_local_user(self.other_user_id)
|
||||
)
|
||||
self.assertEqual(invited_rooms, [])
|
||||
|
||||
# Since a real room was created, the other user should be able to join it.
|
||||
self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
|
||||
|
||||
# Both users should be in the room.
|
||||
users = self.get_success(self.store.get_users_in_room(room_id))
|
||||
self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
|
||||
|
||||
def test_message(self):
|
||||
"""Messages from shadow-banned users don't actually get sent."""
|
||||
|
||||
room_id = self.helper.create_room_as(
|
||||
self.other_user_id, tok=self.other_access_token
|
||||
)
|
||||
|
||||
# The user should be in the room.
|
||||
self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
|
||||
|
||||
# Sending a message should complete successfully.
|
||||
result = self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={"msgtype": "m.text", "body": "with right label"},
|
||||
tok=self.banned_access_token,
|
||||
)
|
||||
self.assertIn("event_id", result)
|
||||
event_id = result["event_id"]
|
||||
|
||||
latest_events = self.get_success(
|
||||
self.store.get_latest_event_ids_in_room(room_id)
|
||||
)
|
||||
self.assertNotIn(event_id, latest_events)
|
||||
|
||||
def test_upgrade(self):
|
||||
"""A room upgrade should fail, but look like it succeeded."""
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
|
||||
{"new_version": "6"},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
# A new room_id should be returned.
|
||||
self.assertIn("replacement_room", channel.json_body)
|
||||
|
||||
new_room_id = channel.json_body["replacement_room"]
|
||||
|
||||
# It doesn't really matter what API we use here, we just want to assert
|
||||
# that the room doesn't exist.
|
||||
summary = self.get_success(self.store.get_room_summary(new_room_id))
|
||||
# The summary should be empty since the room doesn't exist.
|
||||
self.assertEqual(summary, {})
|
||||
|
||||
@@ -182,3 +182,39 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
||||
|
||||
def test_get_persisted_upto_position(self):
|
||||
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||
positions.
|
||||
"""
|
||||
|
||||
self._insert_rows("first", 3)
|
||||
self._insert_rows("second", 5)
|
||||
|
||||
id_gen = self._create_id_generator("first")
|
||||
|
||||
# Min is 3 and there is a gap between 5, so we expect it to be 3.
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||
|
||||
# We advance "first" straight to 6. Min is now 5 but there is no gap so
|
||||
# we expect it to be 6
|
||||
id_gen.advance("first", 6)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||
|
||||
# No gap, so we expect 7.
|
||||
id_gen.advance("second", 7)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||
|
||||
# We haven't seen 8 yet, so we expect 7 still.
|
||||
id_gen.advance("second", 9)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||
|
||||
# Now that we've seen 7, 8 and 9 we can got straight to 9.
|
||||
id_gen.advance("first", 8)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
|
||||
|
||||
# Jump forward with gaps. The minimum is 11, even though we haven't seen
|
||||
# 10 we know that everything before 11 must be persisted.
|
||||
id_gen.advance("first", 11)
|
||||
id_gen.advance("second", 15)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import ThreepidValidationError
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
@@ -122,3 +123,33 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||
)
|
||||
res = yield self.store.is_support_user(SUPPORT_USER)
|
||||
self.assertTrue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_3pid_inhibit_invalid_validation_session_error(self):
|
||||
"""Tests that enabling the configuration option to inhibit 3PID errors on
|
||||
/requestToken also inhibits validation errors caused by an unknown session ID.
|
||||
"""
|
||||
|
||||
# Check that, with the config setting set to false (the default value), a
|
||||
# validation error is caused by the unknown session ID.
|
||||
try:
|
||||
yield defer.ensureDeferred(
|
||||
self.store.validate_threepid_session(
|
||||
"fake_sid", "fake_client_secret", "fake_token", 0,
|
||||
)
|
||||
)
|
||||
except ThreepidValidationError as e:
|
||||
self.assertEquals(e.msg, "Unknown session_id", e)
|
||||
|
||||
# Set the config setting to true.
|
||||
self.store._ignore_unknown_session_error = True
|
||||
|
||||
# Check that now the validation error is caused by the token not matching.
|
||||
try:
|
||||
yield defer.ensureDeferred(
|
||||
self.store.validate_threepid_session(
|
||||
"fake_sid", "fake_client_secret", "fake_token", 0,
|
||||
)
|
||||
)
|
||||
except ThreepidValidationError as e:
|
||||
self.assertEquals(e.msg, "Validation token not found or has expired", e)
|
||||
|
||||
Reference in New Issue
Block a user