Merge branch 'develop' of github.com:matrix-org/synapse into anoa/password_reset_confirmation
* 'develop' of github.com:matrix-org/synapse: (22 commits) Update the test federation client to handle streaming responses (#8130) Do not propagate profile changes of shadow-banned users into rooms. (#8157) Make SlavedIdTracker.advance have same interface as MultiWriterIDGenerator (#8171) Convert simple_select_one and simple_select_one_onecol to async (#8162) Fix rate limiting unit tests. (#8167) Add functions to `MultiWriterIdGen` used by events stream (#8164) Do not allow send_nonmember_event to be called with shadow-banned users. (#8158) Changelog fixes 1.19.1rc1 Make StreamIdGen `get_next` and `get_next_mult` async (#8161) Wording fixes to 'name' user admin api filter (#8163) Fix missing double-backtick in RST document Search in columns 'name' and 'displayname' in the admin users endpoint (#7377) Add type hints for state. (#8140) Stop shadow-banned users from sending non-member events. (#8142) Allow capping a room's retention policy (#8104) Add healthcheck for default localhost 8008 port on /health endpoint. (#8147) Fix flaky shadow-ban tests. (#8152) Fix join ratelimiter breaking profile updates and idempotency (#8153) Do not apply ratelimiting on joins to appservices (#8139) ...
This commit is contained in:
+10
@@ -12,6 +12,16 @@ from Synapse as most users have updated their client. Further context can be
|
||||
found at [\#6766](https://github.com/matrix-org/synapse/issues/6766).
|
||||
|
||||
|
||||
Synapse 1.19.1rc1 (2020-08-25)
|
||||
==============================
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix a bug introduced in v1.19.0 where appservices with ratelimiting disabled would still be ratelimited when joining rooms. ([\#8139](https://github.com/matrix-org/synapse/issues/8139))
|
||||
- Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests. ([\#8153](https://github.com/matrix-org/synapse/issues/8153))
|
||||
|
||||
|
||||
Synapse 1.19.0 (2020-08-17)
|
||||
===========================
|
||||
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH.
|
||||
@@ -0,0 +1 @@
|
||||
Don't fail `/submit_token` requests on incorrect session ID if `request_token_inhibit_3pid_errors` is turned on.
|
||||
@@ -0,0 +1 @@
|
||||
Fix a bug introduced in v1.7.2 impacting message retention policies that would allow federated homeservers to dictate a retention period that's lower than the configured minimum allowed duration in the configuration file.
|
||||
@@ -0,0 +1 @@
|
||||
Update the test federation client to handle streaming responses.
|
||||
@@ -0,0 +1 @@
|
||||
Fixes a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0.
|
||||
@@ -0,0 +1 @@
|
||||
Add type hints to `synapse.state`.
|
||||
@@ -0,0 +1 @@
|
||||
Add support for shadow-banning users (ignoring any message send requests).
|
||||
@@ -0,0 +1 @@
|
||||
Added curl for healthcheck support and readme updates for the change. Contributed by @maquis196.
|
||||
@@ -0,0 +1 @@
|
||||
Add support for shadow-banning users (ignoring any message send requests).
|
||||
@@ -0,0 +1 @@
|
||||
Add support for shadow-banning users (ignoring any message send requests).
|
||||
@@ -0,0 +1 @@
|
||||
Add support for shadow-banning users (ignoring any message send requests).
|
||||
@@ -0,0 +1 @@
|
||||
Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface.
|
||||
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
@@ -0,0 +1 @@
|
||||
Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH.
|
||||
@@ -0,0 +1 @@
|
||||
Add functions to `MultiWriterIdGen` used by events stream.
|
||||
@@ -0,0 +1 @@
|
||||
Fix tests that were broken due to the merge of 1.19.1.
|
||||
@@ -0,0 +1 @@
|
||||
Make `SlavedIdTracker.advance` have the same interface as `MultiWriterIDGenerator`.
|
||||
@@ -55,6 +55,7 @@ RUN pip install --prefix="/install" --no-warn-script-location \
|
||||
FROM docker.io/python:${PYTHON_VERSION}-slim
|
||||
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
libpq5 \
|
||||
xmlsec1 \
|
||||
gosu \
|
||||
@@ -69,3 +70,6 @@ VOLUME ["/data"]
|
||||
EXPOSE 8008/tcp 8009/tcp 8448/tcp
|
||||
|
||||
ENTRYPOINT ["/start.py"]
|
||||
|
||||
HEALTHCHECK --interval=1m --timeout=5s \
|
||||
CMD curl -fSs http://localhost:8008/health || exit 1
|
||||
|
||||
@@ -162,3 +162,32 @@ docker build -t matrixdotorg/synapse -f docker/Dockerfile .
|
||||
|
||||
You can choose to build a different docker image by changing the value of the `-f` flag to
|
||||
point to another Dockerfile.
|
||||
|
||||
## Disabling the healthcheck
|
||||
|
||||
If you are using a non-standard port or tls inside docker you can disable the healthcheck
|
||||
whilst running the above `docker run` commands.
|
||||
|
||||
```
|
||||
--no-healthcheck
|
||||
```
|
||||
## Setting custom healthcheck on docker run
|
||||
|
||||
If you wish to point the healthcheck at a different port with docker command, add the following
|
||||
|
||||
```
|
||||
--health-cmd 'curl -fSs http://localhost:1234/health'
|
||||
```
|
||||
|
||||
## Setting the healthcheck in docker-compose file
|
||||
|
||||
You can add the following to set a custom healthcheck in a docker compose file.
|
||||
You will need version >2.1 for this to work.
|
||||
|
||||
```
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-fSs", "http://localhost:8008/health"]
|
||||
interval: 1m
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
```
|
||||
|
||||
@@ -108,7 +108,7 @@ The api is::
|
||||
|
||||
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
|
||||
|
||||
To use it, you will need to authenticate by providing an `access_token` for a
|
||||
To use it, you will need to authenticate by providing an ``access_token`` for a
|
||||
server admin: see `README.rst <README.rst>`_.
|
||||
|
||||
The parameter ``from`` is optional but used for pagination, denoting the
|
||||
@@ -119,8 +119,11 @@ from a previous call.
|
||||
The parameter ``limit`` is optional but is used for pagination, denoting the
|
||||
maximum number of items to return in this call. Defaults to ``100``.
|
||||
|
||||
The parameter ``user_id`` is optional and filters to only users with user IDs
|
||||
that contain this value.
|
||||
The parameter ``user_id`` is optional and filters to only return users with user IDs
|
||||
that contain this value. This parameter is ignored when using the ``name`` parameter.
|
||||
|
||||
The parameter ``name`` is optional and filters to only return users with user ID localparts
|
||||
**or** displaynames that contain this value.
|
||||
|
||||
The parameter ``guests`` is optional and if ``false`` will **exclude** guest users.
|
||||
Defaults to ``true`` to include guest users.
|
||||
|
||||
+14
-8
@@ -378,11 +378,10 @@ retention:
|
||||
# min_lifetime: 1d
|
||||
# max_lifetime: 1y
|
||||
|
||||
# Retention policy limits. If set, a user won't be able to send a
|
||||
# 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
|
||||
# that's not within this range. This is especially useful in closed federations,
|
||||
# in which server admins can make sure every federating server applies the same
|
||||
# rules.
|
||||
# Retention policy limits. If set, and the state of a room contains a
|
||||
# 'm.room.retention' event in its state which contains a 'min_lifetime' or a
|
||||
# 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
|
||||
# to these limits when running purge jobs.
|
||||
#
|
||||
#allowed_lifetime_min: 1d
|
||||
#allowed_lifetime_max: 1y
|
||||
@@ -408,12 +407,19 @@ retention:
|
||||
# (e.g. every 12h), but not want that purge to be performed by a job that's
|
||||
# iterating over every room it knows, which could be heavy on the server.
|
||||
#
|
||||
# If any purge job is configured, it is strongly recommended to have at least
|
||||
# a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
|
||||
# set, or one job without 'shortest_max_lifetime' and one job without
|
||||
# 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
|
||||
# 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
|
||||
# room's policy to these values is done after the policies are retrieved from
|
||||
# Synapse's database (which is done using the range specified in a purge job's
|
||||
# configuration).
|
||||
#
|
||||
#purge_jobs:
|
||||
# - shortest_max_lifetime: 1d
|
||||
# longest_max_lifetime: 3d
|
||||
# - longest_max_lifetime: 3d
|
||||
# interval: 12h
|
||||
# - shortest_max_lifetime: 3d
|
||||
# longest_max_lifetime: 1y
|
||||
# interval: 1d
|
||||
|
||||
# Inhibits the /requestToken endpoints from returning an error that might leak
|
||||
|
||||
@@ -21,10 +21,12 @@ import argparse
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
from urllib import parse as urlparse
|
||||
|
||||
import nacl.signing
|
||||
import requests
|
||||
import signedjson.types
|
||||
import srvlookup
|
||||
import yaml
|
||||
from requests.adapters import HTTPAdapter
|
||||
@@ -69,7 +71,9 @@ def encode_canonical_json(value):
|
||||
).encode("UTF-8")
|
||||
|
||||
|
||||
def sign_json(json_object, signing_key, signing_name):
|
||||
def sign_json(
|
||||
json_object: Any, signing_key: signedjson.types.SigningKey, signing_name: str
|
||||
) -> Any:
|
||||
signatures = json_object.pop("signatures", {})
|
||||
unsigned = json_object.pop("unsigned", None)
|
||||
|
||||
@@ -122,7 +126,14 @@ def read_signing_keys(stream):
|
||||
return keys
|
||||
|
||||
|
||||
def request_json(method, origin_name, origin_key, destination, path, content):
|
||||
def request(
|
||||
method: Optional[str],
|
||||
origin_name: str,
|
||||
origin_key: signedjson.types.SigningKey,
|
||||
destination: str,
|
||||
path: str,
|
||||
content: Optional[str],
|
||||
) -> requests.Response:
|
||||
if method is None:
|
||||
if content is None:
|
||||
method = "GET"
|
||||
@@ -159,11 +170,14 @@ def request_json(method, origin_name, origin_key, destination, path, content):
|
||||
if method == "POST":
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
result = s.request(
|
||||
method=method, url=dest, headers=headers, verify=False, data=content
|
||||
return s.request(
|
||||
method=method,
|
||||
url=dest,
|
||||
headers=headers,
|
||||
verify=False,
|
||||
data=content,
|
||||
stream=True,
|
||||
)
|
||||
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
|
||||
return result.json()
|
||||
|
||||
|
||||
def main():
|
||||
@@ -222,7 +236,7 @@ def main():
|
||||
with open(args.signing_key_path) as f:
|
||||
key = read_signing_keys(f)[0]
|
||||
|
||||
result = request_json(
|
||||
result = request(
|
||||
args.method,
|
||||
args.server_name,
|
||||
key,
|
||||
@@ -231,7 +245,12 @@ def main():
|
||||
content=args.body,
|
||||
)
|
||||
|
||||
json.dump(result, sys.stdout)
|
||||
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
|
||||
|
||||
for chunk in result.iter_content():
|
||||
# we write raw utf8 to stdout.
|
||||
sys.stdout.buffer.write(chunk)
|
||||
|
||||
print("")
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Stub for frozendict.
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Hashable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Mapping,
|
||||
overload,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
_KT = TypeVar("_KT", bound=Hashable) # Key type.
|
||||
_VT = TypeVar("_VT") # Value type.
|
||||
|
||||
class frozendict(Mapping[_KT, _VT]):
|
||||
@overload
|
||||
def __init__(self, **kwargs: _VT) -> None: ...
|
||||
@overload
|
||||
def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
|
||||
@overload
|
||||
def __init__(
|
||||
self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
|
||||
) -> None: ...
|
||||
def __getitem__(self, key: _KT) -> _VT: ...
|
||||
def __contains__(self, key: Any) -> bool: ...
|
||||
def copy(self, **add_or_replace: Any) -> frozendict: ...
|
||||
def __iter__(self) -> Iterator[_KT]: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __repr__(self) -> str: ...
|
||||
def __hash__(self) -> int: ...
|
||||
+1
-1
@@ -48,7 +48,7 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.19.0"
|
||||
__version__ = "1.19.1rc1"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
||||
@@ -17,6 +17,7 @@ from collections import OrderedDict
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import LimitExceededError
|
||||
from synapse.types import Requester
|
||||
from synapse.util import Clock
|
||||
|
||||
|
||||
@@ -43,6 +44,42 @@ class Ratelimiter(object):
|
||||
# * The rate_hz of this particular entry. This can vary per request
|
||||
self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]]
|
||||
|
||||
def can_requester_do_action(
|
||||
self,
|
||||
requester: Requester,
|
||||
rate_hz: Optional[float] = None,
|
||||
burst_count: Optional[int] = None,
|
||||
update: bool = True,
|
||||
_time_now_s: Optional[int] = None,
|
||||
) -> Tuple[bool, float]:
|
||||
"""Can the requester perform the action?
|
||||
|
||||
Args:
|
||||
requester: The requester to key off when rate limiting. The user property
|
||||
will be used.
|
||||
rate_hz: The long term number of actions that can be performed in a second.
|
||||
Overrides the value set during instantiation if set.
|
||||
burst_count: How many actions that can be performed before being limited.
|
||||
Overrides the value set during instantiation if set.
|
||||
update: Whether to count this check as performing the action
|
||||
_time_now_s: The current time. Optional, defaults to the current time according
|
||||
to self.clock. Only used by tests.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
* A bool indicating if they can perform the action now
|
||||
* The reactor timestamp for when the action can be performed next.
|
||||
-1 if rate_hz is less than or equal to zero
|
||||
"""
|
||||
# Disable rate limiting of users belonging to any AS that is configured
|
||||
# not to be rate limited in its registration file (rate_limited: true|false).
|
||||
if requester.app_service and not requester.app_service.is_rate_limited():
|
||||
return True, -1.0
|
||||
|
||||
return self.can_do_action(
|
||||
requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
|
||||
)
|
||||
|
||||
def can_do_action(
|
||||
self,
|
||||
key: Any,
|
||||
|
||||
@@ -961,11 +961,10 @@ class ServerConfig(Config):
|
||||
# min_lifetime: 1d
|
||||
# max_lifetime: 1y
|
||||
|
||||
# Retention policy limits. If set, a user won't be able to send a
|
||||
# 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
|
||||
# that's not within this range. This is especially useful in closed federations,
|
||||
# in which server admins can make sure every federating server applies the same
|
||||
# rules.
|
||||
# Retention policy limits. If set, and the state of a room contains a
|
||||
# 'm.room.retention' event in its state which contains a 'min_lifetime' or a
|
||||
# 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
|
||||
# to these limits when running purge jobs.
|
||||
#
|
||||
#allowed_lifetime_min: 1d
|
||||
#allowed_lifetime_max: 1y
|
||||
@@ -991,12 +990,19 @@ class ServerConfig(Config):
|
||||
# (e.g. every 12h), but not want that purge to be performed by a job that's
|
||||
# iterating over every room it knows, which could be heavy on the server.
|
||||
#
|
||||
# If any purge job is configured, it is strongly recommended to have at least
|
||||
# a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
|
||||
# set, or one job without 'shortest_max_lifetime' and one job without
|
||||
# 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
|
||||
# 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
|
||||
# room's policy to these values is done after the policies are retrieved from
|
||||
# Synapse's database (which is done using the range specified in a purge job's
|
||||
# configuration).
|
||||
#
|
||||
#purge_jobs:
|
||||
# - shortest_max_lifetime: 1d
|
||||
# longest_max_lifetime: 3d
|
||||
# - longest_max_lifetime: 3d
|
||||
# interval: 12h
|
||||
# - shortest_max_lifetime: 3d
|
||||
# longest_max_lifetime: 1y
|
||||
# interval: 1d
|
||||
|
||||
# Inhibits the /requestToken endpoints from returning an error that might leak
|
||||
|
||||
@@ -74,15 +74,14 @@ class EventValidator(object):
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Retention:
|
||||
self._validate_retention(event, config)
|
||||
self._validate_retention(event)
|
||||
|
||||
def _validate_retention(self, event, config):
|
||||
def _validate_retention(self, event):
|
||||
"""Checks that an event that defines the retention policy for a room respects the
|
||||
boundaries imposed by the server's administrator.
|
||||
format enforced by the spec.
|
||||
|
||||
Args:
|
||||
event (FrozenEvent): The event to validate.
|
||||
config (Config): The homeserver's configuration.
|
||||
"""
|
||||
min_lifetime = event.content.get("min_lifetime")
|
||||
max_lifetime = event.content.get("max_lifetime")
|
||||
@@ -95,32 +94,6 @@ class EventValidator(object):
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if (
|
||||
config.retention_allowed_lifetime_min is not None
|
||||
and min_lifetime < config.retention_allowed_lifetime_min
|
||||
):
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg=(
|
||||
"'min_lifetime' can't be lower than the minimum allowed"
|
||||
" value enforced by the server's administrator"
|
||||
),
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if (
|
||||
config.retention_allowed_lifetime_max is not None
|
||||
and min_lifetime > config.retention_allowed_lifetime_max
|
||||
):
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg=(
|
||||
"'min_lifetime' can't be greater than the maximum allowed"
|
||||
" value enforced by the server's administrator"
|
||||
),
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if max_lifetime is not None:
|
||||
if not isinstance(max_lifetime, int):
|
||||
raise SynapseError(
|
||||
@@ -129,32 +102,6 @@ class EventValidator(object):
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if (
|
||||
config.retention_allowed_lifetime_min is not None
|
||||
and max_lifetime < config.retention_allowed_lifetime_min
|
||||
):
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg=(
|
||||
"'max_lifetime' can't be lower than the minimum allowed value"
|
||||
" enforced by the server's administrator"
|
||||
),
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if (
|
||||
config.retention_allowed_lifetime_max is not None
|
||||
and max_lifetime > config.retention_allowed_lifetime_max
|
||||
):
|
||||
raise SynapseError(
|
||||
code=400,
|
||||
msg=(
|
||||
"'max_lifetime' can't be greater than the maximum allowed"
|
||||
" value enforced by the server's administrator"
|
||||
),
|
||||
errcode=Codes.BAD_JSON,
|
||||
)
|
||||
|
||||
if (
|
||||
min_lifetime is not None
|
||||
and max_lifetime is not None
|
||||
|
||||
@@ -329,10 +329,10 @@ class FederationSender(object):
|
||||
room_id = receipt.room_id
|
||||
|
||||
# Work out which remote servers should be poked and poke them.
|
||||
domains = await self.state.get_current_hosts_in_room(room_id)
|
||||
domains_set = await self.state.get_current_hosts_in_room(room_id)
|
||||
domains = [
|
||||
d
|
||||
for d in domains
|
||||
for d in domains_set
|
||||
if d != self.server_name
|
||||
and self._federation_shard_config.should_handle(self._instance_name, d)
|
||||
]
|
||||
|
||||
@@ -23,6 +23,7 @@ from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
Codes,
|
||||
NotFoundError,
|
||||
ShadowBanError,
|
||||
StoreError,
|
||||
SynapseError,
|
||||
)
|
||||
@@ -199,6 +200,8 @@ class DirectoryHandler(BaseHandler):
|
||||
|
||||
try:
|
||||
await self._update_canonical_alias(requester, user_id, room_id, room_alias)
|
||||
except ShadowBanError as e:
|
||||
logger.info("Failed to update alias events due to shadow-ban: %s", e)
|
||||
except AuthError as e:
|
||||
logger.info("Failed to update alias events: %s", e)
|
||||
|
||||
@@ -292,6 +295,9 @@ class DirectoryHandler(BaseHandler):
|
||||
"""
|
||||
Send an updated canonical alias event if the removed alias was set as
|
||||
the canonical alias or listed in the alt_aliases field.
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester has been shadow-banned.
|
||||
"""
|
||||
alias_event = await self.state.get_current_state(
|
||||
room_id, EventTypes.CanonicalAlias, ""
|
||||
|
||||
@@ -2134,10 +2134,10 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
state_sets = list(state_sets.values())
|
||||
state_sets.append(state)
|
||||
current_state_ids = await self.state_handler.resolve_events(
|
||||
current_states = await self.state_handler.resolve_events(
|
||||
room_version, state_sets, event
|
||||
)
|
||||
current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
|
||||
current_state_ids = {k: e.event_id for k, e in current_states.items()}
|
||||
else:
|
||||
current_state_ids = await self.state_handler.get_current_state_ids(
|
||||
event.room_id, latest_event_ids=extrem_ids
|
||||
@@ -2149,9 +2149,11 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
# Now check if event pass auth against said current state
|
||||
auth_types = auth_types_for_event(event)
|
||||
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
|
||||
current_state_ids_list = [
|
||||
e for k, e in current_state_ids.items() if k in auth_types
|
||||
]
|
||||
|
||||
auth_events_map = await self.store.get_events(current_state_ids)
|
||||
auth_events_map = await self.store.get_events(current_state_ids_list)
|
||||
current_auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_events_map.values()
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
@@ -34,6 +35,7 @@ from synapse.api.errors import (
|
||||
Codes,
|
||||
ConsentNotGivenError,
|
||||
NotFoundError,
|
||||
ShadowBanError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
||||
@@ -645,24 +647,35 @@ class EventCreationHandler(object):
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
ratelimit: bool = True,
|
||||
ignore_shadow_ban: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Persists and notifies local clients and federation of an event.
|
||||
|
||||
Args:
|
||||
requester
|
||||
event the event to send.
|
||||
context: the context of the event.
|
||||
requester: The requester sending the event.
|
||||
event: The event to send.
|
||||
context: The context of the event.
|
||||
ratelimit: Whether to rate limit this send.
|
||||
ignore_shadow_ban: True if shadow-banned users should be allowed to
|
||||
send this event.
|
||||
|
||||
Return:
|
||||
The stream_id of the persisted event.
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester has been shadow-banned.
|
||||
"""
|
||||
if event.type == EventTypes.Member:
|
||||
raise SynapseError(
|
||||
500, "Tried to send member event through non-member codepath"
|
||||
)
|
||||
|
||||
if not ignore_shadow_ban and requester.shadow_banned:
|
||||
# We randomly sleep a bit just to annoy the requester.
|
||||
await self.clock.sleep(random.randint(1, 10))
|
||||
raise ShadowBanError()
|
||||
|
||||
user = UserID.from_string(event.sender)
|
||||
|
||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
@@ -716,12 +729,28 @@ class EventCreationHandler(object):
|
||||
event_dict: dict,
|
||||
ratelimit: bool = True,
|
||||
txn_id: Optional[str] = None,
|
||||
ignore_shadow_ban: bool = False,
|
||||
) -> Tuple[EventBase, int]:
|
||||
"""
|
||||
Creates an event, then sends it.
|
||||
|
||||
See self.create_event and self.send_nonmember_event.
|
||||
|
||||
Args:
|
||||
requester: The requester sending the event.
|
||||
event_dict: An entire event.
|
||||
ratelimit: Whether to rate limit this send.
|
||||
txn_id: The transaction ID.
|
||||
ignore_shadow_ban: True if shadow-banned users should be allowed to
|
||||
send this event.
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester has been shadow-banned.
|
||||
"""
|
||||
if not ignore_shadow_ban and requester.shadow_banned:
|
||||
# We randomly sleep a bit just to annoy the requester.
|
||||
await self.clock.sleep(random.randint(1, 10))
|
||||
raise ShadowBanError()
|
||||
|
||||
# We limit the number of concurrent event sends in a room so that we
|
||||
# don't fork the DAG too much. If we don't limit then we can end up in
|
||||
@@ -740,7 +769,11 @@ class EventCreationHandler(object):
|
||||
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
|
||||
|
||||
stream_id = await self.send_nonmember_event(
|
||||
requester, event, context, ratelimit=ratelimit
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=ratelimit,
|
||||
ignore_shadow_ban=ignore_shadow_ban,
|
||||
)
|
||||
return event, stream_id
|
||||
|
||||
@@ -1180,8 +1213,14 @@ class EventCreationHandler(object):
|
||||
|
||||
event.internal_metadata.proactively_send = False
|
||||
|
||||
# Since this is a dummy-event it is OK if it is sent by a
|
||||
# shadow-banned user.
|
||||
await self.send_nonmember_event(
|
||||
requester, event, context, ratelimit=False
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=False,
|
||||
ignore_shadow_ban=True,
|
||||
)
|
||||
dummy_event_sent = True
|
||||
break
|
||||
|
||||
@@ -82,6 +82,9 @@ class PaginationHandler(object):
|
||||
|
||||
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
|
||||
|
||||
self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
|
||||
self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
|
||||
|
||||
if hs.config.retention_enabled:
|
||||
# Run the purge jobs described in the configuration file.
|
||||
for job in hs.config.retention_purge_jobs:
|
||||
@@ -111,7 +114,7 @@ class PaginationHandler(object):
|
||||
the range to handle (inclusive). If None, it means that the range has no
|
||||
upper limit.
|
||||
"""
|
||||
# We want the storage layer to to include rooms with no retention policy in its
|
||||
# We want the storage layer to include rooms with no retention policy in its
|
||||
# return value only if a default retention policy is defined in the server's
|
||||
# configuration and that policy's 'max_lifetime' is either lower (or equal) than
|
||||
# max_ms or higher than min_ms (or both).
|
||||
@@ -152,13 +155,32 @@ class PaginationHandler(object):
|
||||
)
|
||||
continue
|
||||
|
||||
max_lifetime = retention_policy["max_lifetime"]
|
||||
# If max_lifetime is None, it means that the room has no retention policy.
|
||||
# Given we only retrieve such rooms when there's a default retention policy
|
||||
# defined in the server's configuration, we can safely assume that's the
|
||||
# case and use it for this room.
|
||||
max_lifetime = (
|
||||
retention_policy["max_lifetime"] or self._retention_default_max_lifetime
|
||||
)
|
||||
|
||||
if max_lifetime is None:
|
||||
# If max_lifetime is None, it means that include_null equals True,
|
||||
# therefore we can safely assume that there is a default policy defined
|
||||
# in the server's configuration.
|
||||
max_lifetime = self._retention_default_max_lifetime
|
||||
# Cap the effective max_lifetime to be within the range allowed in the
|
||||
# config.
|
||||
# We do this in two steps:
|
||||
# 1. Make sure it's higher or equal to the minimum allowed value, and if
|
||||
# it's not replace it with that value. This is because the server
|
||||
# operator can be required to not delete information before a given
|
||||
# time, e.g. to comply with freedom of information laws.
|
||||
# 2. Make sure the resulting value is lower or equal to the maximum allowed
|
||||
# value, and if it's not replace it with that value. This is because the
|
||||
# server operator can be required to delete any data after a specific
|
||||
# amount of time.
|
||||
if self._retention_allowed_lifetime_min is not None:
|
||||
max_lifetime = max(self._retention_allowed_lifetime_min, max_lifetime)
|
||||
|
||||
if self._retention_allowed_lifetime_max is not None:
|
||||
max_lifetime = min(max_lifetime, self._retention_allowed_lifetime_max)
|
||||
|
||||
logger.debug("[purge] max_lifetime for room %s: %s", room_id, max_lifetime)
|
||||
|
||||
# Figure out what token we should start purging at.
|
||||
ts = self.clock.time_msec() - max_lifetime
|
||||
|
||||
@@ -40,7 +40,7 @@ from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage.databases.main import DataStore
|
||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.metrics import Measure
|
||||
@@ -1318,7 +1318,7 @@ async def get_interested_parties(
|
||||
|
||||
async def get_interested_remotes(
|
||||
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
|
||||
) -> List[Tuple[List[str], List[UserPresenceState]]]:
|
||||
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
|
||||
"""Given a list of presence states figure out which remote servers
|
||||
should be sent which.
|
||||
|
||||
@@ -1334,7 +1334,7 @@ async def get_interested_remotes(
|
||||
each tuple the list of UserPresenceState should be sent to each
|
||||
destination
|
||||
"""
|
||||
hosts_and_states = []
|
||||
hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]]
|
||||
|
||||
# First we look up the rooms each user is in (as well as any explicit
|
||||
# subscriptions), then for each distinct room we look up the remote
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
@@ -213,8 +214,14 @@ class BaseProfileHandler(BaseHandler):
|
||||
async def set_avatar_url(
|
||||
self, target_user, requester, new_avatar_url, by_admin=False
|
||||
):
|
||||
"""target_user is the user whose avatar_url is to be changed;
|
||||
auth_user is the user attempting to make this change."""
|
||||
"""Set a new avatar URL for a user.
|
||||
|
||||
Args:
|
||||
target_user (UserID): the user whose avatar URL is to be changed.
|
||||
requester (Requester): The user attempting to make this change.
|
||||
new_avatar_url (str): The avatar URL to give this user.
|
||||
by_admin (bool): Whether this change was made by an administrator.
|
||||
"""
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||
|
||||
@@ -278,6 +285,12 @@ class BaseProfileHandler(BaseHandler):
|
||||
|
||||
await self.ratelimit(requester)
|
||||
|
||||
# Do not actually update the room state for shadow-banned users.
|
||||
if requester.shadow_banned:
|
||||
# We randomly sleep a bit just to annoy the requester.
|
||||
await self.clock.sleep(random.randint(1, 10))
|
||||
return
|
||||
|
||||
room_ids = await self.store.get_rooms_for_user(target_user.to_string())
|
||||
|
||||
for room_id in room_ids:
|
||||
|
||||
@@ -136,6 +136,9 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
Returns:
|
||||
the new room id
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester is shadow-banned.
|
||||
"""
|
||||
await self.ratelimit(requester)
|
||||
|
||||
@@ -171,6 +174,15 @@ class RoomCreationHandler(BaseHandler):
|
||||
async def _upgrade_room(
|
||||
self, requester: Requester, old_room_id: str, new_version: RoomVersion
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
requester: the user requesting the upgrade
|
||||
old_room_id: the id of the room to be replaced
|
||||
new_versions: the version to upgrade the room to
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester is shadow-banned.
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
# start by allocating a new room id
|
||||
@@ -257,6 +269,9 @@ class RoomCreationHandler(BaseHandler):
|
||||
old_room_id: the id of the room to be replaced
|
||||
new_room_id: the id of the replacement room
|
||||
old_room_state: the state map for the old room
|
||||
|
||||
Raises:
|
||||
ShadowBanError if the requester is shadow-banned.
|
||||
"""
|
||||
old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
|
||||
|
||||
@@ -829,11 +844,13 @@ class RoomCreationHandler(BaseHandler):
|
||||
async def send(etype: str, content: JsonDict, **kwargs) -> int:
|
||||
event = create(etype, content, **kwargs)
|
||||
logger.debug("Sending %s in new room", etype)
|
||||
# Allow these events to be sent even if the user is shadow-banned to
|
||||
# allow the room creation to complete.
|
||||
(
|
||||
_,
|
||||
last_stream_id,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
creator, event, ratelimit=False
|
||||
creator, event, ratelimit=False, ignore_shadow_ban=True,
|
||||
)
|
||||
return last_stream_id
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ import abc
|
||||
import logging
|
||||
import random
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
@@ -38,7 +38,15 @@ from synapse.events.builder import create_local_event_from_event_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
|
||||
from synapse.types import (
|
||||
Collection,
|
||||
JsonDict,
|
||||
Requester,
|
||||
RoomAlias,
|
||||
RoomID,
|
||||
StateMap,
|
||||
UserID,
|
||||
)
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.distributor import user_joined_room, user_left_room
|
||||
|
||||
@@ -217,24 +225,40 @@ class RoomMemberHandler(object):
|
||||
_, stream_id = await self.store.get_event_ordering(duplicate.event_id)
|
||||
return duplicate.event_id, stream_id
|
||||
|
||||
stream_id = await self.event_creation_handler.handle_new_client_event(
|
||||
requester, event, context, extra_users=[target], ratelimit=ratelimit,
|
||||
)
|
||||
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
|
||||
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
|
||||
|
||||
newly_joined = False
|
||||
if event.membership == Membership.JOIN:
|
||||
# Only fire user_joined_room if the user has actually joined the
|
||||
# room. Don't bother if the user is just changing their profile
|
||||
# info.
|
||||
newly_joined = True
|
||||
if prev_member_event_id:
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
newly_joined = prev_member_event.membership != Membership.JOIN
|
||||
|
||||
# Only rate-limit if the user actually joined the room, otherwise we'll end
|
||||
# up blocking profile updates.
|
||||
if newly_joined:
|
||||
await self._user_joined_room(target, room_id)
|
||||
time_now_s = self.clock.time()
|
||||
(
|
||||
allowed,
|
||||
time_allowed,
|
||||
) = self._join_rate_limiter_local.can_requester_do_action(requester)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now_s))
|
||||
)
|
||||
|
||||
stream_id = await self.event_creation_handler.handle_new_client_event(
|
||||
requester, event, context, extra_users=[target], ratelimit=ratelimit,
|
||||
)
|
||||
|
||||
if event.membership == Membership.JOIN and newly_joined:
|
||||
# Only fire user_joined_room if the user has actually joined the
|
||||
# room. Don't bother if the user is just changing their profile
|
||||
# info.
|
||||
await self._user_joined_room(target, room_id)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
if prev_member_event_id:
|
||||
prev_member_event = await self.store.get_event(prev_member_event_id)
|
||||
@@ -356,7 +380,7 @@ class RoomMemberHandler(object):
|
||||
# later on.
|
||||
content = dict(content)
|
||||
|
||||
if not self.allow_per_room_profiles:
|
||||
if not self.allow_per_room_profiles or requester.shadow_banned:
|
||||
# Strip profile data, knowing that new profile data will be added to the
|
||||
# event's content in event_creation_handler.create_event() using the target's
|
||||
# global profile.
|
||||
@@ -489,22 +513,12 @@ class RoomMemberHandler(object):
|
||||
# so don't really fit into the general auth process.
|
||||
raise AuthError(403, "Guest access not allowed")
|
||||
|
||||
if is_host_in_room:
|
||||
if not is_host_in_room:
|
||||
time_now_s = self.clock.time()
|
||||
allowed, time_allowed = self._join_rate_limiter_local.can_do_action(
|
||||
requester.user.to_string(),
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
retry_after_ms=int(1000 * (time_allowed - time_now_s))
|
||||
)
|
||||
|
||||
else:
|
||||
time_now_s = self.clock.time()
|
||||
allowed, time_allowed = self._join_rate_limiter_remote.can_do_action(
|
||||
requester.user.to_string(),
|
||||
)
|
||||
(
|
||||
allowed,
|
||||
time_allowed,
|
||||
) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
|
||||
|
||||
if not allowed:
|
||||
raise LimitExceededError(
|
||||
@@ -736,9 +750,7 @@ class RoomMemberHandler(object):
|
||||
if prev_member_event.membership == Membership.JOIN:
|
||||
await self._user_left_room(target_user, room_id)
|
||||
|
||||
async def _can_guest_join(
|
||||
self, current_state_ids: Dict[Tuple[str, str], str]
|
||||
) -> bool:
|
||||
async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
|
||||
"""
|
||||
Returns whether a guest can join a room based on its current state.
|
||||
"""
|
||||
@@ -967,9 +979,7 @@ class RoomMemberHandler(object):
|
||||
)
|
||||
return stream_id
|
||||
|
||||
async def _is_host_in_room(
|
||||
self, current_state_ids: Dict[Tuple[str, str], str]
|
||||
) -> bool:
|
||||
async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
|
||||
# Have we just created the room, and is this about to be the very
|
||||
# first member event?
|
||||
create_event_id = current_state_ids.get(("m.room.create", ""))
|
||||
|
||||
@@ -21,9 +21,9 @@ class SlavedIdTracker(object):
|
||||
self.step = step
|
||||
self._current = _load_current_id(db_conn, table, column, step)
|
||||
for table, column in extra_tables:
|
||||
self.advance(_load_current_id(db_conn, table, column))
|
||||
self.advance(None, _load_current_id(db_conn, table, column))
|
||||
|
||||
def advance(self, new_id):
|
||||
def advance(self, instance_name, new_id):
|
||||
self._current = (max if self.step > 0 else min)(self._current, new_id)
|
||||
|
||||
def get_current_token(self):
|
||||
|
||||
@@ -41,12 +41,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == TagAccountDataStream.NAME:
|
||||
self._account_data_id_gen.advance(token)
|
||||
self._account_data_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self.get_tags_for_user.invalidate((row.user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
|
||||
elif stream_name == AccountDataStream.NAME:
|
||||
self._account_data_id_gen.advance(token)
|
||||
self._account_data_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
if not row.room_id:
|
||||
self.get_global_account_data_by_type_for_user.invalidate(
|
||||
|
||||
@@ -46,7 +46,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == ToDeviceStream.NAME:
|
||||
self._device_inbox_id_gen.advance(token)
|
||||
self._device_inbox_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
if row.entity.startswith("@"):
|
||||
self._device_inbox_stream_cache.entity_has_changed(
|
||||
|
||||
@@ -50,10 +50,10 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(token)
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
self._invalidate_caches_for_devices(token, rows)
|
||||
elif stream_name == UserSignatureStream.NAME:
|
||||
self._device_list_id_gen.advance(token)
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -40,7 +40,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == GroupServerStream.NAME:
|
||||
self._group_updates_id_gen.advance(token)
|
||||
self._group_updates_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == PresenceStream.NAME:
|
||||
self._presence_id_gen.advance(token)
|
||||
self._presence_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self.presence_stream_cache.entity_has_changed(row.user_id, token)
|
||||
self._get_presence_for_user.invalidate((row.user_id,))
|
||||
|
||||
@@ -30,7 +30,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
||||
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
|
||||
|
||||
if stream_name == PushRulesStream.NAME:
|
||||
self._push_rules_stream_id_gen.advance(token)
|
||||
self._push_rules_stream_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self.get_push_rules_for_user.invalidate((row.user_id,))
|
||||
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
|
||||
|
||||
@@ -34,5 +34,5 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == PushersStream.NAME:
|
||||
self._pushers_id_gen.advance(token)
|
||||
self._pushers_id_gen.advance(instance_name, token)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -46,7 +46,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == ReceiptsStream.NAME:
|
||||
self._receipts_id_gen.advance(token)
|
||||
self._receipts_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self.invalidate_caches_for_receipt(
|
||||
row.room_id, row.receipt_type, row.user_id
|
||||
|
||||
@@ -33,6 +33,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == PublicRoomsStream.NAME:
|
||||
self._public_room_id_gen.advance(token)
|
||||
self._public_room_id_gen.advance(instance_name, token)
|
||||
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -73,6 +73,7 @@ class UsersRestServletV2(RestServlet):
|
||||
The parameters `from` and `limit` are required only for pagination.
|
||||
By default, a `limit` of 100 is used.
|
||||
The parameter `user_id` can be used to filter by user id.
|
||||
The parameter `name` can be used to filter by user id or display name.
|
||||
The parameter `guests` can be used to exclude guest users.
|
||||
The parameter `deactivated` can be used to include deactivated users.
|
||||
"""
|
||||
@@ -89,11 +90,12 @@ class UsersRestServletV2(RestServlet):
|
||||
start = parse_integer(request, "from", default=0)
|
||||
limit = parse_integer(request, "limit", default=100)
|
||||
user_id = parse_string(request, "user_id", default=None)
|
||||
name = parse_string(request, "name", default=None)
|
||||
guests = parse_boolean(request, "guests", default=True)
|
||||
deactivated = parse_boolean(request, "deactivated", default=False)
|
||||
|
||||
users, total = await self.store.get_users_paginate(
|
||||
start, limit, user_id, guests, deactivated
|
||||
start, limit, user_id, name, guests, deactivated
|
||||
)
|
||||
ret = {"users": users, "total": total}
|
||||
if len(users) >= limit:
|
||||
|
||||
@@ -201,8 +201,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
||||
if state_key is not None:
|
||||
event_dict["state_key"] = state_key
|
||||
|
||||
if event_type == EventTypes.Member:
|
||||
try:
|
||||
try:
|
||||
if event_type == EventTypes.Member:
|
||||
membership = content.get("membership", None)
|
||||
event_id, _ = await self.room_member_handler.update_membership(
|
||||
requester,
|
||||
@@ -211,16 +211,16 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
||||
action=membership,
|
||||
content=content,
|
||||
)
|
||||
except ShadowBanError:
|
||||
event_id = "$" + random_string(43)
|
||||
else:
|
||||
(
|
||||
event,
|
||||
_,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
event_id = event.event_id
|
||||
else:
|
||||
(
|
||||
event,
|
||||
_,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
event_id = event.event_id
|
||||
except ShadowBanError:
|
||||
event_id = "$" + random_string(43)
|
||||
|
||||
set_tag("event_id", event_id)
|
||||
ret = {"event_id": event_id}
|
||||
@@ -253,12 +253,19 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
||||
if b"ts" in request.args and requester.app_service:
|
||||
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
|
||||
|
||||
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
try:
|
||||
(
|
||||
event,
|
||||
_,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, txn_id=txn_id
|
||||
)
|
||||
event_id = event.event_id
|
||||
except ShadowBanError:
|
||||
event_id = "$" + random_string(43)
|
||||
|
||||
set_tag("event_id", event.event_id)
|
||||
return 200, {"event_id": event.event_id}
|
||||
set_tag("event_id", event_id)
|
||||
return 200, {"event_id": event_id}
|
||||
|
||||
def on_GET(self, request, room_id, event_type, txn_id):
|
||||
return 200, "Not implemented"
|
||||
@@ -799,20 +806,27 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Redaction,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": requester.user.to_string(),
|
||||
"redacts": event_id,
|
||||
},
|
||||
txn_id=txn_id,
|
||||
)
|
||||
try:
|
||||
(
|
||||
event,
|
||||
_,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Redaction,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": requester.user.to_string(),
|
||||
"redacts": event_id,
|
||||
},
|
||||
txn_id=txn_id,
|
||||
)
|
||||
event_id = event.event_id
|
||||
except ShadowBanError:
|
||||
event_id = "$" + random_string(43)
|
||||
|
||||
set_tag("event_id", event.event_id)
|
||||
return 200, {"event_id": event.event_id}
|
||||
set_tag("event_id", event_id)
|
||||
return 200, {"event_id": event_id}
|
||||
|
||||
def on_PUT(self, request, room_id, event_id, txn_id):
|
||||
set_tag("txn_id", txn_id)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -113,6 +114,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||
if self.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
|
||||
@@ -504,6 +508,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||
if self.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
@@ -572,6 +579,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||
if self.hs.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
import hmac
|
||||
import logging
|
||||
import random
|
||||
from typing import List, Union
|
||||
|
||||
import synapse
|
||||
@@ -131,6 +132,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||
if self.hs.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
|
||||
@@ -203,6 +207,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||
if self.hs.config.request_token_inhibit_3pid_errors:
|
||||
# Make the client think the operation succeeded. See the rationale in the
|
||||
# comments for request_token_inhibit_3pid_errors.
|
||||
# Also wait for some random amount of time between 100ms and 1s to make it
|
||||
# look like we did something.
|
||||
await self.hs.clock.sleep(random.randint(1, 10) / 10)
|
||||
return 200, {"sid": random_string(16)}
|
||||
|
||||
raise SynapseError(
|
||||
|
||||
@@ -22,7 +22,7 @@ any time to reflect changes in the MSC.
|
||||
import logging
|
||||
|
||||
from synapse.api.constants import EventTypes, RelationTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import ShadowBanError, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
parse_integer,
|
||||
@@ -35,6 +35,7 @@ from synapse.storage.relations import (
|
||||
PaginationChunk,
|
||||
RelationPaginationToken,
|
||||
)
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
@@ -111,11 +112,18 @@ class RelationSendServlet(RestServlet):
|
||||
"sender": requester.user.to_string(),
|
||||
}
|
||||
|
||||
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict=event_dict, txn_id=txn_id
|
||||
)
|
||||
try:
|
||||
(
|
||||
event,
|
||||
_,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict=event_dict, txn_id=txn_id
|
||||
)
|
||||
event_id = event.event_id
|
||||
except ShadowBanError:
|
||||
event_id = "$" + random_string(43)
|
||||
|
||||
return 200, {"event_id": event.event_id}
|
||||
return 200, {"event_id": event_id}
|
||||
|
||||
|
||||
class RelationPaginationServlet(RestServlet):
|
||||
|
||||
@@ -15,13 +15,14 @@
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.errors import Codes, ShadowBanError, SynapseError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.util import stringutils
|
||||
|
||||
from ._base import client_patterns
|
||||
|
||||
@@ -62,7 +63,6 @@ class RoomUpgradeRestServlet(RestServlet):
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(content, ("new_version",))
|
||||
new_version = content["new_version"]
|
||||
|
||||
new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"])
|
||||
if new_version is None:
|
||||
@@ -72,9 +72,13 @@ class RoomUpgradeRestServlet(RestServlet):
|
||||
Codes.UNSUPPORTED_ROOM_VERSION,
|
||||
)
|
||||
|
||||
new_room_id = await self._room_creation_handler.upgrade_room(
|
||||
requester, room_id, new_version
|
||||
)
|
||||
try:
|
||||
new_room_id = await self._room_creation_handler.upgrade_room(
|
||||
requester, room_id, new_version
|
||||
)
|
||||
except ShadowBanError:
|
||||
# Generate a random room ID.
|
||||
new_room_id = stringutils.random_string(18)
|
||||
|
||||
ret = {"replacement_room": new_room_id}
|
||||
|
||||
|
||||
+122
-70
@@ -16,11 +16,22 @@
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Awaitable, Dict, Iterable, List, Optional, Set
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
from prometheus_client import Histogram
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
|
||||
@@ -30,7 +41,7 @@ from synapse.logging.utils import log_function
|
||||
from synapse.state import v1, v2
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.types import StateMap
|
||||
from synapse.types import Collection, StateMap
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
@@ -68,8 +79,14 @@ def _gen_state_id():
|
||||
class _StateCacheEntry(object):
|
||||
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
|
||||
|
||||
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
|
||||
# dict[(str, str), str] map from (type, state_key) to event_id
|
||||
def __init__(
|
||||
self,
|
||||
state: StateMap[str],
|
||||
state_group: Optional[int],
|
||||
prev_group: Optional[int] = None,
|
||||
delta_ids: Optional[StateMap[str]] = None,
|
||||
):
|
||||
# A map from (type, state_key) to event_id.
|
||||
self.state = frozendict(state)
|
||||
|
||||
# the ID of a state group if one and only one is involved.
|
||||
@@ -107,24 +124,49 @@ class StateHandler(object):
|
||||
self.hs = hs
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
|
||||
@overload
|
||||
async def get_current_state(
|
||||
self, room_id, event_type=None, state_key="", latest_event_ids=None
|
||||
):
|
||||
""" Retrieves the current state for the room. This is done by
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: Literal[None] = None,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> StateMap[EventBase]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> Optional[EventBase]:
|
||||
...
|
||||
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: Optional[str] = None,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> Union[Optional[EventBase], StateMap[EventBase]]:
|
||||
"""Retrieves the current state for the room. This is done by
|
||||
calling `get_latest_events_in_room` to get the leading edges of the
|
||||
event graph and then resolving any of the state conflicts.
|
||||
|
||||
This is equivalent to getting the state of an event that were to send
|
||||
next before receiving any new events.
|
||||
|
||||
If `event_type` is specified, then the method returns only the one
|
||||
event (or None) with that `event_type` and `state_key`.
|
||||
|
||||
Returns:
|
||||
map from (type, state_key) to event
|
||||
If `event_type` is specified, then the method returns only the one
|
||||
event (or None) with that `event_type` and `state_key`.
|
||||
|
||||
Otherwise, a map from (type, state_key) to event.
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state")
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
@@ -140,34 +182,30 @@ class StateHandler(object):
|
||||
state_map = await self.store.get_events(
|
||||
list(state.values()), get_prev_content=False
|
||||
)
|
||||
state = {
|
||||
return {
|
||||
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
|
||||
}
|
||||
|
||||
return state
|
||||
|
||||
async def get_current_state_ids(self, room_id, latest_event_ids=None):
|
||||
async def get_current_state_ids(
|
||||
self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
|
||||
) -> StateMap[str]:
|
||||
"""Get the current state, or the state at a set of events, for a room
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
|
||||
latest_event_ids (iterable[str]|None): if given, the forward
|
||||
extremities to resolve. If None, we look them up from the
|
||||
database (via a cache)
|
||||
room_id:
|
||||
latest_event_ids: if given, the forward extremities to resolve. If
|
||||
None, we look them up from the database (via a cache).
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str)]]: the state dict, mapping from
|
||||
(event_type, state_key) -> event_id
|
||||
the state dict, mapping from (event_type, state_key) -> event_id
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
state = ret.state
|
||||
|
||||
return state
|
||||
return dict(ret.state)
|
||||
|
||||
async def get_current_users_in_room(
|
||||
self, room_id: str, latest_event_ids: Optional[List[str]] = None
|
||||
@@ -183,32 +221,34 @@ class StateHandler(object):
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_users_in_room")
|
||||
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
joined_users = await self.store.get_joined_users_from_state(room_id, entry)
|
||||
return joined_users
|
||||
return await self.store.get_joined_users_from_state(room_id, entry)
|
||||
|
||||
async def get_current_hosts_in_room(self, room_id):
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
return await self.get_hosts_in_room_at_events(room_id, event_ids)
|
||||
|
||||
async def get_hosts_in_room_at_events(self, room_id, event_ids):
|
||||
async def get_hosts_in_room_at_events(
|
||||
self, room_id: str, event_ids: List[str]
|
||||
) -> Set[str]:
|
||||
"""Get the hosts that were in a room at the given event ids
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
event_ids (list[str]):
|
||||
room_id:
|
||||
event_ids:
|
||||
|
||||
Returns:
|
||||
Deferred[list[str]]: the hosts in the room at the given events
|
||||
The hosts in the room at the given events
|
||||
"""
|
||||
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
|
||||
joined_hosts = await self.store.get_joined_hosts(room_id, entry)
|
||||
return joined_hosts
|
||||
return await self.store.get_joined_hosts(room_id, entry)
|
||||
|
||||
async def compute_event_context(
|
||||
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
|
||||
):
|
||||
) -> EventContext:
|
||||
"""Build an EventContext structure for the event.
|
||||
|
||||
This works out what the current state should be for the event, and
|
||||
@@ -221,7 +261,7 @@ class StateHandler(object):
|
||||
when receiving an event from federation where we don't have the
|
||||
prev events for, e.g. when backfilling.
|
||||
Returns:
|
||||
synapse.events.snapshot.EventContext:
|
||||
The event context.
|
||||
"""
|
||||
|
||||
if event.internal_metadata.is_outlier():
|
||||
@@ -275,7 +315,7 @@ class StateHandler(object):
|
||||
event.room_id, event.prev_event_ids()
|
||||
)
|
||||
|
||||
state_ids_before_event = entry.state
|
||||
state_ids_before_event = dict(entry.state)
|
||||
state_group_before_event = entry.state_group
|
||||
state_group_before_event_prev_group = entry.prev_group
|
||||
deltas_to_state_group_before_event = entry.delta_ids
|
||||
@@ -346,19 +386,18 @@ class StateHandler(object):
|
||||
)
|
||||
|
||||
@measure_func()
|
||||
async def resolve_state_groups_for_events(self, room_id, event_ids):
|
||||
async def resolve_state_groups_for_events(
|
||||
self, room_id: str, event_ids: Iterable[str]
|
||||
) -> _StateCacheEntry:
|
||||
""" Given a list of event_ids this method fetches the state at each
|
||||
event, resolves conflicts between them and returns them.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
event_ids (list[str])
|
||||
explicit_room_version (str|None): If set uses the the given room
|
||||
version to choose the resolution algorithm. If None, then
|
||||
checks the database for room version.
|
||||
room_id
|
||||
event_ids
|
||||
|
||||
Returns:
|
||||
Deferred[_StateCacheEntry]: resolved state
|
||||
The resolved state
|
||||
"""
|
||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||
|
||||
@@ -394,7 +433,12 @@ class StateHandler(object):
|
||||
)
|
||||
return result
|
||||
|
||||
async def resolve_events(self, room_version, state_sets, event):
|
||||
async def resolve_events(
|
||||
self,
|
||||
room_version: str,
|
||||
state_sets: Collection[Iterable[EventBase]],
|
||||
event: EventBase,
|
||||
) -> StateMap[EventBase]:
|
||||
logger.info(
|
||||
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
||||
)
|
||||
@@ -414,9 +458,7 @@ class StateHandler(object):
|
||||
state_res_store=StateResolutionStore(self.store),
|
||||
)
|
||||
|
||||
new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()}
|
||||
|
||||
return new_state
|
||||
return {key: state_map[ev_id] for key, ev_id in new_state.items()}
|
||||
|
||||
|
||||
class StateResolutionHandler(object):
|
||||
@@ -444,7 +486,12 @@ class StateResolutionHandler(object):
|
||||
|
||||
@log_function
|
||||
async def resolve_state_groups(
|
||||
self, room_id, room_version, state_groups_ids, event_map, state_res_store
|
||||
self,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
state_groups_ids: Dict[int, StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_res_store: "StateResolutionStore",
|
||||
):
|
||||
"""Resolves conflicts between a set of state groups
|
||||
|
||||
@@ -452,13 +499,13 @@ class StateResolutionHandler(object):
|
||||
not be called for a single state group
|
||||
|
||||
Args:
|
||||
room_id (str): room we are resolving for (used for logging and sanity checks)
|
||||
room_version (str): version of the room
|
||||
state_groups_ids (dict[int, dict[(str, str), str]]):
|
||||
map from state group id to the state in that state group
|
||||
room_id: room we are resolving for (used for logging and sanity checks)
|
||||
room_version: version of the room
|
||||
state_groups_ids:
|
||||
A map from state group id to the state in that state group
|
||||
(where 'state' is a map from state key to event id)
|
||||
|
||||
event_map(dict[str,FrozenEvent]|None):
|
||||
event_map:
|
||||
a dict from event_id to event, for any events that we happen to
|
||||
have in flight (eg, those currently being persisted). This will be
|
||||
used as a starting point fof finding the state we need; any missing
|
||||
@@ -466,10 +513,10 @@ class StateResolutionHandler(object):
|
||||
|
||||
If None, all events will be fetched via state_res_store.
|
||||
|
||||
state_res_store (StateResolutionStore)
|
||||
state_res_store
|
||||
|
||||
Returns:
|
||||
_StateCacheEntry: resolved state
|
||||
The resolved state
|
||||
"""
|
||||
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
|
||||
|
||||
@@ -530,21 +577,22 @@ class StateResolutionHandler(object):
|
||||
return cache
|
||||
|
||||
|
||||
def _make_state_cache_entry(new_state, state_groups_ids):
|
||||
def _make_state_cache_entry(
|
||||
new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
|
||||
) -> _StateCacheEntry:
|
||||
"""Given a resolved state, and a set of input state groups, pick one to base
|
||||
a new state group on (if any), and return an appropriately-constructed
|
||||
_StateCacheEntry.
|
||||
|
||||
Args:
|
||||
new_state (dict[(str, str), str]): resolved state map (mapping from
|
||||
(type, state_key) to event_id)
|
||||
new_state: resolved state map (mapping from (type, state_key) to event_id)
|
||||
|
||||
state_groups_ids (dict[int, dict[(str, str), str]]):
|
||||
map from state group id to the state in that state group
|
||||
(where 'state' is a map from state key to event id)
|
||||
state_groups_ids:
|
||||
map from state group id to the state in that state group (where
|
||||
'state' is a map from state key to event id)
|
||||
|
||||
Returns:
|
||||
_StateCacheEntry
|
||||
The cache entry.
|
||||
"""
|
||||
# if the new state matches any of the input state groups, we can
|
||||
# use that state group again. Otherwise we will generate a state_id
|
||||
@@ -585,7 +633,7 @@ def resolve_events_with_store(
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
state_sets: List[StateMap[str]],
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_res_store: "StateResolutionStore",
|
||||
) -> Awaitable[StateMap[str]]:
|
||||
@@ -633,15 +681,17 @@ class StateResolutionStore(object):
|
||||
|
||||
store = attr.ib()
|
||||
|
||||
def get_events(self, event_ids, allow_rejected=False):
|
||||
def get_events(
|
||||
self, event_ids: Iterable[str], allow_rejected: bool = False
|
||||
) -> Awaitable[Dict[str, EventBase]]:
|
||||
"""Get events from the database
|
||||
|
||||
Args:
|
||||
event_ids (list): The event_ids of the events to fetch
|
||||
allow_rejected (bool): If True return rejected events.
|
||||
event_ids: The event_ids of the events to fetch
|
||||
allow_rejected: If True return rejected events.
|
||||
|
||||
Returns:
|
||||
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
|
||||
An awaitable which resolves to a dict from event_id to event.
|
||||
"""
|
||||
|
||||
return self.store.get_events(
|
||||
@@ -651,7 +701,9 @@ class StateResolutionStore(object):
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
|
||||
def get_auth_chain_difference(
|
||||
self, state_sets: List[Set[str]]
|
||||
) -> Awaitable[Set[str]]:
|
||||
"""Given sets of state events figure out the auth chain difference (as
|
||||
per state res v2 algorithm).
|
||||
|
||||
@@ -660,7 +712,7 @@ class StateResolutionStore(object):
|
||||
chain.
|
||||
|
||||
Returns:
|
||||
Deferred[Set[str]]: Set of event IDs.
|
||||
An awaitable that resolves to a set of event IDs.
|
||||
"""
|
||||
|
||||
return self.store.get_auth_chain_difference(state_sets)
|
||||
|
||||
+59
-28
@@ -15,7 +15,17 @@
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Awaitable, Callable, Dict, List, Optional
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes
|
||||
@@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "")
|
||||
|
||||
async def resolve_events_with_store(
|
||||
room_id: str,
|
||||
state_sets: List[StateMap[str]],
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_map_factory: Callable[[List[str]], Awaitable],
|
||||
):
|
||||
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Args:
|
||||
room_id: the room we are working in
|
||||
@@ -56,8 +66,7 @@ async def resolve_events_with_store(
|
||||
an Awaitable that resolves to a dict of event_id to event.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
A map from (type, state_key) to event_id.
|
||||
"""
|
||||
if len(state_sets) == 1:
|
||||
return state_sets[0]
|
||||
@@ -75,8 +84,8 @@ async def resolve_events_with_store(
|
||||
"Asking for %d/%d conflicted events", len(needed_events), needed_event_count
|
||||
)
|
||||
|
||||
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
|
||||
# the state events which are in conflict (and those in event_map)
|
||||
# A map from state event id to event. Only includes the state events which
|
||||
# are in conflict (and those in event_map).
|
||||
state_map = await state_map_factory(needed_events)
|
||||
if event_map is not None:
|
||||
state_map.update(event_map)
|
||||
@@ -91,8 +100,6 @@ async def resolve_events_with_store(
|
||||
|
||||
# get the ids of the auth events which allow us to authenticate the
|
||||
# conflicted state, picking only from the unconflicting state.
|
||||
#
|
||||
# dict[(str, str), str]: a map from state key to event id
|
||||
auth_events = _create_auth_events_from_maps(
|
||||
unconflicted_state, conflicted_state, state_map
|
||||
)
|
||||
@@ -122,29 +129,30 @@ async def resolve_events_with_store(
|
||||
)
|
||||
|
||||
|
||||
def _seperate(state_sets):
|
||||
def _seperate(
|
||||
state_sets: Iterable[StateMap[str]],
|
||||
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
|
||||
"""Takes the state_sets and figures out which keys are conflicted and
|
||||
which aren't. i.e., which have multiple different event_ids associated
|
||||
with them in different state sets.
|
||||
|
||||
Args:
|
||||
state_sets(iterable[dict[(str, str), str]]):
|
||||
state_sets:
|
||||
List of dicts of (type, state_key) -> event_id, which are the
|
||||
different state groups to resolve.
|
||||
|
||||
Returns:
|
||||
(dict[(str, str), str], dict[(str, str), set[str]]):
|
||||
A tuple of (unconflicted_state, conflicted_state), where:
|
||||
A tuple of (unconflicted_state, conflicted_state), where:
|
||||
|
||||
unconflicted_state is a dict mapping (type, state_key)->event_id
|
||||
for unconflicted state keys.
|
||||
unconflicted_state is a dict mapping (type, state_key)->event_id
|
||||
for unconflicted state keys.
|
||||
|
||||
conflicted_state is a dict mapping (type, state_key) to a set of
|
||||
event ids for conflicted state keys.
|
||||
conflicted_state is a dict mapping (type, state_key) to a set of
|
||||
event ids for conflicted state keys.
|
||||
"""
|
||||
state_set_iterator = iter(state_sets)
|
||||
unconflicted_state = dict(next(state_set_iterator))
|
||||
conflicted_state = {}
|
||||
conflicted_state = {} # type: StateMap[Set[str]]
|
||||
|
||||
for state_set in state_set_iterator:
|
||||
for key, value in state_set.items():
|
||||
@@ -171,7 +179,21 @@ def _seperate(state_sets):
|
||||
return unconflicted_state, conflicted_state
|
||||
|
||||
|
||||
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
|
||||
def _create_auth_events_from_maps(
|
||||
unconflicted_state: StateMap[str],
|
||||
conflicted_state: StateMap[Set[str]],
|
||||
state_map: Dict[str, EventBase],
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
unconflicted_state: The unconflicted state map.
|
||||
conflicted_state: The conflicted state map.
|
||||
state_map:
|
||||
|
||||
Returns:
|
||||
A map from state key to event id.
|
||||
"""
|
||||
auth_events = {}
|
||||
for event_ids in conflicted_state.values():
|
||||
for event_id in event_ids:
|
||||
@@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
|
||||
keys = event_auth.auth_types_for_event(state_map[event_id])
|
||||
for key in keys:
|
||||
if key not in auth_events:
|
||||
event_id = unconflicted_state.get(key, None)
|
||||
if event_id:
|
||||
auth_events[key] = event_id
|
||||
auth_event_id = unconflicted_state.get(key, None)
|
||||
if auth_event_id:
|
||||
auth_events[key] = auth_event_id
|
||||
return auth_events
|
||||
|
||||
|
||||
def _resolve_with_state(
|
||||
unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
|
||||
unconflicted_state_ids: StateMap[str],
|
||||
conflicted_state_ids: StateMap[Set[str]],
|
||||
auth_event_ids: StateMap[str],
|
||||
state_map: Dict[str, EventBase],
|
||||
):
|
||||
conflicted_state = {}
|
||||
for key, event_ids in conflicted_state_ids.items():
|
||||
@@ -215,7 +240,9 @@ def _resolve_with_state(
|
||||
return new_state
|
||||
|
||||
|
||||
def _resolve_state_events(conflicted_state, auth_events):
|
||||
def _resolve_state_events(
|
||||
conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase]
|
||||
) -> StateMap[EventBase]:
|
||||
""" This is where we actually decide which of the conflicted state to
|
||||
use.
|
||||
|
||||
@@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events):
|
||||
return resolved_state
|
||||
|
||||
|
||||
def _resolve_auth_events(events, auth_events):
|
||||
def _resolve_auth_events(
|
||||
events: List[EventBase], auth_events: StateMap[EventBase]
|
||||
) -> EventBase:
|
||||
reverse = list(reversed(_ordered_events(events)))
|
||||
|
||||
auth_keys = {
|
||||
@@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events):
|
||||
return event
|
||||
|
||||
|
||||
def _resolve_normal_events(events, auth_events):
|
||||
def _resolve_normal_events(
|
||||
events: List[EventBase], auth_events: StateMap[EventBase]
|
||||
) -> EventBase:
|
||||
for event in _ordered_events(events):
|
||||
try:
|
||||
# The signatures have already been checked at this point
|
||||
@@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events):
|
||||
return event
|
||||
|
||||
|
||||
def _ordered_events(events):
|
||||
def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
|
||||
def key_func(e):
|
||||
# we have to use utf-8 rather than ascii here because it turns out we allow
|
||||
# people to send us events with non-ascii event IDs :/
|
||||
|
||||
+167
-88
@@ -16,7 +16,21 @@
|
||||
import heapq
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
import synapse.state
|
||||
from synapse import event_auth
|
||||
@@ -40,10 +54,10 @@ async def resolve_events_with_store(
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
state_sets: List[StateMap[str]],
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
):
|
||||
) -> StateMap[str]:
|
||||
"""Resolves the state using the v2 state resolution algorithm
|
||||
|
||||
Args:
|
||||
@@ -63,8 +77,7 @@ async def resolve_events_with_store(
|
||||
state_res_store:
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
A map from (type, state_key) to event_id.
|
||||
"""
|
||||
|
||||
logger.debug("Computing conflicted state")
|
||||
@@ -171,18 +184,23 @@ async def resolve_events_with_store(
|
||||
return resolved_state
|
||||
|
||||
|
||||
async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
|
||||
async def _get_power_level_for_sender(
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
) -> int:
|
||||
"""Return the power level of the sender of the given event according to
|
||||
their auth events.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
event_id (str)
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
room_id
|
||||
event_id
|
||||
event_map
|
||||
state_res_store
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
The power level.
|
||||
"""
|
||||
event = await _get_event(room_id, event_id, event_map, state_res_store)
|
||||
|
||||
@@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st
|
||||
return int(level)
|
||||
|
||||
|
||||
async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
|
||||
async def _get_auth_chain_difference(
|
||||
state_sets: Sequence[StateMap[str]],
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
) -> Set[str]:
|
||||
"""Compare the auth chains of each state set and return the set of events
|
||||
that only appear in some but not all of the auth chains.
|
||||
|
||||
Args:
|
||||
state_sets (list)
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
state_sets
|
||||
event_map
|
||||
state_res_store
|
||||
|
||||
Returns:
|
||||
Deferred[set[str]]: Set of event IDs
|
||||
Set of event IDs
|
||||
"""
|
||||
|
||||
difference = await state_res_store.get_auth_chain_difference(
|
||||
@@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
|
||||
return difference
|
||||
|
||||
|
||||
def _seperate(state_sets):
|
||||
def _seperate(
|
||||
state_sets: Iterable[StateMap[str]],
|
||||
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
|
||||
"""Return the unconflicted and conflicted state. This is different than in
|
||||
the original algorithm, as this defines a key to be conflicted if one of
|
||||
the state sets doesn't have that key.
|
||||
|
||||
Args:
|
||||
state_sets (list)
|
||||
state_sets
|
||||
|
||||
Returns:
|
||||
tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
|
||||
conflicted state dict is a map from type/state_key to set of event IDs
|
||||
A tuple of unconflicted and conflicted state. The conflicted state dict
|
||||
is a map from type/state_key to set of event IDs
|
||||
"""
|
||||
unconflicted_state = {}
|
||||
conflicted_state = {}
|
||||
@@ -260,18 +284,20 @@ def _seperate(state_sets):
|
||||
event_ids.discard(None)
|
||||
conflicted_state[key] = event_ids
|
||||
|
||||
return unconflicted_state, conflicted_state
|
||||
# mypy doesn't understand that discarding None above means that conflicted
|
||||
# state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
|
||||
return unconflicted_state, conflicted_state # type: ignore
|
||||
|
||||
|
||||
def _is_power_event(event):
|
||||
def _is_power_event(event: EventBase) -> bool:
|
||||
"""Return whether or not the event is a "power event", as defined by the
|
||||
v2 state resolution algorithm
|
||||
|
||||
Args:
|
||||
event (FrozenEvent)
|
||||
event
|
||||
|
||||
Returns:
|
||||
boolean
|
||||
True if the event is a power event.
|
||||
"""
|
||||
if (event.type, event.state_key) in (
|
||||
(EventTypes.PowerLevels, ""),
|
||||
@@ -288,19 +314,23 @@ def _is_power_event(event):
|
||||
|
||||
|
||||
async def _add_event_and_auth_chain_to_graph(
|
||||
graph, room_id, event_id, event_map, state_res_store, auth_diff
|
||||
):
|
||||
graph: Dict[str, Set[str]],
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
auth_diff: Set[str],
|
||||
) -> None:
|
||||
"""Helper function for _reverse_topological_power_sort that add the event
|
||||
and its auth chain (that is in the auth diff) to the graph
|
||||
|
||||
Args:
|
||||
graph (dict[str, set[str]]): A map from event ID to the events auth
|
||||
event IDs
|
||||
room_id (str): the room we are working in
|
||||
event_id (str): Event to add to the graph
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
auth_diff (set[str]): Set of event IDs that are in the auth difference.
|
||||
graph: A map from event ID to the events auth event IDs
|
||||
room_id: the room we are working in
|
||||
event_id: Event to add to the graph
|
||||
event_map
|
||||
state_res_store
|
||||
auth_diff: Set of event IDs that are in the auth difference.
|
||||
"""
|
||||
|
||||
state = [event_id]
|
||||
@@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph(
|
||||
|
||||
|
||||
async def _reverse_topological_power_sort(
|
||||
clock, room_id, event_ids, event_map, state_res_store, auth_diff
|
||||
):
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
event_ids: Iterable[str],
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
auth_diff: Set[str],
|
||||
) -> List[str]:
|
||||
"""Returns a list of the event_ids sorted by reverse topological ordering,
|
||||
and then by power level and origin_server_ts
|
||||
|
||||
Args:
|
||||
clock (Clock)
|
||||
room_id (str): the room we are working in
|
||||
event_ids (list[str]): The events to sort
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
auth_diff (set[str]): Set of event IDs that are in the auth difference.
|
||||
clock
|
||||
room_id: the room we are working in
|
||||
event_ids: The events to sort
|
||||
event_map
|
||||
state_res_store
|
||||
auth_diff: Set of event IDs that are in the auth difference.
|
||||
|
||||
Returns:
|
||||
Deferred[list[str]]: The sorted list
|
||||
The sorted list
|
||||
"""
|
||||
|
||||
graph = {}
|
||||
graph = {} # type: Dict[str, Set[str]]
|
||||
for idx, event_id in enumerate(event_ids, start=1):
|
||||
await _add_event_and_auth_chain_to_graph(
|
||||
graph, room_id, event_id, event_map, state_res_store, auth_diff
|
||||
@@ -372,22 +407,28 @@ async def _reverse_topological_power_sort(
|
||||
|
||||
|
||||
async def _iterative_auth_checks(
|
||||
clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
|
||||
):
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
event_ids: List[str],
|
||||
base_state: StateMap[str],
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
) -> StateMap[str]:
|
||||
"""Sequentially apply auth checks to each event in given list, updating the
|
||||
state as it goes along.
|
||||
|
||||
Args:
|
||||
clock (Clock)
|
||||
room_id (str)
|
||||
room_version (str)
|
||||
event_ids (list[str]): Ordered list of events to apply auth checks to
|
||||
base_state (StateMap[str]): The set of state to start with
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
clock
|
||||
room_id
|
||||
room_version
|
||||
event_ids: Ordered list of events to apply auth checks to
|
||||
base_state: The set of state to start with
|
||||
event_map
|
||||
state_res_store
|
||||
|
||||
Returns:
|
||||
Deferred[StateMap[str]]: Returns the final updated state
|
||||
Returns the final updated state
|
||||
"""
|
||||
resolved_state = base_state.copy()
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
@@ -439,21 +480,26 @@ async def _iterative_auth_checks(
|
||||
|
||||
|
||||
async def _mainline_sort(
|
||||
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
|
||||
):
|
||||
clock: Clock,
|
||||
room_id: str,
|
||||
event_ids: List[str],
|
||||
resolved_power_event_id: Optional[str],
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
) -> List[str]:
|
||||
"""Returns a sorted list of event_ids sorted by mainline ordering based on
|
||||
the given event resolved_power_event_id
|
||||
|
||||
Args:
|
||||
clock (Clock)
|
||||
room_id (str): room we're working in
|
||||
event_ids (list[str]): Events to sort
|
||||
resolved_power_event_id (str): The final resolved power level event ID
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
clock
|
||||
room_id: room we're working in
|
||||
event_ids: Events to sort
|
||||
resolved_power_event_id: The final resolved power level event ID
|
||||
event_map
|
||||
state_res_store
|
||||
|
||||
Returns:
|
||||
Deferred[list[str]]: The sorted list
|
||||
The sorted list
|
||||
"""
|
||||
if not event_ids:
|
||||
# It's possible for there to be no event IDs here to sort, so we can
|
||||
@@ -505,59 +551,90 @@ async def _mainline_sort(
|
||||
|
||||
|
||||
async def _get_mainline_depth_for_event(
|
||||
event, mainline_map, event_map, state_res_store
|
||||
):
|
||||
event: EventBase,
|
||||
mainline_map: Dict[str, int],
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
) -> int:
|
||||
"""Get the mainline depths for the given event based on the mainline map
|
||||
|
||||
Args:
|
||||
event (FrozenEvent)
|
||||
mainline_map (dict[str, int]): Map from event_id to mainline depth for
|
||||
events in the mainline.
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
event
|
||||
mainline_map: Map from event_id to mainline depth for events in the mainline.
|
||||
event_map
|
||||
state_res_store
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
The mainline depth
|
||||
"""
|
||||
|
||||
room_id = event.room_id
|
||||
tmp_event = event # type: Optional[EventBase]
|
||||
|
||||
# We do an iterative search, replacing `event with the power level in its
|
||||
# auth events (if any)
|
||||
while event:
|
||||
while tmp_event:
|
||||
depth = mainline_map.get(event.event_id)
|
||||
if depth is not None:
|
||||
return depth
|
||||
|
||||
auth_events = event.auth_event_ids()
|
||||
event = None
|
||||
auth_events = tmp_event.auth_event_ids()
|
||||
tmp_event = None
|
||||
|
||||
for aid in auth_events:
|
||||
aev = await _get_event(
|
||||
room_id, aid, event_map, state_res_store, allow_none=True
|
||||
)
|
||||
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
|
||||
event = aev
|
||||
tmp_event = aev
|
||||
break
|
||||
|
||||
# Didn't find a power level auth event, so we just return 0
|
||||
return 0
|
||||
|
||||
|
||||
async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
|
||||
@overload
|
||||
async def _get_event(
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
allow_none: Literal[False] = False,
|
||||
) -> EventBase:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
async def _get_event(
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
allow_none: Literal[True],
|
||||
) -> Optional[EventBase]:
|
||||
...
|
||||
|
||||
|
||||
async def _get_event(
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
allow_none: bool = False,
|
||||
) -> Optional[EventBase]:
|
||||
"""Helper function to look up event in event_map, falling back to looking
|
||||
it up in the store
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
event_id (str)
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
allow_none (bool): if the event is not found, return None rather than raising
|
||||
room_id
|
||||
event_id
|
||||
event_map
|
||||
state_res_store
|
||||
allow_none: if the event is not found, return None rather than raising
|
||||
an exception
|
||||
|
||||
Returns:
|
||||
Deferred[Optional[FrozenEvent]]
|
||||
The event, or none if the event does not exist (and allow_none is True).
|
||||
"""
|
||||
if event_id not in event_map:
|
||||
events = await state_res_store.get_events([event_id], allow_rejected=True)
|
||||
@@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F
|
||||
return event
|
||||
|
||||
|
||||
def lexicographical_topological_sort(graph, key):
|
||||
def lexicographical_topological_sort(
|
||||
graph: Dict[str, Set[str]], key: Callable[[str], Any]
|
||||
) -> Generator[str, None, None]:
|
||||
"""Performs a lexicographic reverse topological sort on the graph.
|
||||
|
||||
This returns a reverse topological sort (i.e. if node A references B then B
|
||||
@@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key):
|
||||
NOTE: `graph` is modified during the sort.
|
||||
|
||||
Args:
|
||||
graph (dict[str, set[str]]): A representation of the graph where each
|
||||
node is a key in the dict and its value are the nodes edges.
|
||||
key (func): A function that takes a node and returns a value that is
|
||||
comparable and used to order nodes
|
||||
graph: A representation of the graph where each node is a key in the
|
||||
dict and its value are the nodes edges.
|
||||
key: A function that takes a node and returns a value that is comparable
|
||||
and used to order nodes
|
||||
|
||||
Yields:
|
||||
str: The next node in the topological sort
|
||||
The next node in the topological sort
|
||||
"""
|
||||
|
||||
# Note, this is basically Kahn's algorithm except we look at nodes with no
|
||||
# outgoing edges, c.f.
|
||||
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
|
||||
outdegree_map = graph
|
||||
reverse_graph = {}
|
||||
reverse_graph = {} # type: Dict[str, Set[str]]
|
||||
|
||||
# Lists of nodes with zero out degree. Is actually a tuple of
|
||||
# `(key(node), node)` so that sorting does the right thing
|
||||
|
||||
@@ -29,9 +29,11 @@ from typing import (
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from prometheus_client import Histogram
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.internet import defer
|
||||
@@ -1020,14 +1022,36 @@ class DatabasePool(object):
|
||||
|
||||
return txn.execute_batch(sql, args)
|
||||
|
||||
def simple_select_one(
|
||||
@overload
|
||||
async def simple_select_one(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcols: Iterable[str],
|
||||
allow_none: Literal[False] = False,
|
||||
desc: str = "simple_select_one",
|
||||
) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def simple_select_one(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcols: Iterable[str],
|
||||
allow_none: Literal[True] = True,
|
||||
desc: str = "simple_select_one",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def simple_select_one(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcols: Iterable[str],
|
||||
allow_none: bool = False,
|
||||
desc: str = "simple_select_one",
|
||||
) -> defer.Deferred:
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning multiple columns from it.
|
||||
|
||||
@@ -1038,18 +1062,18 @@ class DatabasePool(object):
|
||||
allow_none: If true, return None instead of failing if the SELECT
|
||||
statement returns no rows
|
||||
"""
|
||||
return self.runInteraction(
|
||||
return await self.runInteraction(
|
||||
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
|
||||
)
|
||||
|
||||
def simple_select_one_onecol(
|
||||
async def simple_select_one_onecol(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
allow_none: bool = False,
|
||||
desc: str = "simple_select_one_onecol",
|
||||
) -> defer.Deferred:
|
||||
) -> Optional[Any]:
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it.
|
||||
|
||||
@@ -1061,7 +1085,7 @@ class DatabasePool(object):
|
||||
statement returns no rows
|
||||
desc: description of the transaction, for logging and metrics
|
||||
"""
|
||||
return self.runInteraction(
|
||||
return await self.runInteraction(
|
||||
desc,
|
||||
self.simple_select_one_onecol_txn,
|
||||
table,
|
||||
|
||||
@@ -498,7 +498,7 @@ class DataStore(
|
||||
)
|
||||
|
||||
def get_users_paginate(
|
||||
self, start, limit, name=None, guests=True, deactivated=False
|
||||
self, start, limit, user_id=None, name=None, guests=True, deactivated=False
|
||||
):
|
||||
"""Function to retrieve a paginated list of users from
|
||||
users list. This will return a json list of users and the
|
||||
@@ -507,7 +507,8 @@ class DataStore(
|
||||
Args:
|
||||
start (int): start number to begin the query from
|
||||
limit (int): number of rows to retrieve
|
||||
name (string): filter for user names
|
||||
user_id (string): search for user_id. ignored if name is not None
|
||||
name (string): search for local part of user_id or display name
|
||||
guests (bool): whether to in include guest users
|
||||
deactivated (bool): whether to include deactivated users
|
||||
Returns:
|
||||
@@ -516,11 +517,14 @@ class DataStore(
|
||||
|
||||
def get_users_paginate_txn(txn):
|
||||
filters = []
|
||||
args = []
|
||||
args = [self.hs.config.server_name]
|
||||
|
||||
if name:
|
||||
filters.append("(name LIKE ? OR displayname LIKE ?)")
|
||||
args.extend(["@%" + name + "%:%", "%" + name + "%"])
|
||||
elif user_id:
|
||||
filters.append("name LIKE ?")
|
||||
args.append("%" + name + "%")
|
||||
args.extend(["%" + user_id + "%"])
|
||||
|
||||
if not guests:
|
||||
filters.append("is_guest = 0")
|
||||
@@ -530,20 +534,23 @@ class DataStore(
|
||||
|
||||
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
|
||||
|
||||
sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
|
||||
txn.execute(sql, args)
|
||||
count = txn.fetchone()[0]
|
||||
|
||||
args = [self.hs.config.server_name] + args + [limit, start]
|
||||
sql = """
|
||||
SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
|
||||
sql_base = """
|
||||
FROM users as u
|
||||
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
|
||||
{}
|
||||
ORDER BY u.name LIMIT ? OFFSET ?
|
||||
""".format(
|
||||
where_clause
|
||||
)
|
||||
sql = "SELECT COUNT(*) as total_users " + sql_base
|
||||
txn.execute(sql, args)
|
||||
count = txn.fetchone()[0]
|
||||
|
||||
sql = (
|
||||
"SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
|
||||
+ sql_base
|
||||
+ " ORDER BY u.name LIMIT ? OFFSET ?"
|
||||
)
|
||||
args += [limit, start]
|
||||
txn.execute(sql, args)
|
||||
users = self.db_pool.cursor_to_dict(txn)
|
||||
return users, count
|
||||
|
||||
@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
"""
|
||||
content_json = json_encoder.encode(content)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
# no need to lock here as room_account_data has a unique constraint
|
||||
# on (user_id, room_id, account_data_type) so simple_upsert will
|
||||
# retry if there is a conflict.
|
||||
@@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
"""
|
||||
content_json = json_encoder.encode(content)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
# no need to lock here as account_data has a unique constraint on
|
||||
# (user_id, account_data_type) so simple_upsert will retry if
|
||||
# there is a conflict.
|
||||
|
||||
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
||||
rows.append((destination, stream_id, now_ms, edu_json))
|
||||
txn.executemany(sql, rows)
|
||||
|
||||
with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
with await self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
await self.db_pool.runInteraction(
|
||||
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
|
||||
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
|
||||
txn, stream_id, local_messages_by_user_then_device
|
||||
)
|
||||
|
||||
with self._device_inbox_id_gen.get_next() as stream_id:
|
||||
with await self._device_inbox_id_gen.get_next() as stream_id:
|
||||
now_ms = self.clock.time_msec()
|
||||
await self.db_pool.runInteraction(
|
||||
"add_messages_from_remote_to_device_inbox",
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, StoreError
|
||||
from synapse.logging.opentracing import (
|
||||
@@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
|
||||
|
||||
|
||||
class DeviceWorkerStore(SQLBaseStore):
|
||||
def get_device(self, user_id: str, device_id: str):
|
||||
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
|
||||
"""Retrieve a device. Only returns devices that are not marked as
|
||||
hidden.
|
||||
|
||||
@@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
user_id: The ID of the user which owns the device
|
||||
device_id: The ID of the device to retrieve
|
||||
Returns:
|
||||
defer.Deferred for a dict containing the device information
|
||||
A dict containing the device information
|
||||
Raises:
|
||||
StoreError: if the device is not found
|
||||
"""
|
||||
return self.db_pool.simple_select_one(
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
@@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
THe new stream ID.
|
||||
"""
|
||||
|
||||
with self._device_list_id_gen.get_next() as stream_id:
|
||||
with await self._device_list_id_gen.get_next() as stream_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_user_sig_change_to_streams",
|
||||
self._add_user_signature_change_txn,
|
||||
@@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def get_device_list_last_stream_id_for_remote(self, user_id: str):
|
||||
async def get_device_list_last_stream_id_for_remote(
|
||||
self, user_id: str
|
||||
) -> Optional[Any]:
|
||||
"""Get the last stream_id we got for a user. May be None if we haven't
|
||||
got any information for them.
|
||||
"""
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="device_lists_remote_extremeties",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="stream_id",
|
||||
@@ -1146,7 +1148,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
if not device_ids:
|
||||
return
|
||||
|
||||
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
|
||||
with await self._device_list_id_gen.get_next_mult(
|
||||
len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_device_change_to_stream",
|
||||
self._add_device_change_to_stream_txn,
|
||||
@@ -1159,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
return stream_ids[-1]
|
||||
|
||||
context = get_active_span_text_map()
|
||||
with self._device_list_id_gen.get_next_mult(
|
||||
with await self._device_list_id_gen.get_next_mult(
|
||||
len(hosts) * len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
|
||||
@@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
|
||||
|
||||
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
|
||||
|
||||
def get_room_alias_creator(self, room_alias):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def get_room_alias_creator(self, room_alias: str) -> str:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="room_aliases",
|
||||
keyvalues={"room_alias": room_alias},
|
||||
retcol="creator",
|
||||
|
||||
@@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
|
||||
return ret
|
||||
|
||||
def count_e2e_room_keys(self, user_id, version):
|
||||
async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
|
||||
"""Get the number of keys in a backup version.
|
||||
|
||||
Args:
|
||||
user_id (str): the user whose backup we're querying
|
||||
version (str): the version ID of the backup we're querying about
|
||||
user_id: the user whose backup we're querying
|
||||
version: the version ID of the backup we're querying about
|
||||
"""
|
||||
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="e2e_room_keys",
|
||||
keyvalues={"user_id": user_id, "version": version},
|
||||
retcol="COUNT(*)",
|
||||
|
||||
@@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||
)
|
||||
|
||||
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
|
||||
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
|
||||
"""Set a user's cross-signing key.
|
||||
|
||||
Args:
|
||||
@@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
for a master key, 'self_signing' for a self-signing key, or
|
||||
'user_signing' for a user-signing key
|
||||
key (dict): the key data
|
||||
stream_id (int)
|
||||
"""
|
||||
# the 'key' dict will look something like:
|
||||
# {
|
||||
@@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
)
|
||||
|
||||
# and finally, store the key itself
|
||||
with self._cross_signing_id_gen.get_next() as stream_id:
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
"e2e_cross_signing_keys",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"keytype": key_type,
|
||||
"keydata": json_encoder.encode(key),
|
||||
"stream_id": stream_id,
|
||||
},
|
||||
)
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
"e2e_cross_signing_keys",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"keytype": key_type,
|
||||
"keydata": json_encoder.encode(key),
|
||||
"stream_id": stream_id,
|
||||
},
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
|
||||
)
|
||||
|
||||
def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
||||
async def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
||||
"""Set a user's cross-signing key.
|
||||
|
||||
Args:
|
||||
@@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
key_type (str): the type of cross-signing key to set
|
||||
key (dict): the key data
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
"add_e2e_cross_signing_key",
|
||||
self._set_e2e_cross_signing_key_txn,
|
||||
user_id,
|
||||
key_type,
|
||||
key,
|
||||
)
|
||||
|
||||
with await self._cross_signing_id_gen.get_next() as stream_id:
|
||||
return await self.db_pool.runInteraction(
|
||||
"add_e2e_cross_signing_key",
|
||||
self._set_e2e_cross_signing_key_txn,
|
||||
user_id,
|
||||
key_type,
|
||||
key,
|
||||
stream_id,
|
||||
)
|
||||
|
||||
def store_e2e_cross_signing_signatures(self, user_id, signatures):
|
||||
"""Stores cross-signing signatures.
|
||||
|
||||
@@ -153,11 +153,11 @@ class PersistEventsStore:
|
||||
# Note: Multiple instances of this function cannot be in flight at
|
||||
# the same time for the same room.
|
||||
if backfilled:
|
||||
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
|
||||
stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
)
|
||||
else:
|
||||
stream_ordering_manager = self._stream_id_gen.get_next_mult(
|
||||
stream_ordering_manager = await self._stream_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
)
|
||||
|
||||
|
||||
@@ -113,25 +113,25 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == EventsStream.NAME:
|
||||
self._stream_id_gen.advance(token)
|
||||
self._stream_id_gen.advance(instance_name, token)
|
||||
elif stream_name == BackfillStream.NAME:
|
||||
self._backfill_id_gen.advance(-token)
|
||||
self._backfill_id_gen.advance(instance_name, -token)
|
||||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def get_received_ts(self, event_id):
|
||||
async def get_received_ts(self, event_id: str) -> Optional[int]:
|
||||
"""Get received_ts (when it was persisted) for the event.
|
||||
|
||||
Raises an exception for unknown events.
|
||||
|
||||
Args:
|
||||
event_id (str)
|
||||
event_id: The event ID to query.
|
||||
|
||||
Returns:
|
||||
Deferred[int|None]: Timestamp in milliseconds, or None for events
|
||||
that were persisted before received_ts was implemented.
|
||||
Timestamp in milliseconds, or None for events that were persisted
|
||||
before received_ts was implemented.
|
||||
"""
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="events",
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="received_ts",
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
|
||||
|
||||
|
||||
class GroupServerWorkerStore(SQLBaseStore):
|
||||
def get_group(self, group_id):
|
||||
return self.db_pool.simple_select_one(
|
||||
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="groups",
|
||||
keyvalues={"group_id": group_id},
|
||||
retcols=(
|
||||
@@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
def is_user_admin_in_group(self, group_id, user_id):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def is_user_admin_in_group(
|
||||
self, group_id: str, user_id: str
|
||||
) -> Optional[bool]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="group_users",
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
retcol="is_admin",
|
||||
@@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
desc="is_user_admin_in_group",
|
||||
)
|
||||
|
||||
def is_user_invited_to_local_group(self, group_id, user_id):
|
||||
async def is_user_invited_to_local_group(
|
||||
self, group_id: str, user_id: str
|
||||
) -> Optional[bool]:
|
||||
"""Has the group server invited a user?
|
||||
"""
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="group_invites",
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
retcol="user_id",
|
||||
@@ -1182,7 +1186,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
|
||||
return next_id
|
||||
|
||||
with self._group_updates_id_gen.get_next() as next_id:
|
||||
with await self._group_updates_id_gen.get_next() as next_id:
|
||||
res = await self.db_pool.runInteraction(
|
||||
"register_user_group_membership",
|
||||
_register_user_group_membership_txn,
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
||||
@@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
def get_local_media(self, media_id):
|
||||
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the metadata for a local piece of media
|
||||
|
||||
Returns:
|
||||
None if the media_id doesn't exist.
|
||||
"""
|
||||
return self.db_pool.simple_select_one(
|
||||
return await self.db_pool.simple_select_one(
|
||||
"local_media_repository",
|
||||
{"media_id": media_id},
|
||||
(
|
||||
@@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
desc="store_local_thumbnail",
|
||||
)
|
||||
|
||||
def get_cached_remote_media(self, origin, media_id):
|
||||
return self.db_pool.simple_select_one(
|
||||
async def get_cached_remote_media(
|
||||
self, origin, media_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
"remote_media_cache",
|
||||
{"media_origin": origin, "media_id": media_id},
|
||||
(
|
||||
|
||||
@@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
||||
return users
|
||||
|
||||
@cached(num_args=1)
|
||||
def user_last_seen_monthly_active(self, user_id):
|
||||
async def user_last_seen_monthly_active(self, user_id: str) -> int:
|
||||
"""
|
||||
Checks if a given user is part of the monthly active user group
|
||||
Arguments:
|
||||
user_id (str): user to add/update
|
||||
Return:
|
||||
Deferred[int] : timestamp since last seen, None if never seen
|
||||
Checks if a given user is part of the monthly active user group
|
||||
|
||||
Arguments:
|
||||
user_id: user to add/update
|
||||
|
||||
Return:
|
||||
Timestamp since last seen, None if never seen
|
||||
"""
|
||||
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="monthly_active_users",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="timestamp",
|
||||
|
||||
@@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter
|
||||
|
||||
class PresenceStore(SQLBaseStore):
|
||||
async def update_presence(self, presence_states):
|
||||
stream_ordering_manager = self._presence_id_gen.get_next_mult(
|
||||
stream_ordering_manager = await self._presence_id_gen.get_next_mult(
|
||||
len(presence_states)
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
@@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
|
||||
|
||||
|
||||
class ProfileWorkerStore(SQLBaseStore):
|
||||
async def get_profileinfo(self, user_localpart):
|
||||
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
|
||||
try:
|
||||
profile = await self.db_pool.simple_select_one(
|
||||
table="profiles",
|
||||
@@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
||||
)
|
||||
|
||||
def get_profile_displayname(self, user_localpart):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def get_profile_displayname(self, user_localpart: str) -> str:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="displayname",
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
|
||||
def get_profile_avatar_url(self, user_localpart):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def get_profile_avatar_url(self, user_localpart: str) -> str:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="avatar_url",
|
||||
desc="get_profile_avatar_url",
|
||||
)
|
||||
|
||||
def get_from_remote_profile_cache(self, user_id):
|
||||
return self.db_pool.simple_select_one(
|
||||
async def get_from_remote_profile_cache(
|
||||
self, user_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("displayname", "avatar_url"),
|
||||
|
||||
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
) -> None:
|
||||
conditions_json = json_encoder.encode(conditions)
|
||||
actions_json = json_encoder.encode(actions)
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
if before or after:
|
||||
@@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
|
||||
)
|
||||
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
@@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
)
|
||||
|
||||
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
@@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
data={"actions": actions_json},
|
||||
)
|
||||
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
with await self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
|
||||
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
|
||||
last_stream_ordering,
|
||||
profile_tag="",
|
||||
) -> None:
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
with await self._pushers_id_gen.get_next() as stream_id:
|
||||
# no need to lock because `pushers` has a unique key on
|
||||
# (app_id, pushkey, user_name) so simple_upsert will retry
|
||||
await self.db_pool.simple_upsert(
|
||||
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
|
||||
},
|
||||
)
|
||||
|
||||
with self._pushers_id_gen.get_next() as stream_id:
|
||||
with await self._pushers_id_gen.get_next() as stream_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_pusher", delete_pusher_txn, stream_id
|
||||
)
|
||||
|
||||
@@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@cached(num_args=3)
|
||||
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def get_last_receipt_event_id_for_user(
|
||||
self, user_id: str, room_id: str, receipt_type: str
|
||||
) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="receipts_linearized",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
@@ -520,8 +522,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
"insert_receipt_conv", graph_to_linear
|
||||
)
|
||||
|
||||
stream_id_manager = self._receipts_id_gen.get_next()
|
||||
with stream_id_manager as stream_id:
|
||||
with await self._receipts_id_gen.get_next() as stream_id:
|
||||
event_ts = await self.db_pool.runInteraction(
|
||||
"insert_linearized_receipt",
|
||||
self.insert_linearized_receipt_txn,
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Awaitable, Dict, List, Optional
|
||||
from typing import Any, Awaitable, Dict, List, Optional
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||
@@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_user_by_id(self, user_id):
|
||||
return self.db_pool.simple_select_one(
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
retcols=[
|
||||
@@ -889,6 +889,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
super(RegistrationStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
self._account_validity = hs.config.account_validity
|
||||
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
|
||||
|
||||
if self._account_validity.enabled:
|
||||
self._clock.call_later(
|
||||
@@ -1258,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
desc="del_user_pending_deactivation",
|
||||
)
|
||||
|
||||
def get_user_pending_deactivation(self):
|
||||
async def get_user_pending_deactivation(self) -> Optional[str]:
|
||||
"""
|
||||
Gets one user from the table of users waiting to be parted from all the rooms
|
||||
they're in.
|
||||
"""
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
"users_pending_deactivation",
|
||||
keyvalues={},
|
||||
retcol="user_id",
|
||||
@@ -1302,15 +1303,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
)
|
||||
|
||||
if not row:
|
||||
raise ThreepidValidationError(400, "Unknown session_id")
|
||||
if self._ignore_unknown_session_error:
|
||||
# If we need to inhibit the error caused by an incorrect session ID,
|
||||
# use None as placeholder values for the client secret and the
|
||||
# validation timestamp.
|
||||
# It shouldn't be an issue because they're both only checked after
|
||||
# the token check, which should fail. And if it doesn't for some
|
||||
# reason, the next check is on the client secret, which is NOT NULL,
|
||||
# so we don't have to worry about the client secret matching by
|
||||
# accident.
|
||||
row = {"client_secret": None, "validated_at": None}
|
||||
else:
|
||||
raise ThreepidValidationError(400, "Unknown session_id")
|
||||
|
||||
retrieved_client_secret = row["client_secret"]
|
||||
validated_at = row["validated_at"]
|
||||
|
||||
if retrieved_client_secret != client_secret:
|
||||
raise ThreepidValidationError(
|
||||
400, "This client_secret does not match the provided session_id"
|
||||
)
|
||||
|
||||
row = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="threepid_validation_token",
|
||||
@@ -1326,6 +1334,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
expires = row["expires"]
|
||||
next_link = row["next_link"]
|
||||
|
||||
if retrieved_client_secret != client_secret:
|
||||
raise ThreepidValidationError(
|
||||
400, "This client_secret does not match the provided session_id"
|
||||
)
|
||||
|
||||
# If the session is already validated, no need to revalidate
|
||||
if validated_at:
|
||||
return next_link
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
|
||||
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RejectionsStore(SQLBaseStore):
|
||||
def get_rejection_reason(self, event_id):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def get_rejection_reason(self, event_id: str) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="rejections",
|
||||
retcol="reason",
|
||||
keyvalues={"event_id": event_id},
|
||||
|
||||
@@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
|
||||
self.config = hs.config
|
||||
|
||||
def get_room(self, room_id):
|
||||
async def get_room(self, room_id: str) -> dict:
|
||||
"""Retrieve a room.
|
||||
|
||||
Args:
|
||||
room_id (str): The ID of the room to retrieve.
|
||||
room_id: The ID of the room to retrieve.
|
||||
Returns:
|
||||
A dict containing the room information, or None if the room is unknown.
|
||||
"""
|
||||
return self.db_pool.simple_select_one(
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("room_id", "is_public", "creator"),
|
||||
@@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
return ret_val
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def is_room_blocked(self, room_id):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def is_room_blocked(self, room_id: str) -> Optional[bool]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="blocked_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="1",
|
||||
@@ -1129,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||
},
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"store_room_txn", store_room_txn, next_id
|
||||
)
|
||||
@@ -1196,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||
},
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"set_room_is_public", set_room_is_public_txn, next_id
|
||||
)
|
||||
@@ -1276,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||
},
|
||||
)
|
||||
|
||||
with self._public_room_id_gen.get_next() as next_id:
|
||||
with await self._public_room_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction(
|
||||
"set_room_is_public_appservice",
|
||||
set_room_is_public_appservice_txn,
|
||||
|
||||
@@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
return event.content.get("canonical_alias")
|
||||
|
||||
@cached(max_entries=50000)
|
||||
def _get_state_group_for_event(self, event_id):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="event_to_state_groups",
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="state_group",
|
||||
|
||||
@@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore):
|
||||
|
||||
return len(rooms_to_work_on)
|
||||
|
||||
def get_stats_positions(self):
|
||||
async def get_stats_positions(self) -> int:
|
||||
"""
|
||||
Returns the stats processor positions.
|
||||
"""
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="stats_incremental_position",
|
||||
keyvalues={},
|
||||
retcol="stream_id",
|
||||
@@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore):
|
||||
return slice_list
|
||||
|
||||
@cached()
|
||||
def get_earliest_token_for_stats(self, stats_type, id):
|
||||
async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
|
||||
"""
|
||||
Fetch the "earliest token". This is used by the room stats delta
|
||||
processor to ignore deltas that have been processed between the
|
||||
@@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore):
|
||||
being calculated.
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
The earliest token.
|
||||
"""
|
||||
table, id_col = TYPE_TO_TABLE[stats_type]
|
||||
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
"%s_current" % (table,),
|
||||
keyvalues={id_col: id},
|
||||
retcol="completed_delta_stream_id",
|
||||
|
||||
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
|
||||
)
|
||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
|
||||
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
|
||||
txn.execute(sql, (user_id, room_id, tag))
|
||||
self._update_revision_txn(txn, user_id, room_id, next_id)
|
||||
|
||||
with self._account_data_id_gen.get_next() as next_id:
|
||||
with await self._account_data_id_gen.get_next() as next_id:
|
||||
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
|
||||
|
||||
self.get_tags_for_user.invalidate((user_id,))
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules
|
||||
from synapse.storage.database import DatabasePool
|
||||
@@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_user_in_directory(self, user_id):
|
||||
return self.db_pool.simple_select_one(
|
||||
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="user_directory",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("display_name", "avatar_url"),
|
||||
@@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
users.update(rows)
|
||||
return list(users)
|
||||
|
||||
def get_user_directory_stream_pos(self):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def get_user_directory_stream_pos(self) -> int:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="user_directory_stream_pos",
|
||||
keyvalues={},
|
||||
retcol="stream_id",
|
||||
|
||||
@@ -14,9 +14,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import heapq
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Dict, Set
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from typing_extensions import Deque
|
||||
|
||||
@@ -80,7 +81,7 @@ class StreamIdGenerator(object):
|
||||
upwards, -1 to grow downwards.
|
||||
|
||||
Usage:
|
||||
with stream_id_gen.get_next() as stream_id:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
@@ -95,10 +96,10 @@ class StreamIdGenerator(object):
|
||||
)
|
||||
self._unfinished_ids = deque() # type: Deque[int]
|
||||
|
||||
def get_next(self):
|
||||
async def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
with stream_id_gen.get_next() as stream_id:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
with self._lock:
|
||||
@@ -117,10 +118,10 @@ class StreamIdGenerator(object):
|
||||
|
||||
return manager()
|
||||
|
||||
def get_next_mult(self, n):
|
||||
async def get_next_mult(self, n):
|
||||
"""
|
||||
Usage:
|
||||
with stream_id_gen.get_next(n) as stream_ids:
|
||||
with await stream_id_gen.get_next(n) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
with self._lock:
|
||||
@@ -210,6 +211,23 @@ class MultiWriterIdGenerator:
|
||||
# should be less than the minimum of this set (if not empty).
|
||||
self._unfinished_ids = set() # type: Set[int]
|
||||
|
||||
# We track the max position where we know everything before has been
|
||||
# persisted. This is done by a) looking at the min across all instances
|
||||
# and b) noting that if we have seen a run of persisted positions
|
||||
# without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
|
||||
#
|
||||
# Note: There is no guarentee that the IDs generated by the sequence
|
||||
# will be gapless; gaps can form when e.g. a transaction was rolled
|
||||
# back. This means that sometimes we won't be able to skip forward the
|
||||
# position even though everything has been persisted. However, since
|
||||
# gaps should be relatively rare it's still worth doing the book keeping
|
||||
# that allows us to skip forwards when there are gapless runs of
|
||||
# positions.
|
||||
self._persisted_upto_position = (
|
||||
min(self._current_positions.values()) if self._current_positions else 0
|
||||
)
|
||||
self._known_persisted_positions = [] # type: List[int]
|
||||
|
||||
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
|
||||
|
||||
def _load_current_ids(
|
||||
@@ -234,9 +252,12 @@ class MultiWriterIdGenerator:
|
||||
|
||||
return current_positions
|
||||
|
||||
def _load_next_id_txn(self, txn):
|
||||
def _load_next_id_txn(self, txn) -> int:
|
||||
return self._sequence_gen.get_next_id_txn(txn)
|
||||
|
||||
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
|
||||
return self._sequence_gen.get_next_mult_txn(txn, n)
|
||||
|
||||
async def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
@@ -262,6 +283,34 @@ class MultiWriterIdGenerator:
|
||||
|
||||
return manager()
|
||||
|
||||
async def get_next_mult(self, n: int):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next_mult(5) as stream_ids:
|
||||
# ... persist events ...
|
||||
"""
|
||||
next_ids = await self._db.runInteraction(
|
||||
"_load_next_mult_id", self._load_next_mult_id_txn, n
|
||||
)
|
||||
|
||||
# Assert the fetched ID is actually greater than any ID we've already
|
||||
# seen. If not, then the sequence and table have got out of sync
|
||||
# somehow.
|
||||
assert max(self.get_positions().values(), default=0) < min(next_ids)
|
||||
|
||||
with self._lock:
|
||||
self._unfinished_ids.update(next_ids)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_ids
|
||||
finally:
|
||||
for i in next_ids:
|
||||
self._mark_id_as_finished(i)
|
||||
|
||||
return manager()
|
||||
|
||||
def get_next_txn(self, txn: LoggingTransaction):
|
||||
"""
|
||||
Usage:
|
||||
@@ -326,3 +375,53 @@ class MultiWriterIdGenerator:
|
||||
self._current_positions[instance_name] = max(
|
||||
new_id, self._current_positions.get(instance_name, 0)
|
||||
)
|
||||
|
||||
self._add_persisted_position(new_id)
|
||||
|
||||
def get_persisted_upto_position(self) -> int:
|
||||
"""Get the max position where all previous positions have been
|
||||
persisted.
|
||||
|
||||
Note: In the worst case scenario this will be equal to the minimum
|
||||
position across writers. This means that the returned position here can
|
||||
lag if one writer doesn't write very often.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return self._persisted_upto_position
|
||||
|
||||
def _add_persisted_position(self, new_id: int):
|
||||
"""Record that we have persisted a position.
|
||||
|
||||
This is used to keep the `_current_positions` up to date.
|
||||
"""
|
||||
|
||||
# We require that the lock is locked by caller
|
||||
assert self._lock.locked()
|
||||
|
||||
heapq.heappush(self._known_persisted_positions, new_id)
|
||||
|
||||
# We move the current min position up if the minimum current positions
|
||||
# of all instances is higher (since by definition all positions less
|
||||
# that that have been persisted).
|
||||
min_curr = min(self._current_positions.values())
|
||||
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
|
||||
|
||||
# We now iterate through the seen positions, discarding those that are
|
||||
# less than the current min positions, and incrementing the min position
|
||||
# if its exactly one greater.
|
||||
#
|
||||
# This is also where we discard items from `_known_persisted_positions`
|
||||
# (to ensure the list doesn't infinitely grow).
|
||||
while self._known_persisted_positions:
|
||||
if self._known_persisted_positions[0] <= self._persisted_upto_position:
|
||||
heapq.heappop(self._known_persisted_positions)
|
||||
elif (
|
||||
self._known_persisted_positions[0] == self._persisted_upto_position + 1
|
||||
):
|
||||
heapq.heappop(self._known_persisted_positions)
|
||||
self._persisted_upto_position += 1
|
||||
else:
|
||||
# There was a gap in seen positions, so there is nothing more to
|
||||
# do.
|
||||
break
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import threading
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
from synapse.storage.types import Cursor
|
||||
@@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
||||
txn.execute("SELECT nextval(?)", (self._sequence_name,))
|
||||
return txn.fetchone()[0]
|
||||
|
||||
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
||||
txn.execute(
|
||||
"SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
|
||||
)
|
||||
return [i for (i,) in txn]
|
||||
|
||||
|
||||
GetFirstCallbackType = Callable[[Cursor], int]
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.types import create_requester
|
||||
|
||||
from tests import unittest
|
||||
|
||||
@@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase):
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(20.0, time_allowed)
|
||||
|
||||
def test_allowed_user_via_can_requester_do_action(self):
|
||||
user_requester = create_requester("@user:example.com")
|
||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
user_requester, _time_now_s=0
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(10.0, time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
user_requester, _time_now_s=5
|
||||
)
|
||||
self.assertFalse(allowed)
|
||||
self.assertEquals(10.0, time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
user_requester, _time_now_s=10
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(20.0, time_allowed)
|
||||
|
||||
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
|
||||
appservice = ApplicationService(
|
||||
None, "example.com", id="foo", rate_limited=True,
|
||||
)
|
||||
as_requester = create_requester("@user:example.com", app_service=appservice)
|
||||
|
||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
as_requester, _time_now_s=0
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(10.0, time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
as_requester, _time_now_s=5
|
||||
)
|
||||
self.assertFalse(allowed)
|
||||
self.assertEquals(10.0, time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
as_requester, _time_now_s=10
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(20.0, time_allowed)
|
||||
|
||||
def test_allowed_appservice_via_can_requester_do_action(self):
|
||||
appservice = ApplicationService(
|
||||
None, "example.com", id="foo", rate_limited=False,
|
||||
)
|
||||
as_requester = create_requester("@user:example.com", app_service=appservice)
|
||||
|
||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
as_requester, _time_now_s=0
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(-1, time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
as_requester, _time_now_s=5
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(-1, time_allowed)
|
||||
|
||||
allowed, time_allowed = limiter.can_requester_do_action(
|
||||
as_requester, _time_now_s=10
|
||||
)
|
||||
self.assertTrue(allowed)
|
||||
self.assertEquals(-1, time_allowed)
|
||||
|
||||
def test_allowed_via_ratelimit(self):
|
||||
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
|
||||
|
||||
|
||||
@@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_my_name(self):
|
||||
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
)
|
||||
|
||||
displayname = yield defer.ensureDeferred(
|
||||
self.handler.get_displayname(self.frank)
|
||||
@@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
"Frank",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
|
||||
self.hs.config.enable_set_displayname = False
|
||||
|
||||
# Setting displayname for the first time is allowed
|
||||
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
"Frank",
|
||||
)
|
||||
|
||||
# Setting displayname a second time is forbidden
|
||||
@@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_incoming_fed_query(self):
|
||||
yield defer.ensureDeferred(self.store.create_profile("caroline"))
|
||||
yield self.store.set_profile_displayname("caroline", "Caroline")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_displayname("caroline", "Caroline")
|
||||
)
|
||||
|
||||
response = yield defer.ensureDeferred(
|
||||
self.query_handlers["profile"](
|
||||
@@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_my_avatar(self):
|
||||
yield self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
)
|
||||
)
|
||||
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
|
||||
|
||||
@@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_avatar_url(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
"http://my.server/pic.gif",
|
||||
)
|
||||
|
||||
@@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_avatar_url(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
"http://my.server/me.png",
|
||||
)
|
||||
|
||||
@@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
|
||||
self.hs.config.enable_set_avatar_url = False
|
||||
|
||||
# Setting displayname for the first time is allowed
|
||||
yield self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_avatar_url(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
"http://my.server/me.png",
|
||||
)
|
||||
|
||||
|
||||
@@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.datastore.get_users_in_room = get_users_in_room
|
||||
|
||||
self.datastore.get_user_directory_stream_pos.return_value = (
|
||||
self.datastore.get_user_directory_stream_pos.side_effect = (
|
||||
# we deliberately return a non-None stream pos to avoid doing an initial_spam
|
||||
defer.succeed(1)
|
||||
lambda: make_awaitable(1)
|
||||
)
|
||||
|
||||
self.datastore.get_current_state_deltas.return_value = (0, None)
|
||||
|
||||
@@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase):
|
||||
# Check that the new user exists with all provided attributes
|
||||
self.assertEqual(user_id, "@bob:test")
|
||||
self.assertTrue(access_token)
|
||||
self.assertTrue(self.store.get_user_by_id(user_id))
|
||||
self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
|
||||
|
||||
# Check that the email was assigned
|
||||
emails = self.get_success(self.store.user_get_threepids(user_id))
|
||||
|
||||
@@ -45,33 +45,16 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
}
|
||||
|
||||
self.hs = self.setup_test_homeserver(config=config)
|
||||
|
||||
return self.hs
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.user_id = self.register_user("user", "password")
|
||||
self.token = self.login("user", "password")
|
||||
|
||||
def test_retention_state_event(self):
|
||||
"""Tests that the server configuration can limit the values a user can set to the
|
||||
room's retention policy.
|
||||
"""
|
||||
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||
|
||||
self.helper.send_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Retention,
|
||||
body={"max_lifetime": one_day_ms * 4},
|
||||
tok=self.token,
|
||||
expect_code=400,
|
||||
)
|
||||
|
||||
self.helper.send_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Retention,
|
||||
body={"max_lifetime": one_hour_ms},
|
||||
tok=self.token,
|
||||
expect_code=400,
|
||||
)
|
||||
self.store = self.hs.get_datastore()
|
||||
self.serializer = self.hs.get_event_client_serializer()
|
||||
self.clock = self.hs.get_clock()
|
||||
|
||||
def test_retention_event_purged_with_state_event(self):
|
||||
"""Tests that expired events are correctly purged when the room's retention policy
|
||||
@@ -90,6 +73,36 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self._test_retention_event_purged(room_id, one_day_ms * 1.5)
|
||||
|
||||
def test_retention_event_purged_with_state_event_outside_allowed(self):
|
||||
"""Tests that the server configuration can override the policy for a room when
|
||||
running the purge jobs.
|
||||
"""
|
||||
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||
|
||||
# Set a max_lifetime higher than the maximum allowed value.
|
||||
self.helper.send_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Retention,
|
||||
body={"max_lifetime": one_day_ms * 4},
|
||||
tok=self.token,
|
||||
)
|
||||
|
||||
# Check that the event is purged after waiting for the maximum allowed duration
|
||||
# instead of the one specified in the room's policy.
|
||||
self._test_retention_event_purged(room_id, one_day_ms * 1.5)
|
||||
|
||||
# Set a max_lifetime lower than the minimum allowed value.
|
||||
self.helper.send_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Retention,
|
||||
body={"max_lifetime": one_hour_ms},
|
||||
tok=self.token,
|
||||
)
|
||||
|
||||
# Check that the event is purged after waiting for the minimum allowed duration
|
||||
# instead of the one specified in the room's policy.
|
||||
self._test_retention_event_purged(room_id, one_day_ms * 0.5)
|
||||
|
||||
def test_retention_event_purged_without_state_event(self):
|
||||
"""Tests that expired events are correctly purged when the room's retention policy
|
||||
is defined by the server's configuration's default retention policy.
|
||||
@@ -140,7 +153,27 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
# That event should be the second, not outdated event.
|
||||
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
|
||||
|
||||
def _test_retention_event_purged(self, room_id, increment):
|
||||
def _test_retention_event_purged(self, room_id: str, increment: float):
|
||||
"""Run the following test scenario to test the message retention policy support:
|
||||
|
||||
1. Send event 1
|
||||
2. Increment time by `increment`
|
||||
3. Send event 2
|
||||
4. Increment time by `increment`
|
||||
5. Check that event 1 has been purged
|
||||
6. Check that event 2 has not been purged
|
||||
7. Check that state events that were sent before event 1 aren't purged.
|
||||
The main reason for sending a second event is because currently Synapse won't
|
||||
purge the latest message in a room because it would otherwise result in a lack of
|
||||
forward extremities for this room. It's also a good thing to ensure the purge jobs
|
||||
aren't too greedy and purge messages they shouldn't.
|
||||
|
||||
Args:
|
||||
room_id: The ID of the room to test retention in.
|
||||
increment: The number of milliseconds to advance the clock each time. Must be
|
||||
defined so that events in the room aren't purged if they are `increment`
|
||||
old but are purged if they are `increment * 2` old.
|
||||
"""
|
||||
# Get the create event to, later, check that we can still access it.
|
||||
message_handler = self.hs.get_message_handler()
|
||||
create_event = self.get_success(
|
||||
@@ -156,7 +189,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
expired_event_id = resp.get("event_id")
|
||||
|
||||
# Check that we can retrieve the event.
|
||||
expired_event = self.get_event(room_id, expired_event_id)
|
||||
expired_event = self.get_event(expired_event_id)
|
||||
self.assertEqual(
|
||||
expired_event.get("content", {}).get("body"), "1", expired_event
|
||||
)
|
||||
@@ -174,26 +207,31 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
# one should still be kept.
|
||||
self.reactor.advance(increment / 1000)
|
||||
|
||||
# Check that the event has been purged from the database.
|
||||
self.get_event(room_id, expired_event_id, expected_code=404)
|
||||
# Check that the first event has been purged from the database, i.e. that we
|
||||
# can't retrieve it anymore, because it has expired.
|
||||
self.get_event(expired_event_id, expect_none=True)
|
||||
|
||||
# Check that the event that hasn't been purged can still be retrieved.
|
||||
valid_event = self.get_event(room_id, valid_event_id)
|
||||
# Check that the event that hasn't expired can still be retrieved.
|
||||
valid_event = self.get_event(valid_event_id)
|
||||
self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
|
||||
|
||||
# Check that we can still access state events that were sent before the event that
|
||||
# has been purged.
|
||||
self.get_event(room_id, create_event.event_id)
|
||||
|
||||
def get_event(self, room_id, event_id, expected_code=200):
|
||||
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
|
||||
def get_event(self, event_id, expect_none=False):
|
||||
event = self.get_success(self.store.get_event(event_id, allow_none=True))
|
||||
|
||||
request, channel = self.make_request("GET", url, access_token=self.token)
|
||||
self.render(request)
|
||||
if expect_none:
|
||||
self.assertIsNone(event)
|
||||
return {}
|
||||
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
self.assertIsNotNone(event)
|
||||
|
||||
return channel.json_body
|
||||
time_now = self.clock.time_msec()
|
||||
serialized = self.get_success(self.serializer.serialize_event(event, time_now))
|
||||
|
||||
return serialized
|
||||
|
||||
|
||||
class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mock import Mock, patch
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.rest.client.v1 import directory, login, profile, room
|
||||
from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class _ShadowBannedBase(unittest.HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
# Create two users, one of which is shadow-banned.
|
||||
self.banned_user_id = self.register_user("banned", "test")
|
||||
self.banned_access_token = self.login("banned", "test")
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_update(
|
||||
table="users",
|
||||
keyvalues={"name": self.banned_user_id},
|
||||
updatevalues={"shadow_banned": True},
|
||||
desc="shadow_ban",
|
||||
)
|
||||
)
|
||||
|
||||
self.other_user_id = self.register_user("otheruser", "pass")
|
||||
self.other_access_token = self.login("otheruser", "pass")
|
||||
|
||||
|
||||
# To avoid the tests timing out don't add a delay to "annoy the requester".
|
||||
@patch("random.randint", new=lambda a, b: 0)
|
||||
class RoomTestCase(_ShadowBannedBase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
directory.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
room_upgrade_rest_servlet.register_servlets,
|
||||
]
|
||||
|
||||
def test_invite(self):
|
||||
"""Invites from shadow-banned users don't actually get sent."""
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# Inviting the user completes successfully.
|
||||
self.helper.invite(
|
||||
room=room_id,
|
||||
src=self.banned_user_id,
|
||||
tok=self.banned_access_token,
|
||||
targ=self.other_user_id,
|
||||
)
|
||||
|
||||
# But the user wasn't actually invited.
|
||||
invited_rooms = self.get_success(
|
||||
self.store.get_invited_rooms_for_local_user(self.other_user_id)
|
||||
)
|
||||
self.assertEqual(invited_rooms, [])
|
||||
|
||||
def test_invite_3pid(self):
|
||||
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
identity_handler.lookup_3pid = Mock(
|
||||
side_effect=AssertionError("This should not get called")
|
||||
)
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# Inviting the user completes successfully.
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/rooms/%s/invite" % (room_id,),
|
||||
{"id_server": "test", "medium": "email", "address": "test@test.test"},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
|
||||
# This should have raised an error earlier, but double check this wasn't called.
|
||||
identity_handler.lookup_3pid.assert_not_called()
|
||||
|
||||
def test_create_room(self):
|
||||
"""Invitations during a room creation should be discarded, but the room still gets created."""
|
||||
# The room creation is successful.
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/createRoom",
|
||||
{"visibility": "public", "invite": [self.other_user_id]},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
room_id = channel.json_body["room_id"]
|
||||
|
||||
# But the user wasn't actually invited.
|
||||
invited_rooms = self.get_success(
|
||||
self.store.get_invited_rooms_for_local_user(self.other_user_id)
|
||||
)
|
||||
self.assertEqual(invited_rooms, [])
|
||||
|
||||
# Since a real room was created, the other user should be able to join it.
|
||||
self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
|
||||
|
||||
# Both users should be in the room.
|
||||
users = self.get_success(self.store.get_users_in_room(room_id))
|
||||
self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
|
||||
|
||||
def test_message(self):
|
||||
"""Messages from shadow-banned users don't actually get sent."""
|
||||
|
||||
room_id = self.helper.create_room_as(
|
||||
self.other_user_id, tok=self.other_access_token
|
||||
)
|
||||
|
||||
# The user should be in the room.
|
||||
self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
|
||||
|
||||
# Sending a message should complete successfully.
|
||||
result = self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={"msgtype": "m.text", "body": "with right label"},
|
||||
tok=self.banned_access_token,
|
||||
)
|
||||
self.assertIn("event_id", result)
|
||||
event_id = result["event_id"]
|
||||
|
||||
latest_events = self.get_success(
|
||||
self.store.get_latest_event_ids_in_room(room_id)
|
||||
)
|
||||
self.assertNotIn(event_id, latest_events)
|
||||
|
||||
def test_upgrade(self):
|
||||
"""A room upgrade should fail, but look like it succeeded."""
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
|
||||
{"new_version": "6"},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
# A new room_id should be returned.
|
||||
self.assertIn("replacement_room", channel.json_body)
|
||||
|
||||
new_room_id = channel.json_body["replacement_room"]
|
||||
|
||||
# It doesn't really matter what API we use here, we just want to assert
|
||||
# that the room doesn't exist.
|
||||
summary = self.get_success(self.store.get_room_summary(new_room_id))
|
||||
# The summary should be empty since the room doesn't exist.
|
||||
self.assertEqual(summary, {})
|
||||
|
||||
|
||||
# To avoid the tests timing out don't add a delay to "annoy the requester".
|
||||
@patch("random.randint", new=lambda a, b: 0)
|
||||
class ProfileTestCase(_ShadowBannedBase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
login.register_servlets,
|
||||
profile.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def test_displayname(self):
|
||||
"""Profile changes should succeed, but don't end up in a room."""
|
||||
original_display_name = "banned"
|
||||
new_display_name = "new name"
|
||||
|
||||
# Join a room.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# The update should succeed.
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
"/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
|
||||
{"displayname": new_display_name},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# The user's display name should be updated.
|
||||
request, channel = self.make_request(
|
||||
"GET", "/profile/%s/displayname" % (self.banned_user_id,)
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(channel.json_body["displayname"], new_display_name)
|
||||
|
||||
# But the display name in the room should not be.
|
||||
message_handler = self.hs.get_message_handler()
|
||||
event = self.get_success(
|
||||
message_handler.get_room_data(
|
||||
self.banned_user_id,
|
||||
room_id,
|
||||
"m.room.member",
|
||||
self.banned_user_id,
|
||||
False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
event.content, {"membership": "join", "displayname": original_display_name}
|
||||
)
|
||||
|
||||
def test_room_displayname(self):
|
||||
"""Changes to state events for a room should be processed, but not end up in the room."""
|
||||
original_display_name = "banned"
|
||||
new_display_name = "new name"
|
||||
|
||||
# Join a room.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# The update should succeed.
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
|
||||
% (room_id, self.banned_user_id),
|
||||
{"membership": "join", "displayname": new_display_name},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
self.assertIn("event_id", channel.json_body)
|
||||
|
||||
# The display name in the room should not be changed.
|
||||
message_handler = self.hs.get_message_handler()
|
||||
event = self.get_success(
|
||||
message_handler.get_room_data(
|
||||
self.banned_user_id,
|
||||
room_id,
|
||||
"m.room.member",
|
||||
self.banned_user_id,
|
||||
False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
event.content, {"membership": "join", "displayname": original_display_name}
|
||||
)
|
||||
@@ -28,7 +28,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.handlers.pagination import PurgeStatus
|
||||
from synapse.rest.client.v1 import directory, login, profile, room
|
||||
from synapse.rest.client.v2_alpha import account
|
||||
from synapse.types import JsonDict, RoomAlias
|
||||
from synapse.types import JsonDict, RoomAlias, UserID
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from tests import unittest
|
||||
@@ -675,6 +675,92 @@ class RoomMemberStateTestCase(RoomBase):
|
||||
self.assertEquals(json.loads(content), channel.json_body)
|
||||
|
||||
|
||||
class RoomJoinRatelimitTestCase(RoomBase):
|
||||
user_id = "@sid1:red"
|
||||
|
||||
servlets = [
|
||||
profile.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit(self):
|
||||
"""Tests that local joins are actually rate-limited."""
|
||||
for i in range(3):
|
||||
self.helper.create_room_as(self.user_id)
|
||||
|
||||
self.helper.create_room_as(self.user_id, expect_code=429)
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit_profile_change(self):
|
||||
"""Tests that sending a profile update into all of the user's joined rooms isn't
|
||||
rate-limited by the rate-limiter on joins."""
|
||||
|
||||
# Create and join as many rooms as the rate-limiting config allows in a second.
|
||||
room_ids = [
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
]
|
||||
# Let some time for the rate-limiter to forget about our multi-join.
|
||||
self.reactor.advance(2)
|
||||
# Add one to make sure we're joined to more rooms than the config allows us to
|
||||
# join in a second.
|
||||
room_ids.append(self.helper.create_room_as(self.user_id))
|
||||
|
||||
# Create a profile for the user, since it hasn't been done on registration.
|
||||
store = self.hs.get_datastore()
|
||||
self.get_success(
|
||||
store.create_profile(UserID.from_string(self.user_id).localpart)
|
||||
)
|
||||
|
||||
# Update the display name for the user.
|
||||
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
|
||||
request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
|
||||
self.render(request)
|
||||
self.assertEquals(channel.code, 200, channel.json_body)
|
||||
|
||||
# Check that all the rooms have been sent a profile update into.
|
||||
for room_id in room_ids:
|
||||
path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
|
||||
room_id,
|
||||
self.user_id,
|
||||
)
|
||||
|
||||
request, channel = self.make_request("GET", path)
|
||||
self.render(request)
|
||||
self.assertEquals(channel.code, 200)
|
||||
|
||||
self.assertIn("displayname", channel.json_body)
|
||||
self.assertEquals(channel.json_body["displayname"], "John Doe")
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit_idempotent(self):
|
||||
"""Tests that the room join endpoints remain idempotent despite rate-limiting
|
||||
on room joins."""
|
||||
room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
# Let's test both paths to be sure.
|
||||
paths_to_test = [
|
||||
"/_matrix/client/r0/rooms/%s/join",
|
||||
"/_matrix/client/r0/join/%s",
|
||||
]
|
||||
|
||||
for path in paths_to_test:
|
||||
# Make sure we send more requests than the rate-limiting config would allow
|
||||
# if all of these requests ended up joining the user to a room.
|
||||
for i in range(4):
|
||||
request, channel = self.make_request("POST", path % room_id, {})
|
||||
self.render(request)
|
||||
self.assertEquals(channel.code, 200)
|
||||
|
||||
|
||||
class RoomMessagesTestCase(RoomBase):
|
||||
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
|
||||
|
||||
@@ -1974,103 +2060,3 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
"""An alias which does not point to the room raises a SynapseError."""
|
||||
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
|
||||
|
||||
|
||||
class ShadowBannedTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
directory.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.banned_user_id = self.register_user("banned", "test")
|
||||
self.banned_access_token = self.login("banned", "test")
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_update(
|
||||
table="users",
|
||||
keyvalues={"name": self.banned_user_id},
|
||||
updatevalues={"shadow_banned": True},
|
||||
desc="shadow_ban",
|
||||
)
|
||||
)
|
||||
|
||||
self.other_user_id = self.register_user("otheruser", "pass")
|
||||
self.other_access_token = self.login("otheruser", "pass")
|
||||
|
||||
def test_invite(self):
|
||||
"""Invites from shadow-banned users don't actually get sent."""
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# Inviting the user completes successfully.
|
||||
self.helper.invite(
|
||||
room=room_id,
|
||||
src=self.banned_user_id,
|
||||
tok=self.banned_access_token,
|
||||
targ=self.other_user_id,
|
||||
)
|
||||
|
||||
# But the user wasn't actually invited.
|
||||
invited_rooms = self.get_success(
|
||||
self.store.get_invited_rooms_for_local_user(self.other_user_id)
|
||||
)
|
||||
self.assertEqual(invited_rooms, [])
|
||||
|
||||
def test_invite_3pid(self):
|
||||
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
identity_handler.lookup_3pid = Mock(
|
||||
side_effect=AssertionError("This should not get called")
|
||||
)
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# Inviting the user completes successfully.
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/rooms/%s/invite" % (room_id,),
|
||||
{"id_server": "test", "medium": "email", "address": "test@test.test"},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
|
||||
# This should have raised an error earlier, but double check this wasn't called.
|
||||
identity_handler.lookup_3pid.assert_not_called()
|
||||
|
||||
def test_create_room(self):
|
||||
"""Invitations during a room creation should be discarded, but the room still gets created."""
|
||||
# The room creation is successful.
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/createRoom",
|
||||
{"visibility": "public", "invite": [self.other_user_id]},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
room_id = channel.json_body["room_id"]
|
||||
|
||||
# But the user wasn't actually invited.
|
||||
invited_rooms = self.get_success(
|
||||
self.store.get_invited_rooms_for_local_user(self.other_user_id)
|
||||
)
|
||||
self.assertEqual(invited_rooms, [])
|
||||
|
||||
# Since a real room was created, the other user should be able to join it.
|
||||
self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
|
||||
|
||||
# Both users should be in the room.
|
||||
users = self.get_success(self.store.get_users_in_room(room_id))
|
||||
self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
|
||||
|
||||
@@ -39,7 +39,9 @@ class RestHelper(object):
|
||||
resource = attr.ib()
|
||||
auth_user_id = attr.ib()
|
||||
|
||||
def create_room_as(self, room_creator=None, is_public=True, tok=None):
|
||||
def create_room_as(
|
||||
self, room_creator=None, is_public=True, tok=None, expect_code=200,
|
||||
):
|
||||
temp_id = self.auth_user_id
|
||||
self.auth_user_id = room_creator
|
||||
path = "/_matrix/client/r0/createRoom"
|
||||
@@ -54,9 +56,11 @@ class RestHelper(object):
|
||||
)
|
||||
render(request, self.resource, self.hs.get_reactor())
|
||||
|
||||
assert channel.result["code"] == b"200", channel.result
|
||||
assert channel.result["code"] == b"%d" % expect_code, channel.result
|
||||
self.auth_user_id = temp_id
|
||||
return channel.json_body["room_id"]
|
||||
|
||||
if expect_code == 200:
|
||||
return channel.json_body["room_id"]
|
||||
|
||||
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
|
||||
self.change_membership(
|
||||
|
||||
+17
-11
@@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
self.mock_txn.rowcount = 1
|
||||
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
|
||||
|
||||
value = yield self.datastore.db_pool.simple_select_one_onecol(
|
||||
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
|
||||
value = yield defer.ensureDeferred(
|
||||
self.datastore.db_pool.simple_select_one_onecol(
|
||||
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals("Value", value)
|
||||
@@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
self.mock_txn.rowcount = 1
|
||||
self.mock_txn.fetchone.return_value = (1, 2, 3)
|
||||
|
||||
ret = yield self.datastore.db_pool.simple_select_one(
|
||||
table="tablename",
|
||||
keyvalues={"keycol": "TheKey"},
|
||||
retcols=["colA", "colB", "colC"],
|
||||
ret = yield defer.ensureDeferred(
|
||||
self.datastore.db_pool.simple_select_one(
|
||||
table="tablename",
|
||||
keyvalues={"keycol": "TheKey"},
|
||||
retcols=["colA", "colB", "colC"],
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
|
||||
@@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
self.mock_txn.rowcount = 0
|
||||
self.mock_txn.fetchone.return_value = None
|
||||
|
||||
ret = yield self.datastore.db_pool.simple_select_one(
|
||||
table="tablename",
|
||||
keyvalues={"keycol": "Not here"},
|
||||
retcols=["colA"],
|
||||
allow_none=True,
|
||||
ret = yield defer.ensureDeferred(
|
||||
self.datastore.db_pool.simple_select_one(
|
||||
table="tablename",
|
||||
keyvalues={"keycol": "Not here"},
|
||||
retcols=["colA"],
|
||||
allow_none=True,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertFalse(ret)
|
||||
|
||||
@@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
self.store.store_device("user_id", "device_id", "display_name")
|
||||
)
|
||||
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"user_id": "user_id",
|
||||
@@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
self.store.store_device("user_id", "device_id", "display_name 1")
|
||||
)
|
||||
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
|
||||
self.assertEqual("display_name 1", res["display_name"])
|
||||
|
||||
# do a no-op first
|
||||
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
|
||||
self.assertEqual("display_name 1", res["display_name"])
|
||||
|
||||
# do the update
|
||||
@@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
)
|
||||
|
||||
# check it worked
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
|
||||
self.assertEqual("display_name 2", res["display_name"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -182,3 +182,39 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
||||
|
||||
def test_get_persisted_upto_position(self):
|
||||
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||
positions.
|
||||
"""
|
||||
|
||||
self._insert_rows("first", 3)
|
||||
self._insert_rows("second", 5)
|
||||
|
||||
id_gen = self._create_id_generator("first")
|
||||
|
||||
# Min is 3 and there is a gap between 5, so we expect it to be 3.
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||
|
||||
# We advance "first" straight to 6. Min is now 5 but there is no gap so
|
||||
# we expect it to be 6
|
||||
id_gen.advance("first", 6)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||
|
||||
# No gap, so we expect 7.
|
||||
id_gen.advance("second", 7)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||
|
||||
# We haven't seen 8 yet, so we expect 7 still.
|
||||
id_gen.advance("second", 9)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||
|
||||
# Now that we've seen 7, 8 and 9 we can got straight to 9.
|
||||
id_gen.advance("first", 8)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
|
||||
|
||||
# Jump forward with gaps. The minimum is 11, even though we haven't seen
|
||||
# 10 we know that everything before 11 must be persisted.
|
||||
id_gen.advance("first", 11)
|
||||
id_gen.advance("second", 15)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
|
||||
|
||||
@@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase):
|
||||
def test_displayname(self):
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
|
||||
|
||||
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
|
||||
"Frank",
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_displayname(self.u_frank.localpart)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_avatar_url(self):
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
|
||||
|
||||
yield self.store.set_profile_avatar_url(
|
||||
self.u_frank.localpart, "http://my.site/here"
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.u_frank.localpart, "http://my.site/here"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
"http://my.site/here",
|
||||
(yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_avatar_url(self.u_frank.localpart)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import ThreepidValidationError
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
@@ -52,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||
"user_type": None,
|
||||
"deactivated": 0,
|
||||
},
|
||||
(yield self.store.get_user_by_id(self.user_id)),
|
||||
(yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -122,3 +123,33 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||
)
|
||||
res = yield self.store.is_support_user(SUPPORT_USER)
|
||||
self.assertTrue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_3pid_inhibit_invalid_validation_session_error(self):
|
||||
"""Tests that enabling the configuration option to inhibit 3PID errors on
|
||||
/requestToken also inhibits validation errors caused by an unknown session ID.
|
||||
"""
|
||||
|
||||
# Check that, with the config setting set to false (the default value), a
|
||||
# validation error is caused by the unknown session ID.
|
||||
try:
|
||||
yield defer.ensureDeferred(
|
||||
self.store.validate_threepid_session(
|
||||
"fake_sid", "fake_client_secret", "fake_token", 0,
|
||||
)
|
||||
)
|
||||
except ThreepidValidationError as e:
|
||||
self.assertEquals(e.msg, "Unknown session_id", e)
|
||||
|
||||
# Set the config setting to true.
|
||||
self.store._ignore_unknown_session_error = True
|
||||
|
||||
# Check that now the validation error is caused by the token not matching.
|
||||
try:
|
||||
yield defer.ensureDeferred(
|
||||
self.store.validate_threepid_session(
|
||||
"fake_sid", "fake_client_secret", "fake_token", 0,
|
||||
)
|
||||
)
|
||||
except ThreepidValidationError as e:
|
||||
self.assertEquals(e.msg, "Validation token not found or has expired", e)
|
||||
|
||||
@@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
|
||||
"creator": self.u_creator.to_string(),
|
||||
"is_public": True,
|
||||
},
|
||||
(yield self.store.get_room(self.room.to_string())),
|
||||
(yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_room_unknown_room(self):
|
||||
self.assertIsNone((yield self.store.get_room("!uknown:test")),)
|
||||
self.assertIsNone(
|
||||
(yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_room_with_stats(self):
|
||||
@@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
|
||||
"creator": self.u_creator.to_string(),
|
||||
"public": True,
|
||||
},
|
||||
(yield self.store.get_room_with_stats(self.room.to_string())),
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_room_with_stats(self.room.to_string())
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_room_with_stats_unknown_room(self):
|
||||
self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
|
||||
self.assertIsNone(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_room_with_stats("!uknown:test")
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class RoomEventsStoreTestCase(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user