1
0

Merge branch 'develop' of github.com:matrix-org/synapse into anoa/password_reset_confirmation

* 'develop' of github.com:matrix-org/synapse: (22 commits)
  Update the test federation client to handle streaming responses (#8130)
  Do not propagate profile changes of shadow-banned users into rooms. (#8157)
  Make SlavedIdTracker.advance have same interface as MultiWriterIDGenerator (#8171)
  Convert simple_select_one and simple_select_one_onecol to async (#8162)
  Fix rate limiting unit tests. (#8167)
  Add functions to `MultiWriterIdGen` used by events stream (#8164)
  Do not allow send_nonmember_event to be called with shadow-banned users. (#8158)
  Changelog fixes
  1.19.1rc1
  Make StreamIdGen `get_next` and `get_next_mult` async  (#8161)
  Wording fixes to 'name' user admin api filter (#8163)
  Fix missing double-backtick in RST document
  Search in columns 'name' and 'displayname' in the admin users endpoint (#7377)
  Add type hints for state. (#8140)
  Stop shadow-banned users from sending non-member events. (#8142)
  Allow capping a room's retention policy (#8104)
  Add healthcheck for default localhost 8008 port on /health endpoint. (#8147)
  Fix flaky shadow-ban tests. (#8152)
  Fix join ratelimiter breaking profile updates and idempotency (#8153)
  Do not apply ratelimiting on joins to appservices (#8139)
  ...
This commit is contained in:
Andrew Morgan
2020-08-26 14:28:12 +01:00
98 changed files with 1819 additions and 702 deletions
+10
View File
@@ -12,6 +12,16 @@ from Synapse as most users have updated their client. Further context can be
found at [\#6766](https://github.com/matrix-org/synapse/issues/6766).
Synapse 1.19.1rc1 (2020-08-25)
==============================
Bugfixes
--------
- Fix a bug introduced in v1.19.0 where appservices with ratelimiting disabled would still be ratelimited when joining rooms. ([\#8139](https://github.com/matrix-org/synapse/issues/8139))
- Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests. ([\#8153](https://github.com/matrix-org/synapse/issues/8153))
Synapse 1.19.0 (2020-08-17)
===========================
+1
View File
@@ -0,0 +1 @@
Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH.
+1
View File
@@ -0,0 +1 @@
Don't fail `/submit_token` requests on incorrect session ID if `request_token_inhibit_3pid_errors` is turned on.
+1
View File
@@ -0,0 +1 @@
Fix a bug introduced in v1.7.2 impacting message retention policies that would allow federated homeservers to dictate a retention period that's lower than the configured minimum allowed duration in the configuration file.
+1
View File
@@ -0,0 +1 @@
Update the test federation client to handle streaming responses.
+1
View File
@@ -0,0 +1 @@
Fixes a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0.
+1
View File
@@ -0,0 +1 @@
Add type hints to `synapse.state`.
+1
View File
@@ -0,0 +1 @@
Add support for shadow-banning users (ignoring any message send requests).
+1
View File
@@ -0,0 +1 @@
Added curl for healthcheck support and readme updates for the change. Contributed by @maquis196.
+1
View File
@@ -0,0 +1 @@
Add support for shadow-banning users (ignoring any message send requests).
+1
View File
@@ -0,0 +1 @@
Add support for shadow-banning users (ignoring any message send requests).
+1
View File
@@ -0,0 +1 @@
Add support for shadow-banning users (ignoring any message send requests).
+1
View File
@@ -0,0 +1 @@
Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface.
+1
View File
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
+1
View File
@@ -0,0 +1 @@
Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH.
+1
View File
@@ -0,0 +1 @@
Add functions to `MultiWriterIdGen` used by events stream.
+1
View File
@@ -0,0 +1 @@
Fix tests that were broken due to the merge of 1.19.1.
+1
View File
@@ -0,0 +1 @@
Make `SlavedIdTracker.advance` have the same interface as `MultiWriterIDGenerator`.
+4
View File
@@ -55,6 +55,7 @@ RUN pip install --prefix="/install" --no-warn-script-location \
FROM docker.io/python:${PYTHON_VERSION}-slim
RUN apt-get update && apt-get install -y \
curl \
libpq5 \
xmlsec1 \
gosu \
@@ -69,3 +70,6 @@ VOLUME ["/data"]
EXPOSE 8008/tcp 8009/tcp 8448/tcp
ENTRYPOINT ["/start.py"]
HEALTHCHECK --interval=1m --timeout=5s \
CMD curl -fSs http://localhost:8008/health || exit 1
+29
View File
@@ -162,3 +162,32 @@ docker build -t matrixdotorg/synapse -f docker/Dockerfile .
You can choose to build a different docker image by changing the value of the `-f` flag to
point to another Dockerfile.
## Disabling the healthcheck
If you are using a non-standard port or tls inside docker you can disable the healthcheck
whilst running the above `docker run` commands.
```
--no-healthcheck
```
## Setting custom healthcheck on docker run
If you wish to point the healthcheck at a different port with docker command, add the following
```
--health-cmd 'curl -fSs http://localhost:1234/health'
```
## Setting the healthcheck in docker-compose file
You can add the following to set a custom healthcheck in a docker compose file.
You will need version >2.1 for this to work.
```
healthcheck:
test: ["CMD", "curl", "-fSs", "http://localhost:8008/health"]
interval: 1m
timeout: 10s
retries: 3
```
+6 -3
View File
@@ -108,7 +108,7 @@ The api is::
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
To use it, you will need to authenticate by providing an `access_token` for a
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
The parameter ``from`` is optional but used for pagination, denoting the
@@ -119,8 +119,11 @@ from a previous call.
The parameter ``limit`` is optional but is used for pagination, denoting the
maximum number of items to return in this call. Defaults to ``100``.
The parameter ``user_id`` is optional and filters to only users with user IDs
that contain this value.
The parameter ``user_id`` is optional and filters to only return users with user IDs
that contain this value. This parameter is ignored when using the ``name`` parameter.
The parameter ``name`` is optional and filters to only return users with user ID localparts
**or** displaynames that contain this value.
The parameter ``guests`` is optional and if ``false`` will **exclude** guest users.
Defaults to ``true`` to include guest users.
+14 -8
View File
@@ -378,11 +378,10 @@ retention:
# min_lifetime: 1d
# max_lifetime: 1y
# Retention policy limits. If set, a user won't be able to send a
# 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
# that's not within this range. This is especially useful in closed federations,
# in which server admins can make sure every federating server applies the same
# rules.
# Retention policy limits. If set, and the state of a room contains a
# 'm.room.retention' event in its state which contains a 'min_lifetime' or a
# 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
# to these limits when running purge jobs.
#
#allowed_lifetime_min: 1d
#allowed_lifetime_max: 1y
@@ -408,12 +407,19 @@ retention:
# (e.g. every 12h), but not want that purge to be performed by a job that's
# iterating over every room it knows, which could be heavy on the server.
#
# If any purge job is configured, it is strongly recommended to have at least
# a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
# set, or one job without 'shortest_max_lifetime' and one job without
# 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
# 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
# room's policy to these values is done after the policies are retrieved from
# Synapse's database (which is done using the range specified in a purge job's
# configuration).
#
#purge_jobs:
# - shortest_max_lifetime: 1d
# longest_max_lifetime: 3d
# - longest_max_lifetime: 3d
# interval: 12h
# - shortest_max_lifetime: 3d
# longest_max_lifetime: 1y
# interval: 1d
# Inhibits the /requestToken endpoints from returning an error that might leak
+27 -8
View File
@@ -21,10 +21,12 @@ import argparse
import base64
import json
import sys
from typing import Any, Optional
from urllib import parse as urlparse
import nacl.signing
import requests
import signedjson.types
import srvlookup
import yaml
from requests.adapters import HTTPAdapter
@@ -69,7 +71,9 @@ def encode_canonical_json(value):
).encode("UTF-8")
def sign_json(json_object, signing_key, signing_name):
def sign_json(
json_object: Any, signing_key: signedjson.types.SigningKey, signing_name: str
) -> Any:
signatures = json_object.pop("signatures", {})
unsigned = json_object.pop("unsigned", None)
@@ -122,7 +126,14 @@ def read_signing_keys(stream):
return keys
def request_json(method, origin_name, origin_key, destination, path, content):
def request(
method: Optional[str],
origin_name: str,
origin_key: signedjson.types.SigningKey,
destination: str,
path: str,
content: Optional[str],
) -> requests.Response:
if method is None:
if content is None:
method = "GET"
@@ -159,11 +170,14 @@ def request_json(method, origin_name, origin_key, destination, path, content):
if method == "POST":
headers["Content-Type"] = "application/json"
result = s.request(
method=method, url=dest, headers=headers, verify=False, data=content
return s.request(
method=method,
url=dest,
headers=headers,
verify=False,
data=content,
stream=True,
)
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
return result.json()
def main():
@@ -222,7 +236,7 @@ def main():
with open(args.signing_key_path) as f:
key = read_signing_keys(f)[0]
result = request_json(
result = request(
args.method,
args.server_name,
key,
@@ -231,7 +245,12 @@ def main():
content=args.body,
)
json.dump(result, sys.stdout)
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
for chunk in result.iter_content():
# we write raw utf8 to stdout.
sys.stdout.buffer.write(chunk)
print("")
+47
View File
@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Stub for frozendict.
from typing import (
Any,
Hashable,
Iterable,
Iterator,
Mapping,
overload,
Tuple,
TypeVar,
)
_KT = TypeVar("_KT", bound=Hashable) # Key type.
_VT = TypeVar("_VT") # Value type.
class frozendict(Mapping[_KT, _VT]):
@overload
def __init__(self, **kwargs: _VT) -> None: ...
@overload
def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
@overload
def __init__(
self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
def __getitem__(self, key: _KT) -> _VT: ...
def __contains__(self, key: Any) -> bool: ...
def copy(self, **add_or_replace: Any) -> frozendict: ...
def __iter__(self) -> Iterator[_KT]: ...
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...
+1 -1
View File
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
__version__ = "1.19.0"
__version__ = "1.19.1rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
+37
View File
@@ -17,6 +17,7 @@ from collections import OrderedDict
from typing import Any, Optional, Tuple
from synapse.api.errors import LimitExceededError
from synapse.types import Requester
from synapse.util import Clock
@@ -43,6 +44,42 @@ class Ratelimiter(object):
# * The rate_hz of this particular entry. This can vary per request
self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]]
def can_requester_do_action(
self,
requester: Requester,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
_time_now_s: Optional[int] = None,
) -> Tuple[bool, float]:
"""Can the requester perform the action?
Args:
requester: The requester to key off when rate limiting. The user property
will be used.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set.
update: Whether to count this check as performing the action
_time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests.
Returns:
A tuple containing:
* A bool indicating if they can perform the action now
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return True, -1.0
return self.can_do_action(
requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
)
def can_do_action(
self,
key: Any,
+14 -8
View File
@@ -961,11 +961,10 @@ class ServerConfig(Config):
# min_lifetime: 1d
# max_lifetime: 1y
# Retention policy limits. If set, a user won't be able to send a
# 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
# that's not within this range. This is especially useful in closed federations,
# in which server admins can make sure every federating server applies the same
# rules.
# Retention policy limits. If set, and the state of a room contains a
# 'm.room.retention' event in its state which contains a 'min_lifetime' or a
# 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
# to these limits when running purge jobs.
#
#allowed_lifetime_min: 1d
#allowed_lifetime_max: 1y
@@ -991,12 +990,19 @@ class ServerConfig(Config):
# (e.g. every 12h), but not want that purge to be performed by a job that's
# iterating over every room it knows, which could be heavy on the server.
#
# If any purge job is configured, it is strongly recommended to have at least
# a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
# set, or one job without 'shortest_max_lifetime' and one job without
# 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
# 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
# room's policy to these values is done after the policies are retrieved from
# Synapse's database (which is done using the range specified in a purge job's
# configuration).
#
#purge_jobs:
# - shortest_max_lifetime: 1d
# longest_max_lifetime: 3d
# - longest_max_lifetime: 3d
# interval: 12h
# - shortest_max_lifetime: 3d
# longest_max_lifetime: 1y
# interval: 1d
# Inhibits the /requestToken endpoints from returning an error that might leak
+3 -56
View File
@@ -74,15 +74,14 @@ class EventValidator(object):
)
if event.type == EventTypes.Retention:
self._validate_retention(event, config)
self._validate_retention(event)
def _validate_retention(self, event, config):
def _validate_retention(self, event):
"""Checks that an event that defines the retention policy for a room respects the
boundaries imposed by the server's administrator.
format enforced by the spec.
Args:
event (FrozenEvent): The event to validate.
config (Config): The homeserver's configuration.
"""
min_lifetime = event.content.get("min_lifetime")
max_lifetime = event.content.get("max_lifetime")
@@ -95,32 +94,6 @@ class EventValidator(object):
errcode=Codes.BAD_JSON,
)
if (
config.retention_allowed_lifetime_min is not None
and min_lifetime < config.retention_allowed_lifetime_min
):
raise SynapseError(
code=400,
msg=(
"'min_lifetime' can't be lower than the minimum allowed"
" value enforced by the server's administrator"
),
errcode=Codes.BAD_JSON,
)
if (
config.retention_allowed_lifetime_max is not None
and min_lifetime > config.retention_allowed_lifetime_max
):
raise SynapseError(
code=400,
msg=(
"'min_lifetime' can't be greater than the maximum allowed"
" value enforced by the server's administrator"
),
errcode=Codes.BAD_JSON,
)
if max_lifetime is not None:
if not isinstance(max_lifetime, int):
raise SynapseError(
@@ -129,32 +102,6 @@ class EventValidator(object):
errcode=Codes.BAD_JSON,
)
if (
config.retention_allowed_lifetime_min is not None
and max_lifetime < config.retention_allowed_lifetime_min
):
raise SynapseError(
code=400,
msg=(
"'max_lifetime' can't be lower than the minimum allowed value"
" enforced by the server's administrator"
),
errcode=Codes.BAD_JSON,
)
if (
config.retention_allowed_lifetime_max is not None
and max_lifetime > config.retention_allowed_lifetime_max
):
raise SynapseError(
code=400,
msg=(
"'max_lifetime' can't be greater than the maximum allowed"
" value enforced by the server's administrator"
),
errcode=Codes.BAD_JSON,
)
if (
min_lifetime is not None
and max_lifetime is not None
+2 -2
View File
@@ -329,10 +329,10 @@ class FederationSender(object):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
domains = await self.state.get_current_hosts_in_room(room_id)
domains_set = await self.state.get_current_hosts_in_room(room_id)
domains = [
d
for d in domains
for d in domains_set
if d != self.server_name
and self._federation_shard_config.should_handle(self._instance_name, d)
]
+6
View File
@@ -23,6 +23,7 @@ from synapse.api.errors import (
CodeMessageException,
Codes,
NotFoundError,
ShadowBanError,
StoreError,
SynapseError,
)
@@ -199,6 +200,8 @@ class DirectoryHandler(BaseHandler):
try:
await self._update_canonical_alias(requester, user_id, room_id, room_alias)
except ShadowBanError as e:
logger.info("Failed to update alias events due to shadow-ban: %s", e)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
@@ -292,6 +295,9 @@ class DirectoryHandler(BaseHandler):
"""
Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field.
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
alias_event = await self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, ""
+6 -4
View File
@@ -2134,10 +2134,10 @@ class FederationHandler(BaseHandler):
)
state_sets = list(state_sets.values())
state_sets.append(state)
current_state_ids = await self.state_handler.resolve_events(
current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
current_state_ids = {k: e.event_id for k, e in current_states.items()}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
@@ -2149,9 +2149,11 @@ class FederationHandler(BaseHandler):
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]
auth_events_map = await self.store.get_events(current_state_ids)
auth_events_map = await self.store.get_events(current_state_ids_list)
current_auth_events = {
(e.type, e.state_key): e for e in auth_events_map.values()
}
+44 -5
View File
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from canonicaljson import encode_canonical_json
@@ -34,6 +35,7 @@ from synapse.api.errors import (
Codes,
ConsentNotGivenError,
NotFoundError,
ShadowBanError,
SynapseError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
@@ -645,24 +647,35 @@ class EventCreationHandler(object):
event: EventBase,
context: EventContext,
ratelimit: bool = True,
ignore_shadow_ban: bool = False,
) -> int:
"""
Persists and notifies local clients and federation of an event.
Args:
requester
event the event to send.
context: the context of the event.
requester: The requester sending the event.
event: The event to send.
context: The context of the event.
ratelimit: Whether to rate limit this send.
ignore_shadow_ban: True if shadow-banned users should be allowed to
send this event.
Return:
The stream_id of the persisted event.
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500, "Tried to send member event through non-member codepath"
)
if not ignore_shadow_ban and requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError()
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
@@ -716,12 +729,28 @@ class EventCreationHandler(object):
event_dict: dict,
ratelimit: bool = True,
txn_id: Optional[str] = None,
ignore_shadow_ban: bool = False,
) -> Tuple[EventBase, int]:
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
Args:
requester: The requester sending the event.
event_dict: An entire event.
ratelimit: Whether to rate limit this send.
txn_id: The transaction ID.
ignore_shadow_ban: True if shadow-banned users should be allowed to
send this event.
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
if not ignore_shadow_ban and requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError()
# We limit the number of concurrent event sends in a room so that we
# don't fork the DAG too much. If we don't limit then we can end up in
@@ -740,7 +769,11 @@ class EventCreationHandler(object):
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
stream_id = await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit
requester,
event,
context,
ratelimit=ratelimit,
ignore_shadow_ban=ignore_shadow_ban,
)
return event, stream_id
@@ -1180,8 +1213,14 @@ class EventCreationHandler(object):
event.internal_metadata.proactively_send = False
# Since this is a dummy-event it is OK if it is sent by a
# shadow-banned user.
await self.send_nonmember_event(
requester, event, context, ratelimit=False
requester,
event,
context,
ratelimit=False,
ignore_shadow_ban=True,
)
dummy_event_sent = True
break
+29 -7
View File
@@ -82,6 +82,9 @@ class PaginationHandler(object):
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
if hs.config.retention_enabled:
# Run the purge jobs described in the configuration file.
for job in hs.config.retention_purge_jobs:
@@ -111,7 +114,7 @@ class PaginationHandler(object):
the range to handle (inclusive). If None, it means that the range has no
upper limit.
"""
# We want the storage layer to to include rooms with no retention policy in its
# We want the storage layer to include rooms with no retention policy in its
# return value only if a default retention policy is defined in the server's
# configuration and that policy's 'max_lifetime' is either lower (or equal) than
# max_ms or higher than min_ms (or both).
@@ -152,13 +155,32 @@ class PaginationHandler(object):
)
continue
max_lifetime = retention_policy["max_lifetime"]
# If max_lifetime is None, it means that the room has no retention policy.
# Given we only retrieve such rooms when there's a default retention policy
# defined in the server's configuration, we can safely assume that's the
# case and use it for this room.
max_lifetime = (
retention_policy["max_lifetime"] or self._retention_default_max_lifetime
)
if max_lifetime is None:
# If max_lifetime is None, it means that include_null equals True,
# therefore we can safely assume that there is a default policy defined
# in the server's configuration.
max_lifetime = self._retention_default_max_lifetime
# Cap the effective max_lifetime to be within the range allowed in the
# config.
# We do this in two steps:
# 1. Make sure it's higher or equal to the minimum allowed value, and if
# it's not replace it with that value. This is because the server
# operator can be required to not delete information before a given
# time, e.g. to comply with freedom of information laws.
# 2. Make sure the resulting value is lower or equal to the maximum allowed
# value, and if it's not replace it with that value. This is because the
# server operator can be required to delete any data after a specific
# amount of time.
if self._retention_allowed_lifetime_min is not None:
max_lifetime = max(self._retention_allowed_lifetime_min, max_lifetime)
if self._retention_allowed_lifetime_max is not None:
max_lifetime = min(max_lifetime, self._retention_allowed_lifetime_max)
logger.debug("[purge] max_lifetime for room %s: %s", room_id, max_lifetime)
# Figure out what token we should start purging at.
ts = self.clock.time_msec() - max_lifetime
+3 -3
View File
@@ -40,7 +40,7 @@ from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
@@ -1318,7 +1318,7 @@ async def get_interested_parties(
async def get_interested_remotes(
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
) -> List[Tuple[List[str], List[UserPresenceState]]]:
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
@@ -1334,7 +1334,7 @@ async def get_interested_remotes(
each tuple the list of UserPresenceState should be sent to each
destination
"""
hosts_and_states = []
hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]]
# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote
+15 -2
View File
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
import random
from synapse.api.errors import (
AuthError,
@@ -213,8 +214,14 @@ class BaseProfileHandler(BaseHandler):
async def set_avatar_url(
self, target_user, requester, new_avatar_url, by_admin=False
):
"""target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change."""
"""Set a new avatar URL for a user.
Args:
target_user (UserID): the user whose avatar URL is to be changed.
requester (Requester): The user attempting to make this change.
new_avatar_url (str): The avatar URL to give this user.
by_admin (bool): Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -278,6 +285,12 @@ class BaseProfileHandler(BaseHandler):
await self.ratelimit(requester)
# Do not actually update the room state for shadow-banned users.
if requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
return
room_ids = await self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids:
+18 -1
View File
@@ -136,6 +136,9 @@ class RoomCreationHandler(BaseHandler):
Returns:
the new room id
Raises:
ShadowBanError if the requester is shadow-banned.
"""
await self.ratelimit(requester)
@@ -171,6 +174,15 @@ class RoomCreationHandler(BaseHandler):
async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
"""
Args:
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
new_versions: the version to upgrade the room to
Raises:
ShadowBanError if the requester is shadow-banned.
"""
user_id = requester.user.to_string()
# start by allocating a new room id
@@ -257,6 +269,9 @@ class RoomCreationHandler(BaseHandler):
old_room_id: the id of the room to be replaced
new_room_id: the id of the replacement room
old_room_state: the state map for the old room
Raises:
ShadowBanError if the requester is shadow-banned.
"""
old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
@@ -829,11 +844,13 @@ class RoomCreationHandler(BaseHandler):
async def send(etype: str, content: JsonDict, **kwargs) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
# Allow these events to be sent even if the user is shadow-banned to
# allow the room creation to complete.
(
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False
creator, event, ratelimit=False, ignore_shadow_ban=True,
)
return last_stream_id
+42 -32
View File
@@ -17,7 +17,7 @@ import abc
import logging
import random
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64
@@ -38,7 +38,15 @@ from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
from synapse.types import (
Collection,
JsonDict,
Requester,
RoomAlias,
RoomID,
StateMap,
UserID,
)
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -217,24 +225,40 @@ class RoomMemberHandler(object):
_, stream_id = await self.store.get_event_ordering(duplicate.event_id)
return duplicate.event_id, stream_id
stream_id = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
newly_joined = False
if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
if newly_joined:
await self._user_joined_room(target, room_id)
time_now_s = self.clock.time()
(
allowed,
time_allowed,
) = self._join_rate_limiter_local.can_requester_do_action(requester)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
stream_id = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
if event.membership == Membership.JOIN and newly_joined:
# Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
await self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -356,7 +380,7 @@ class RoomMemberHandler(object):
# later on.
content = dict(content)
if not self.allow_per_room_profiles:
if not self.allow_per_room_profiles or requester.shadow_banned:
# Strip profile data, knowing that new profile data will be added to the
# event's content in event_creation_handler.create_event() using the target's
# global profile.
@@ -489,22 +513,12 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if is_host_in_room:
if not is_host_in_room:
time_now_s = self.clock.time()
allowed, time_allowed = self._join_rate_limiter_local.can_do_action(
requester.user.to_string(),
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
else:
time_now_s = self.clock.time()
allowed, time_allowed = self._join_rate_limiter_remote.can_do_action(
requester.user.to_string(),
)
(
allowed,
time_allowed,
) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
if not allowed:
raise LimitExceededError(
@@ -736,9 +750,7 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)
async def _can_guest_join(
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
"""
Returns whether a guest can join a room based on its current state.
"""
@@ -967,9 +979,7 @@ class RoomMemberHandler(object):
)
return stream_id
async def _is_host_in_room(
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
@@ -21,9 +21,9 @@ class SlavedIdTracker(object):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self.advance(_load_current_id(db_conn, table, column))
self.advance(None, _load_current_id(db_conn, table, column))
def advance(self, new_id):
def advance(self, instance_name, new_id):
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self):
@@ -41,12 +41,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(token)
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(token)
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
@@ -46,7 +46,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(token)
self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed(
+2 -2
View File
@@ -50,10 +50,10 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(token)
self._device_list_id_gen.advance(instance_name, token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
+1 -1
View File
@@ -40,7 +40,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == GroupServerStream.NAME:
self._group_updates_id_gen.advance(token)
self._group_updates_id_gen.advance(instance_name, token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
@@ -44,7 +44,7 @@ class SlavedPresenceStore(BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(token)
self._presence_id_gen.advance(instance_name, token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
@@ -30,7 +30,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(token)
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
+1 -1
View File
@@ -34,5 +34,5 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(token)
self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
@@ -46,7 +46,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(token)
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id
+1 -1
View File
@@ -33,6 +33,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == PublicRoomsStream.NAME:
self._public_room_id_gen.advance(token)
self._public_room_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
+3 -1
View File
@@ -73,6 +73,7 @@ class UsersRestServletV2(RestServlet):
The parameters `from` and `limit` are required only for pagination.
By default, a `limit` of 100 is used.
The parameter `user_id` can be used to filter by user id.
The parameter `name` can be used to filter by user id or display name.
The parameter `guests` can be used to exclude guest users.
The parameter `deactivated` can be used to include deactivated users.
"""
@@ -89,11 +90,12 @@ class UsersRestServletV2(RestServlet):
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
user_id = parse_string(request, "user_id", default=None)
name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
users, total = await self.store.get_users_paginate(
start, limit, user_id, guests, deactivated
start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
if len(users) >= limit:
+44 -30
View File
@@ -201,8 +201,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if state_key is not None:
event_dict["state_key"] = state_key
if event_type == EventTypes.Member:
try:
try:
if event_type == EventTypes.Member:
membership = content.get("membership", None)
event_id, _ = await self.room_member_handler.update_membership(
requester,
@@ -211,16 +211,16 @@ class RoomStateEventRestServlet(TransactionRestServlet):
action=membership,
content=content,
)
except ShadowBanError:
event_id = "$" + random_string(43)
else:
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
event_id = event.event_id
else:
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
event_id = event.event_id
except ShadowBanError:
event_id = "$" + random_string(43)
set_tag("event_id", event_id)
ret = {"event_id": event_id}
@@ -253,12 +253,19 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
try:
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
event_id = event.event_id
except ShadowBanError:
event_id = "$" + random_string(43)
set_tag("event_id", event.event_id)
return 200, {"event_id": event.event_id}
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
def on_GET(self, request, room_id, event_type, txn_id):
return 200, "Not implemented"
@@ -799,20 +806,27 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"redacts": event_id,
},
txn_id=txn_id,
)
try:
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"redacts": event_id,
},
txn_id=txn_id,
)
event_id = event.event_id
except ShadowBanError:
event_id = "$" + random_string(43)
set_tag("event_id", event.event_id)
return 200, {"event_id": event.event_id}
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
def on_PUT(self, request, room_id, event_id, txn_id):
set_tag("txn_id", txn_id)
+10
View File
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from http import HTTPStatus
from typing import TYPE_CHECKING
@@ -113,6 +114,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if self.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@@ -504,6 +508,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if self.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -572,6 +579,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
+7
View File
@@ -16,6 +16,7 @@
import hmac
import logging
import random
from typing import List, Union
import synapse
@@ -131,6 +132,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -203,6 +207,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(
+13 -5
View File
@@ -22,7 +22,7 @@ any time to reflect changes in the MSC.
import logging
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.api.errors import ShadowBanError, SynapseError
from synapse.http.servlet import (
RestServlet,
parse_integer,
@@ -35,6 +35,7 @@ from synapse.storage.relations import (
PaginationChunk,
RelationPaginationToken,
)
from synapse.util.stringutils import random_string
from ._base import client_patterns
@@ -111,11 +112,18 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(),
}
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id
)
try:
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id
)
event_id = event.event_id
except ShadowBanError:
event_id = "$" + random_string(43)
return 200, {"event_id": event.event_id}
return 200, {"event_id": event_id}
class RelationPaginationServlet(RestServlet):
@@ -15,13 +15,14 @@
import logging
from synapse.api.errors import Codes, SynapseError
from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.util import stringutils
from ._base import client_patterns
@@ -62,7 +63,6 @@ class RoomUpgradeRestServlet(RestServlet):
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("new_version",))
new_version = content["new_version"]
new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"])
if new_version is None:
@@ -72,9 +72,13 @@ class RoomUpgradeRestServlet(RestServlet):
Codes.UNSUPPORTED_ROOM_VERSION,
)
new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version
)
try:
new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version
)
except ShadowBanError:
# Generate a random room ID.
new_room_id = stringutils.random_string(18)
ret = {"replacement_room": new_room_id}
+122 -70
View File
@@ -16,11 +16,22 @@
import logging
from collections import namedtuple
from typing import Awaitable, Dict, Iterable, List, Optional, Set
from typing import (
Awaitable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Union,
overload,
)
import attr
from frozendict import frozendict
from prometheus_client import Histogram
from typing_extensions import Literal
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@@ -30,7 +41,7 @@ from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap
from synapse.types import Collection, StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -68,8 +79,14 @@ def _gen_state_id():
class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
# dict[(str, str), str] map from (type, state_key) to event_id
def __init__(
self,
state: StateMap[str],
state_group: Optional[int],
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
):
# A map from (type, state_key) to event_id.
self.state = frozendict(state)
# the ID of a state group if one and only one is involved.
@@ -107,24 +124,49 @@ class StateHandler(object):
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
@overload
async def get_current_state(
self, room_id, event_type=None, state_key="", latest_event_ids=None
):
""" Retrieves the current state for the room. This is done by
self,
room_id: str,
event_type: Literal[None] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> StateMap[EventBase]:
...
@overload
async def get_current_state(
self,
room_id: str,
event_type: str,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Optional[EventBase]:
...
async def get_current_state(
self,
room_id: str,
event_type: Optional[str] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Union[Optional[EventBase], StateMap[EventBase]]:
"""Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
This is equivalent to getting the state of an event that were to send
next before receiving any new events.
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
Returns:
map from (type, state_key) to event
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
Otherwise, a map from (type, state_key) to event.
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
@@ -140,34 +182,30 @@ class StateHandler(object):
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
state = {
return {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
return state
async def get_current_state_ids(self, room_id, latest_event_ids=None):
async def get_current_state_ids(
self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room
Args:
room_id (str):
latest_event_ids (iterable[str]|None): if given, the forward
extremities to resolve. If None, we look them up from the
database (via a cache)
room_id:
latest_event_ids: if given, the forward extremities to resolve. If
None, we look them up from the database (via a cache).
Returns:
Deferred[dict[(str, str), str)]]: the state dict, mapping from
(event_type, state_key) -> event_id
the state dict, mapping from (event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
return state
return dict(ret.state)
async def get_current_users_in_room(
self, room_id: str, latest_event_ids: Optional[List[str]] = None
@@ -183,32 +221,34 @@ class StateHandler(object):
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = await self.store.get_joined_users_from_state(room_id, entry)
return joined_users
return await self.store.get_joined_users_from_state(room_id, entry)
async def get_current_hosts_in_room(self, room_id):
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events(self, room_id, event_ids):
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: List[str]
) -> Set[str]:
"""Get the hosts that were in a room at the given event ids
Args:
room_id (str):
event_ids (list[str]):
room_id:
event_ids:
Returns:
Deferred[list[str]]: the hosts in the room at the given events
The hosts in the room at the given events
"""
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = await self.store.get_joined_hosts(room_id, entry)
return joined_hosts
return await self.store.get_joined_hosts(room_id, entry)
async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
):
) -> EventContext:
"""Build an EventContext structure for the event.
This works out what the current state should be for the event, and
@@ -221,7 +261,7 @@ class StateHandler(object):
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
Returns:
synapse.events.snapshot.EventContext:
The event context.
"""
if event.internal_metadata.is_outlier():
@@ -275,7 +315,7 @@ class StateHandler(object):
event.room_id, event.prev_event_ids()
)
state_ids_before_event = entry.state
state_ids_before_event = dict(entry.state)
state_group_before_event = entry.state_group
state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids
@@ -346,19 +386,18 @@ class StateHandler(object):
)
@measure_func()
async def resolve_state_groups_for_events(self, room_id, event_ids):
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Iterable[str]
) -> _StateCacheEntry:
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
Args:
room_id (str)
event_ids (list[str])
explicit_room_version (str|None): If set uses the the given room
version to choose the resolution algorithm. If None, then
checks the database for room version.
room_id
event_ids
Returns:
Deferred[_StateCacheEntry]: resolved state
The resolved state
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
@@ -394,7 +433,12 @@ class StateHandler(object):
)
return result
async def resolve_events(self, room_version, state_sets, event):
async def resolve_events(
self,
room_version: str,
state_sets: Collection[Iterable[EventBase]],
event: EventBase,
) -> StateMap[EventBase]:
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
@@ -414,9 +458,7 @@ class StateHandler(object):
state_res_store=StateResolutionStore(self.store),
)
new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()}
return new_state
return {key: state_map[ev_id] for key, ev_id in new_state.items()}
class StateResolutionHandler(object):
@@ -444,7 +486,12 @@ class StateResolutionHandler(object):
@log_function
async def resolve_state_groups(
self, room_id, room_version, state_groups_ids, event_map, state_res_store
self,
room_id: str,
room_version: str,
state_groups_ids: Dict[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):
"""Resolves conflicts between a set of state groups
@@ -452,13 +499,13 @@ class StateResolutionHandler(object):
not be called for a single state group
Args:
room_id (str): room we are resolving for (used for logging and sanity checks)
room_version (str): version of the room
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
room_id: room we are resolving for (used for logging and sanity checks)
room_version: version of the room
state_groups_ids:
A map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
event_map(dict[str,FrozenEvent]|None):
event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
@@ -466,10 +513,10 @@ class StateResolutionHandler(object):
If None, all events will be fetched via state_res_store.
state_res_store (StateResolutionStore)
state_res_store
Returns:
_StateCacheEntry: resolved state
The resolved state
"""
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
@@ -530,21 +577,22 @@ class StateResolutionHandler(object):
return cache
def _make_state_cache_entry(new_state, state_groups_ids):
def _make_state_cache_entry(
new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
) -> _StateCacheEntry:
"""Given a resolved state, and a set of input state groups, pick one to base
a new state group on (if any), and return an appropriately-constructed
_StateCacheEntry.
Args:
new_state (dict[(str, str), str]): resolved state map (mapping from
(type, state_key) to event_id)
new_state: resolved state map (mapping from (type, state_key) to event_id)
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
state_groups_ids:
map from state group id to the state in that state group (where
'state' is a map from state key to event id)
Returns:
_StateCacheEntry
The cache entry.
"""
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
@@ -585,7 +633,7 @@ def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
state_sets: List[StateMap[str]],
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
) -> Awaitable[StateMap[str]]:
@@ -633,15 +681,17 @@ class StateResolutionStore(object):
store = attr.ib()
def get_events(self, event_ids, allow_rejected=False):
def get_events(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database
Args:
event_ids (list): The event_ids of the events to fetch
allow_rejected (bool): If True return rejected events.
event_ids: The event_ids of the events to fetch
allow_rejected: If True return rejected events.
Returns:
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
An awaitable which resolves to a dict from event_id to event.
"""
return self.store.get_events(
@@ -651,7 +701,9 @@ class StateResolutionStore(object):
allow_rejected=allow_rejected,
)
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
def get_auth_chain_difference(
self, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -660,7 +712,7 @@ class StateResolutionStore(object):
chain.
Returns:
Deferred[Set[str]]: Set of event IDs.
An awaitable that resolves to a set of event IDs.
"""
return self.store.get_auth_chain_difference(state_sets)
+59 -28
View File
@@ -15,7 +15,17 @@
import hashlib
import logging
from typing import Awaitable, Callable, Dict, List, Optional
from typing import (
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
)
from synapse import event_auth
from synapse.api.constants import EventTypes
@@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "")
async def resolve_events_with_store(
room_id: str,
state_sets: List[StateMap[str]],
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[List[str]], Awaitable],
):
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:
room_id: the room we are working in
@@ -56,8 +66,7 @@ async def resolve_events_with_store(
an Awaitable that resolves to a dict of event_id to event.
Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
A map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -75,8 +84,8 @@ async def resolve_events_with_store(
"Asking for %d/%d conflicted events", len(needed_events), needed_event_count
)
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
# A map from state event id to event. Only includes the state events which
# are in conflict (and those in event_map).
state_map = await state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
@@ -91,8 +100,6 @@ async def resolve_events_with_store(
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
#
# dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
@@ -122,29 +129,30 @@ async def resolve_events_with_store(
)
def _seperate(state_sets):
def _seperate(
state_sets: Iterable[StateMap[str]],
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
Args:
state_sets(iterable[dict[(str, str), str]]):
state_sets:
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
(dict[(str, str), str], dict[(str, str), set[str]]):
A tuple of (unconflicted_state, conflicted_state), where:
A tuple of (unconflicted_state, conflicted_state), where:
unconflicted_state is a dict mapping (type, state_key)->event_id
for unconflicted state keys.
unconflicted_state is a dict mapping (type, state_key)->event_id
for unconflicted state keys.
conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys.
conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys.
"""
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {}
conflicted_state = {} # type: StateMap[Set[str]]
for state_set in state_set_iterator:
for key, value in state_set.items():
@@ -171,7 +179,21 @@ def _seperate(state_sets):
return unconflicted_state, conflicted_state
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
def _create_auth_events_from_maps(
unconflicted_state: StateMap[str],
conflicted_state: StateMap[Set[str]],
state_map: Dict[str, EventBase],
) -> StateMap[str]:
"""
Args:
unconflicted_state: The unconflicted state map.
conflicted_state: The conflicted state map.
state_map:
Returns:
A map from state key to event id.
"""
auth_events = {}
for event_ids in conflicted_state.values():
for event_id in event_ids:
@@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
auth_event_id = unconflicted_state.get(key, None)
if auth_event_id:
auth_events[key] = auth_event_id
return auth_events
def _resolve_with_state(
unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
unconflicted_state_ids: StateMap[str],
conflicted_state_ids: StateMap[Set[str]],
auth_event_ids: StateMap[str],
state_map: Dict[str, EventBase],
):
conflicted_state = {}
for key, event_ids in conflicted_state_ids.items():
@@ -215,7 +240,9 @@ def _resolve_with_state(
return new_state
def _resolve_state_events(conflicted_state, auth_events):
def _resolve_state_events(
conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase]
) -> StateMap[EventBase]:
""" This is where we actually decide which of the conflicted state to
use.
@@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events):
return resolved_state
def _resolve_auth_events(events, auth_events):
def _resolve_auth_events(
events: List[EventBase], auth_events: StateMap[EventBase]
) -> EventBase:
reverse = list(reversed(_ordered_events(events)))
auth_keys = {
@@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events):
return event
def _resolve_normal_events(events, auth_events):
def _resolve_normal_events(
events: List[EventBase], auth_events: StateMap[EventBase]
) -> EventBase:
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
@@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events):
return event
def _ordered_events(events):
def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
def key_func(e):
# we have to use utf-8 rather than ascii here because it turns out we allow
# people to send us events with non-ascii event IDs :/
+167 -88
View File
@@ -16,7 +16,21 @@
import heapq
import itertools
import logging
from typing import Dict, List, Optional
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
overload,
)
from typing_extensions import Literal
import synapse.state
from synapse import event_auth
@@ -40,10 +54,10 @@ async def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
state_sets: List[StateMap[str]],
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
):
) -> StateMap[str]:
"""Resolves the state using the v2 state resolution algorithm
Args:
@@ -63,8 +77,7 @@ async def resolve_events_with_store(
state_res_store:
Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
A map from (type, state_key) to event_id.
"""
logger.debug("Computing conflicted state")
@@ -171,18 +184,23 @@ async def resolve_events_with_store(
return resolved_state
async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
async def _get_power_level_for_sender(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> int:
"""Return the power level of the sender of the given event according to
their auth events.
Args:
room_id (str)
event_id (str)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
room_id
event_id
event_map
state_res_store
Returns:
Deferred[int]
The power level.
"""
event = await _get_event(room_id, event_id, event_map, state_res_store)
@@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st
return int(level)
async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
async def _get_auth_chain_difference(
state_sets: Sequence[StateMap[str]],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains.
Args:
state_sets (list)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
state_sets
event_map
state_res_store
Returns:
Deferred[set[str]]: Set of event IDs
Set of event IDs
"""
difference = await state_res_store.get_auth_chain_difference(
@@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
return difference
def _seperate(state_sets):
def _seperate(
state_sets: Iterable[StateMap[str]],
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
"""Return the unconflicted and conflicted state. This is different than in
the original algorithm, as this defines a key to be conflicted if one of
the state sets doesn't have that key.
Args:
state_sets (list)
state_sets
Returns:
tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
conflicted state dict is a map from type/state_key to set of event IDs
A tuple of unconflicted and conflicted state. The conflicted state dict
is a map from type/state_key to set of event IDs
"""
unconflicted_state = {}
conflicted_state = {}
@@ -260,18 +284,20 @@ def _seperate(state_sets):
event_ids.discard(None)
conflicted_state[key] = event_ids
return unconflicted_state, conflicted_state
# mypy doesn't understand that discarding None above means that conflicted
# state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
return unconflicted_state, conflicted_state # type: ignore
def _is_power_event(event):
def _is_power_event(event: EventBase) -> bool:
"""Return whether or not the event is a "power event", as defined by the
v2 state resolution algorithm
Args:
event (FrozenEvent)
event
Returns:
boolean
True if the event is a power event.
"""
if (event.type, event.state_key) in (
(EventTypes.PowerLevels, ""),
@@ -288,19 +314,23 @@ def _is_power_event(event):
async def _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
):
graph: Dict[str, Set[str]],
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
auth_diff: Set[str],
) -> None:
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
Args:
graph (dict[str, set[str]]): A map from event ID to the events auth
event IDs
room_id (str): the room we are working in
event_id (str): Event to add to the graph
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
auth_diff (set[str]): Set of event IDs that are in the auth difference.
graph: A map from event ID to the events auth event IDs
room_id: the room we are working in
event_id: Event to add to the graph
event_map
state_res_store
auth_diff: Set of event IDs that are in the auth difference.
"""
state = [event_id]
@@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph(
async def _reverse_topological_power_sort(
clock, room_id, event_ids, event_map, state_res_store, auth_diff
):
clock: Clock,
room_id: str,
event_ids: Iterable[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
auth_diff: Set[str],
) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
Args:
clock (Clock)
room_id (str): the room we are working in
event_ids (list[str]): The events to sort
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
auth_diff (set[str]): Set of event IDs that are in the auth difference.
clock
room_id: the room we are working in
event_ids: The events to sort
event_map
state_res_store
auth_diff: Set of event IDs that are in the auth difference.
Returns:
Deferred[list[str]]: The sorted list
The sorted list
"""
graph = {}
graph = {} # type: Dict[str, Set[str]]
for idx, event_id in enumerate(event_ids, start=1):
await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
@@ -372,22 +407,28 @@ async def _reverse_topological_power_sort(
async def _iterative_auth_checks(
clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
):
clock: Clock,
room_id: str,
room_version: str,
event_ids: List[str],
base_state: StateMap[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> StateMap[str]:
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
Args:
clock (Clock)
room_id (str)
room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to
base_state (StateMap[str]): The set of state to start with
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
clock
room_id
room_version
event_ids: Ordered list of events to apply auth checks to
base_state: The set of state to start with
event_map
state_res_store
Returns:
Deferred[StateMap[str]]: Returns the final updated state
Returns the final updated state
"""
resolved_state = base_state.copy()
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@@ -439,21 +480,26 @@ async def _iterative_auth_checks(
async def _mainline_sort(
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
):
clock: Clock,
room_id: str,
event_ids: List[str],
resolved_power_event_id: Optional[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> List[str]:
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Args:
clock (Clock)
room_id (str): room we're working in
event_ids (list[str]): Events to sort
resolved_power_event_id (str): The final resolved power level event ID
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
clock
room_id: room we're working in
event_ids: Events to sort
resolved_power_event_id: The final resolved power level event ID
event_map
state_res_store
Returns:
Deferred[list[str]]: The sorted list
The sorted list
"""
if not event_ids:
# It's possible for there to be no event IDs here to sort, so we can
@@ -505,59 +551,90 @@ async def _mainline_sort(
async def _get_mainline_depth_for_event(
event, mainline_map, event_map, state_res_store
):
event: EventBase,
mainline_map: Dict[str, int],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> int:
"""Get the mainline depths for the given event based on the mainline map
Args:
event (FrozenEvent)
mainline_map (dict[str, int]): Map from event_id to mainline depth for
events in the mainline.
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
event
mainline_map: Map from event_id to mainline depth for events in the mainline.
event_map
state_res_store
Returns:
Deferred[int]
The mainline depth
"""
room_id = event.room_id
tmp_event = event # type: Optional[EventBase]
# We do an iterative search, replacing `event with the power level in its
# auth events (if any)
while event:
while tmp_event:
depth = mainline_map.get(event.event_id)
if depth is not None:
return depth
auth_events = event.auth_event_ids()
event = None
auth_events = tmp_event.auth_event_ids()
tmp_event = None
for aid in auth_events:
aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
event = aev
tmp_event = aev
break
# Didn't find a power level auth event, so we just return 0
return 0
async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
@overload
async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
allow_none: Literal[False] = False,
) -> EventBase:
...
@overload
async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
allow_none: Literal[True],
) -> Optional[EventBase]:
...
async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
allow_none: bool = False,
) -> Optional[EventBase]:
"""Helper function to look up event in event_map, falling back to looking
it up in the store
Args:
room_id (str)
event_id (str)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
allow_none (bool): if the event is not found, return None rather than raising
room_id
event_id
event_map
state_res_store
allow_none: if the event is not found, return None rather than raising
an exception
Returns:
Deferred[Optional[FrozenEvent]]
The event, or none if the event does not exist (and allow_none is True).
"""
if event_id not in event_map:
events = await state_res_store.get_events([event_id], allow_rejected=True)
@@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F
return event
def lexicographical_topological_sort(graph, key):
def lexicographical_topological_sort(
graph: Dict[str, Set[str]], key: Callable[[str], Any]
) -> Generator[str, None, None]:
"""Performs a lexicographic reverse topological sort on the graph.
This returns a reverse topological sort (i.e. if node A references B then B
@@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key):
NOTE: `graph` is modified during the sort.
Args:
graph (dict[str, set[str]]): A representation of the graph where each
node is a key in the dict and its value are the nodes edges.
key (func): A function that takes a node and returns a value that is
comparable and used to order nodes
graph: A representation of the graph where each node is a key in the
dict and its value are the nodes edges.
key: A function that takes a node and returns a value that is comparable
and used to order nodes
Yields:
str: The next node in the topological sort
The next node in the topological sort
"""
# Note, this is basically Kahn's algorithm except we look at nodes with no
# outgoing edges, c.f.
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
outdegree_map = graph
reverse_graph = {}
reverse_graph = {} # type: Dict[str, Set[str]]
# Lists of nodes with zero out degree. Is actually a tuple of
# `(key(node), node)` so that sorting does the right thing
+30 -6
View File
@@ -29,9 +29,11 @@ from typing import (
Tuple,
TypeVar,
Union,
overload,
)
from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
from twisted.internet import defer
@@ -1020,14 +1022,36 @@ class DatabasePool(object):
return txn.execute_batch(sql, args)
def simple_select_one(
@overload
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
...
@overload
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
...
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> defer.Deferred:
) -> Optional[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@@ -1038,18 +1062,18 @@ class DatabasePool(object):
allow_none: If true, return None instead of failing if the SELECT
statement returns no rows
"""
return self.runInteraction(
return await self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
def simple_select_one_onecol(
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one_onecol",
) -> defer.Deferred:
) -> Optional[Any]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@@ -1061,7 +1085,7 @@ class DatabasePool(object):
statement returns no rows
desc: description of the transaction, for logging and metrics
"""
return self.runInteraction(
return await self.runInteraction(
desc,
self.simple_select_one_onecol_txn,
table,
+19 -12
View File
@@ -498,7 +498,7 @@ class DataStore(
)
def get_users_paginate(
self, start, limit, name=None, guests=True, deactivated=False
self, start, limit, user_id=None, name=None, guests=True, deactivated=False
):
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
@@ -507,7 +507,8 @@ class DataStore(
Args:
start (int): start number to begin the query from
limit (int): number of rows to retrieve
name (string): filter for user names
user_id (string): search for user_id. ignored if name is not None
name (string): search for local part of user_id or display name
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
Returns:
@@ -516,11 +517,14 @@ class DataStore(
def get_users_paginate_txn(txn):
filters = []
args = []
args = [self.hs.config.server_name]
if name:
filters.append("(name LIKE ? OR displayname LIKE ?)")
args.extend(["@%" + name + "%:%", "%" + name + "%"])
elif user_id:
filters.append("name LIKE ?")
args.append("%" + name + "%")
args.extend(["%" + user_id + "%"])
if not guests:
filters.append("is_guest = 0")
@@ -530,20 +534,23 @@ class DataStore(
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
txn.execute(sql, args)
count = txn.fetchone()[0]
args = [self.hs.config.server_name] + args + [limit, start]
sql = """
SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
sql_base = """
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
{}
ORDER BY u.name LIMIT ? OFFSET ?
""".format(
where_clause
)
sql = "SELECT COUNT(*) as total_users " + sql_base
txn.execute(sql, args)
count = txn.fetchone()[0]
sql = (
"SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+ sql_base
+ " ORDER BY u.name LIMIT ? OFFSET ?"
)
args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)
return users, count
@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
with self._account_data_id_gen.get_next() as next_id:
with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
with self._account_data_id_gen.get_next() as next_id:
with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
with self._device_inbox_id_gen.get_next() as stream_id:
with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
with self._device_inbox_id_gen.get_next() as stream_id:
with await self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
+13 -9
View File
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, Iterable, List, Optional, Set, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id: str, device_id: str):
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore):
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
A dict containing the device information
Raises:
StoreError: if the device is not found
"""
return self.db_pool.simple_select_one(
return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -380,7 +380,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
with self._device_list_id_gen.get_next() as stream_id:
with await self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id: str):
async def get_device_list_last_stream_id_for_remote(
self, user_id: str
) -> Optional[Any]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
@@ -1146,7 +1148,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
with await self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
"add_device_change_to_stream",
self._add_device_change_to_stream_txn,
@@ -1159,7 +1163,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
with self._device_list_id_gen.get_next_mult(
with await self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
+2 -2
View File
@@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias):
return self.db_pool.simple_select_one_onecol(
async def get_room_alias_creator(self, room_alias: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
@@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return ret
def count_e2e_room_keys(self, user_id, version):
async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
"""Get the number of keys in a backup version.
Args:
user_id (str): the user whose backup we're querying
version (str): the version ID of the backup we're querying about
user_id: the user whose backup we're querying
version: the version ID of the backup we're querying about
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",
@@ -648,7 +648,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, stream_id):
"""Set a user's cross-signing key.
Args:
@@ -658,6 +658,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
key (dict): the key data
stream_id (int)
"""
# the 'key' dict will look something like:
# {
@@ -695,23 +696,22 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
)
# and finally, store the key itself
with self._cross_signing_id_gen.get_next() as stream_id:
self.db_pool.simple_insert_txn(
txn,
"e2e_cross_signing_keys",
values={
"user_id": user_id,
"keytype": key_type,
"keydata": json_encoder.encode(key),
"stream_id": stream_id,
},
)
self.db_pool.simple_insert_txn(
txn,
"e2e_cross_signing_keys",
values={
"user_id": user_id,
"keytype": key_type,
"keydata": json_encoder.encode(key),
"stream_id": stream_id,
},
)
self._invalidate_cache_and_stream(
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
def set_e2e_cross_signing_key(self, user_id, key_type, key):
async def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.
Args:
@@ -719,13 +719,16 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key_type (str): the type of cross-signing key to set
key (dict): the key data
"""
return self.db_pool.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,
user_id,
key_type,
key,
)
with await self._cross_signing_id_gen.get_next() as stream_id:
return await self.db_pool.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,
user_id,
key_type,
key,
stream_id,
)
def store_e2e_cross_signing_signatures(self, user_id, signatures):
"""Stores cross-signing signatures.
+2 -2
View File
@@ -153,11 +153,11 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
stream_ordering_manager = self._stream_id_gen.get_next_mult(
stream_ordering_manager = await self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
@@ -113,25 +113,25 @@ class EventsWorkerStore(SQLBaseStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(token)
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(-token)
self._backfill_id_gen.advance(instance_name, -token)
super().process_replication_rows(stream_name, instance_name, token, rows)
def get_received_ts(self, event_id):
async def get_received_ts(self, event_id: str) -> Optional[int]:
"""Get received_ts (when it was persisted) for the event.
Raises an exception for unknown events.
Args:
event_id (str)
event_id: The event ID to query.
Returns:
Deferred[int|None]: Timestamp in milliseconds, or None for events
that were persisted before received_ts was implemented.
Timestamp in milliseconds, or None for events that were persisted
before received_ts was implemented.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
+12 -8
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
class GroupServerWorkerStore(SQLBaseStore):
def get_group(self, group_id):
return self.db_pool.simple_select_one(
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore):
)
return bool(result)
def is_user_admin_in_group(self, group_id, user_id):
return self.db_pool.simple_select_one_onecol(
async def is_user_admin_in_group(
self, group_id: str, user_id: str
) -> Optional[bool]:
return await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
@@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="is_user_admin_in_group",
)
def is_user_invited_to_local_group(self, group_id, user_id):
async def is_user_invited_to_local_group(
self, group_id: str, user_id: str
) -> Optional[bool]:
"""Has the group server invited a user?
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
@@ -1182,7 +1186,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
with self._group_updates_id_gen.get_next() as next_id:
with await self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
def get_local_media(self, media_id):
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
Returns:
None if the media_id doesn't exist.
"""
return self.db_pool.simple_select_one(
return await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_thumbnail",
)
def get_cached_remote_media(self, origin, media_id):
return self.db_pool.simple_select_one(
async def get_cached_remote_media(
self, origin, media_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return users
@cached(num_args=1)
def user_last_seen_monthly_active(self, user_id):
async def user_last_seen_monthly_active(self, user_id: str) -> int:
"""
Checks if a given user is part of the monthly active user group
Arguments:
user_id (str): user to add/update
Return:
Deferred[int] : timestamp since last seen, None if never seen
Checks if a given user is part of the monthly active user group
Arguments:
user_id: user to add/update
Return:
Timestamp since last seen, None if never seen
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",
+1 -1
View File
@@ -23,7 +23,7 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
stream_ordering_manager = self._presence_id_gen.get_next_mult(
stream_ordering_manager = await self._presence_id_gen.get_next_mult(
len(presence_states)
)
+10 -7
View File
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
@@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
class ProfileWorkerStore(SQLBaseStore):
async def get_profileinfo(self, user_localpart):
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
@@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
def get_profile_displayname(self, user_localpart):
return self.db_pool.simple_select_one_onecol(
async def get_profile_displayname(self, user_localpart: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
def get_profile_avatar_url(self, user_localpart):
return self.db_pool.simple_select_one_onecol(
async def get_profile_avatar_url(self, user_localpart: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
def get_from_remote_profile_cache(self, user_id):
return self.db_pool.simple_select_one(
async def get_from_remote_profile_cache(
self, user_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
+4 -4
View File
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
with self._push_rules_stream_id_gen.get_next() as stream_id:
with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after:
@@ -560,7 +560,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
with self._push_rules_stream_id_gen.get_next() as stream_id:
with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -571,7 +571,7 @@ class PushRuleStore(PushRulesWorkerStore):
)
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
with self._push_rules_stream_id_gen.get_next() as stream_id:
with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -646,7 +646,7 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
with self._push_rules_stream_id_gen.get_next() as stream_id:
with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
+2 -2
View File
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering,
profile_tag="",
) -> None:
with self._pushers_id_gen.get_next() as stream_id:
with await self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
with self._pushers_id_gen.get_next() as stream_id:
with await self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
+5 -4
View File
@@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
return self.db_pool.simple_select_one_onecol(
async def get_last_receipt_event_id_for_user(
self, user_id: str, room_id: str, receipt_type: str
) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
@@ -520,8 +522,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
with await self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
+24 -11
View File
@@ -17,7 +17,7 @@
import logging
import re
from typing import Awaitable, Dict, List, Optional
from typing import Any, Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@cached()
def get_user_by_id(self, user_id):
return self.db_pool.simple_select_one(
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@@ -889,6 +889,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
super(RegistrationStore, self).__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
if self._account_validity.enabled:
self._clock.call_later(
@@ -1258,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="del_user_pending_deactivation",
)
def get_user_pending_deactivation(self):
async def get_user_pending_deactivation(self) -> Optional[str]:
"""
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
@@ -1302,15 +1303,22 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
if not row:
raise ThreepidValidationError(400, "Unknown session_id")
if self._ignore_unknown_session_error:
# If we need to inhibit the error caused by an incorrect session ID,
# use None as placeholder values for the client secret and the
# validation timestamp.
# It shouldn't be an issue because they're both only checked after
# the token check, which should fail. And if it doesn't for some
# reason, the next check is on the client secret, which is NOT NULL,
# so we don't have to worry about the client secret matching by
# accident.
row = {"client_secret": None, "validated_at": None}
else:
raise ThreepidValidationError(400, "Unknown session_id")
retrieved_client_secret = row["client_secret"]
validated_at = row["validated_at"]
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
400, "This client_secret does not match the provided session_id"
)
row = self.db_pool.simple_select_one_txn(
txn,
table="threepid_validation_token",
@@ -1326,6 +1334,11 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
expires = row["expires"]
next_link = row["next_link"]
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
400, "This client_secret does not match the provided session_id"
)
# If the session is already validated, no need to revalidate
if validated_at:
return next_link
+3 -2
View File
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import Optional
from synapse.storage._base import SQLBaseStore
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
def get_rejection_reason(self, event_id):
return self.db_pool.simple_select_one_onecol(
async def get_rejection_reason(self, event_id: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},
+8 -8
View File
@@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
self.config = hs.config
def get_room(self, room_id):
async def get_room(self, room_id: str) -> dict:
"""Retrieve a room.
Args:
room_id (str): The ID of the room to retrieve.
room_id: The ID of the room to retrieve.
Returns:
A dict containing the room information, or None if the room is unknown.
"""
return self.db_pool.simple_select_one(
return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
@@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore):
return ret_val
@cached(max_entries=10000)
def is_room_blocked(self, room_id):
return self.db_pool.simple_select_one_onecol(
async def is_room_blocked(self, room_id: str) -> Optional[bool]:
return await self.db_pool.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="1",
@@ -1129,7 +1129,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
with self._public_room_id_gen.get_next() as next_id:
with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@@ -1196,7 +1196,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
with self._public_room_id_gen.get_next() as next_id:
with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@@ -1276,7 +1276,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
with self._public_room_id_gen.get_next() as next_id:
with await self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
+2 -2
View File
@@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return event.content.get("canonical_alias")
@cached(max_entries=50000)
def _get_state_group_for_event(self, event_id):
return self.db_pool.simple_select_one_onecol(
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
return await self.db_pool.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
+5 -5
View File
@@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore):
return len(rooms_to_work_on)
def get_stats_positions(self):
async def get_stats_positions(self) -> int:
"""
Returns the stats processor positions.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="stats_incremental_position",
keyvalues={},
retcol="stream_id",
@@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
def get_earliest_token_for_stats(self, stats_type, id):
async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore):
being calculated.
Returns:
Deferred[int]
The earliest token.
"""
table, id_col = TYPE_TO_TABLE[stats_type]
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
"%s_current" % (table,),
keyvalues={id_col: id},
retcol="completed_delta_stream_id",
+2 -2
View File
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
with await self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -15,6 +15,7 @@
import logging
import re
from typing import Any, Dict, Optional
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
def get_user_in_directory(self, user_id):
return self.db_pool.simple_select_one(
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
@@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
def get_user_directory_stream_pos(self):
return self.db_pool.simple_select_one_onecol(
async def get_user_directory_stream_pos(self) -> int:
return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
+106 -7
View File
@@ -14,9 +14,10 @@
# limitations under the License.
import contextlib
import heapq
import threading
from collections import deque
from typing import Dict, Set
from typing import Dict, List, Set
from typing_extensions import Deque
@@ -80,7 +81,7 @@ class StreamIdGenerator(object):
upwards, -1 to grow downwards.
Usage:
with stream_id_gen.get_next() as stream_id:
with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -95,10 +96,10 @@ class StreamIdGenerator(object):
)
self._unfinished_ids = deque() # type: Deque[int]
def get_next(self):
async def get_next(self):
"""
Usage:
with stream_id_gen.get_next() as stream_id:
with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -117,10 +118,10 @@ class StreamIdGenerator(object):
return manager()
def get_next_mult(self, n):
async def get_next_mult(self, n):
"""
Usage:
with stream_id_gen.get_next(n) as stream_ids:
with await stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -210,6 +211,23 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
# We track the max position where we know everything before has been
# persisted. This is done by a) looking at the min across all instances
# and b) noting that if we have seen a run of persisted positions
# without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
#
# Note: There is no guarentee that the IDs generated by the sequence
# will be gapless; gaps can form when e.g. a transaction was rolled
# back. This means that sometimes we won't be able to skip forward the
# position even though everything has been persisted. However, since
# gaps should be relatively rare it's still worth doing the book keeping
# that allows us to skip forwards when there are gapless runs of
# positions.
self._persisted_upto_position = (
min(self._current_positions.values()) if self._current_positions else 0
)
self._known_persisted_positions = [] # type: List[int]
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
def _load_current_ids(
@@ -234,9 +252,12 @@ class MultiWriterIdGenerator:
return current_positions
def _load_next_id_txn(self, txn):
def _load_next_id_txn(self, txn) -> int:
return self._sequence_gen.get_next_id_txn(txn)
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
async def get_next(self):
"""
Usage:
@@ -262,6 +283,34 @@ class MultiWriterIdGenerator:
return manager()
async def get_next_mult(self, n: int):
"""
Usage:
with await stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
next_ids = await self._db.runInteraction(
"_load_next_mult_id", self._load_next_mult_id_txn, n
)
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
assert max(self.get_positions().values(), default=0) < min(next_ids)
with self._lock:
self._unfinished_ids.update(next_ids)
@contextlib.contextmanager
def manager():
try:
yield next_ids
finally:
for i in next_ids:
self._mark_id_as_finished(i)
return manager()
def get_next_txn(self, txn: LoggingTransaction):
"""
Usage:
@@ -326,3 +375,53 @@ class MultiWriterIdGenerator:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
)
self._add_persisted_position(new_id)
def get_persisted_upto_position(self) -> int:
"""Get the max position where all previous positions have been
persisted.
Note: In the worst case scenario this will be equal to the minimum
position across writers. This means that the returned position here can
lag if one writer doesn't write very often.
"""
with self._lock:
return self._persisted_upto_position
def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position.
This is used to keep the `_current_positions` up to date.
"""
# We require that the lock is locked by caller
assert self._lock.locked()
heapq.heappush(self._known_persisted_positions, new_id)
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
min_curr = min(self._current_positions.values())
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are
# less than the current min positions, and incrementing the min position
# if its exactly one greater.
#
# This is also where we discard items from `_known_persisted_positions`
# (to ensure the list doesn't infinitely grow).
while self._known_persisted_positions:
if self._known_persisted_positions[0] <= self._persisted_upto_position:
heapq.heappop(self._known_persisted_positions)
elif (
self._known_persisted_positions[0] == self._persisted_upto_position + 1
):
heapq.heappop(self._known_persisted_positions)
self._persisted_upto_position += 1
else:
# There was a gap in seen positions, so there is nothing more to
# do.
break
+7 -1
View File
@@ -14,7 +14,7 @@
# limitations under the License.
import abc
import threading
from typing import Callable, Optional
from typing import Callable, List, Optional
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
@@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
txn.execute("SELECT nextval(?)", (self._sequence_name,))
return txn.fetchone()[0]
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
txn.execute(
"SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
)
return [i for (i,) in txn]
GetFirstCallbackType = Callable[[Cursor], int]
+73
View File
@@ -1,4 +1,6 @@
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from synapse.appservice import ApplicationService
from synapse.types import create_requester
from tests import unittest
@@ -18,6 +20,77 @@ class TestRatelimiter(unittest.TestCase):
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
def test_allowed_user_via_can_requester_do_action(self):
user_requester = create_requester("@user:example.com")
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_requester_do_action(
user_requester, _time_now_s=0
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
user_requester, _time_now_s=5
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
user_requester, _time_now_s=10
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
None, "example.com", id="foo", rate_limited=True,
)
as_requester = create_requester("@user:example.com", app_service=appservice)
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=0
)
self.assertTrue(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=5
)
self.assertFalse(allowed)
self.assertEquals(10.0, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=10
)
self.assertTrue(allowed)
self.assertEquals(20.0, time_allowed)
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
None, "example.com", id="foo", rate_limited=False,
)
as_requester = create_requester("@user:example.com", app_service=appservice)
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=0
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=5
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
allowed, time_allowed = limiter.can_requester_do_action(
as_requester, _time_now_s=10
)
self.assertTrue(allowed)
self.assertEquals(-1, time_allowed)
def test_allowed_via_ratelimit(self):
limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+44 -12
View File
@@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_name(self):
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
yield defer.ensureDeferred(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.frank)
@@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank",
)
@defer.inlineCallbacks
@@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
yield defer.ensureDeferred(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank",
)
# Setting displayname a second time is forbidden
@@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield self.store.set_profile_displayname("caroline", "Caroline")
yield defer.ensureDeferred(
self.store.set_profile_displayname("caroline", "Caroline")
)
response = yield defer.ensureDeferred(
self.query_handlers["profile"](
@@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
yield self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
)
)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
@@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/pic.gif",
)
@@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/me.png",
)
@@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
yield self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
)
)
self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/me.png",
)
+2 -2
View File
@@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_users_in_room = get_users_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
self.datastore.get_user_directory_stream_pos.side_effect = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
defer.succeed(1)
lambda: make_awaitable(1)
)
self.datastore.get_current_state_deltas.return_value = (0, None)
+1 -1
View File
@@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the new user exists with all provided attributes
self.assertEqual(user_id, "@bob:test")
self.assertTrue(access_token)
self.assertTrue(self.store.get_user_by_id(user_id))
self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
# Check that the email was assigned
emails = self.get_success(self.store.user_get_threepids(user_id))
+71 -33
View File
@@ -45,33 +45,16 @@ class RetentionTestCase(unittest.HomeserverTestCase):
}
self.hs = self.setup_test_homeserver(config=config)
return self.hs
def prepare(self, reactor, clock, homeserver):
self.user_id = self.register_user("user", "password")
self.token = self.login("user", "password")
def test_retention_state_event(self):
"""Tests that the server configuration can limit the values a user can set to the
room's retention policy.
"""
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
body={"max_lifetime": one_day_ms * 4},
tok=self.token,
expect_code=400,
)
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
body={"max_lifetime": one_hour_ms},
tok=self.token,
expect_code=400,
)
self.store = self.hs.get_datastore()
self.serializer = self.hs.get_event_client_serializer()
self.clock = self.hs.get_clock()
def test_retention_event_purged_with_state_event(self):
"""Tests that expired events are correctly purged when the room's retention policy
@@ -90,6 +73,36 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self._test_retention_event_purged(room_id, one_day_ms * 1.5)
def test_retention_event_purged_with_state_event_outside_allowed(self):
"""Tests that the server configuration can override the policy for a room when
running the purge jobs.
"""
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
# Set a max_lifetime higher than the maximum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
body={"max_lifetime": one_day_ms * 4},
tok=self.token,
)
# Check that the event is purged after waiting for the maximum allowed duration
# instead of the one specified in the room's policy.
self._test_retention_event_purged(room_id, one_day_ms * 1.5)
# Set a max_lifetime lower than the minimum allowed value.
self.helper.send_state(
room_id=room_id,
event_type=EventTypes.Retention,
body={"max_lifetime": one_hour_ms},
tok=self.token,
)
# Check that the event is purged after waiting for the minimum allowed duration
# instead of the one specified in the room's policy.
self._test_retention_event_purged(room_id, one_day_ms * 0.5)
def test_retention_event_purged_without_state_event(self):
"""Tests that expired events are correctly purged when the room's retention policy
is defined by the server's configuration's default retention policy.
@@ -140,7 +153,27 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# That event should be the second, not outdated event.
self.assertEqual(filtered_events[0].event_id, valid_event_id, filtered_events)
def _test_retention_event_purged(self, room_id, increment):
def _test_retention_event_purged(self, room_id: str, increment: float):
"""Run the following test scenario to test the message retention policy support:
1. Send event 1
2. Increment time by `increment`
3. Send event 2
4. Increment time by `increment`
5. Check that event 1 has been purged
6. Check that event 2 has not been purged
7. Check that state events that were sent before event 1 aren't purged.
The main reason for sending a second event is because currently Synapse won't
purge the latest message in a room because it would otherwise result in a lack of
forward extremities for this room. It's also a good thing to ensure the purge jobs
aren't too greedy and purge messages they shouldn't.
Args:
room_id: The ID of the room to test retention in.
increment: The number of milliseconds to advance the clock each time. Must be
defined so that events in the room aren't purged if they are `increment`
old but are purged if they are `increment * 2` old.
"""
# Get the create event to, later, check that we can still access it.
message_handler = self.hs.get_message_handler()
create_event = self.get_success(
@@ -156,7 +189,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
expired_event_id = resp.get("event_id")
# Check that we can retrieve the event.
expired_event = self.get_event(room_id, expired_event_id)
expired_event = self.get_event(expired_event_id)
self.assertEqual(
expired_event.get("content", {}).get("body"), "1", expired_event
)
@@ -174,26 +207,31 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# one should still be kept.
self.reactor.advance(increment / 1000)
# Check that the event has been purged from the database.
self.get_event(room_id, expired_event_id, expected_code=404)
# Check that the first event has been purged from the database, i.e. that we
# can't retrieve it anymore, because it has expired.
self.get_event(expired_event_id, expect_none=True)
# Check that the event that hasn't been purged can still be retrieved.
valid_event = self.get_event(room_id, valid_event_id)
# Check that the event that hasn't expired can still be retrieved.
valid_event = self.get_event(valid_event_id)
self.assertEqual(valid_event.get("content", {}).get("body"), "2", valid_event)
# Check that we can still access state events that were sent before the event that
# has been purged.
self.get_event(room_id, create_event.event_id)
def get_event(self, room_id, event_id, expected_code=200):
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
def get_event(self, event_id, expect_none=False):
event = self.get_success(self.store.get_event(event_id, allow_none=True))
request, channel = self.make_request("GET", url, access_token=self.token)
self.render(request)
if expect_none:
self.assertIsNone(event)
return {}
self.assertEqual(channel.code, expected_code, channel.result)
self.assertIsNotNone(event)
return channel.json_body
time_now = self.clock.time_msec()
serialized = self.get_success(self.serializer.serialize_event(event, time_now))
return serialized
class RetentionNoDefaultPolicyTestCase(unittest.HomeserverTestCase):
+272
View File
@@ -0,0 +1,272 @@
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
from tests import unittest
class _ShadowBannedBase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
# Create two users, one of which is shadow-banned.
self.banned_user_id = self.register_user("banned", "test")
self.banned_access_token = self.login("banned", "test")
self.store = self.hs.get_datastore()
self.get_success(
self.store.db_pool.simple_update(
table="users",
keyvalues={"name": self.banned_user_id},
updatevalues={"shadow_banned": True},
desc="shadow_ban",
)
)
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
# To avoid the tests timing out don't add a delay to "annoy the requester".
@patch("random.randint", new=lambda a, b: 0)
class RoomTestCase(_ShadowBannedBase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
directory.register_servlets,
login.register_servlets,
room.register_servlets,
room_upgrade_rest_servlet.register_servlets,
]
def test_invite(self):
"""Invites from shadow-banned users don't actually get sent."""
# The create works fine.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
# Inviting the user completes successfully.
self.helper.invite(
room=room_id,
src=self.banned_user_id,
tok=self.banned_access_token,
targ=self.other_user_id,
)
# But the user wasn't actually invited.
invited_rooms = self.get_success(
self.store.get_invited_rooms_for_local_user(self.other_user_id)
)
self.assertEqual(invited_rooms, [])
def test_invite_3pid(self):
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
identity_handler = self.hs.get_handlers().identity_handler
identity_handler.lookup_3pid = Mock(
side_effect=AssertionError("This should not get called")
)
# The create works fine.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
# Inviting the user completes successfully.
request, channel = self.make_request(
"POST",
"/rooms/%s/invite" % (room_id,),
{"id_server": "test", "medium": "email", "address": "test@test.test"},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
# This should have raised an error earlier, but double check this wasn't called.
identity_handler.lookup_3pid.assert_not_called()
def test_create_room(self):
"""Invitations during a room creation should be discarded, but the room still gets created."""
# The room creation is successful.
request, channel = self.make_request(
"POST",
"/_matrix/client/r0/createRoom",
{"visibility": "public", "invite": [self.other_user_id]},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
room_id = channel.json_body["room_id"]
# But the user wasn't actually invited.
invited_rooms = self.get_success(
self.store.get_invited_rooms_for_local_user(self.other_user_id)
)
self.assertEqual(invited_rooms, [])
# Since a real room was created, the other user should be able to join it.
self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
# Both users should be in the room.
users = self.get_success(self.store.get_users_in_room(room_id))
self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
def test_message(self):
"""Messages from shadow-banned users don't actually get sent."""
room_id = self.helper.create_room_as(
self.other_user_id, tok=self.other_access_token
)
# The user should be in the room.
self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
# Sending a message should complete successfully.
result = self.helper.send_event(
room_id=room_id,
type=EventTypes.Message,
content={"msgtype": "m.text", "body": "with right label"},
tok=self.banned_access_token,
)
self.assertIn("event_id", result)
event_id = result["event_id"]
latest_events = self.get_success(
self.store.get_latest_event_ids_in_room(room_id)
)
self.assertNotIn(event_id, latest_events)
def test_upgrade(self):
"""A room upgrade should fail, but look like it succeeded."""
# The create works fine.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
request, channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
{"new_version": "6"},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
# A new room_id should be returned.
self.assertIn("replacement_room", channel.json_body)
new_room_id = channel.json_body["replacement_room"]
# It doesn't really matter what API we use here, we just want to assert
# that the room doesn't exist.
summary = self.get_success(self.store.get_room_summary(new_room_id))
# The summary should be empty since the room doesn't exist.
self.assertEqual(summary, {})
# To avoid the tests timing out don't add a delay to "annoy the requester".
@patch("random.randint", new=lambda a, b: 0)
class ProfileTestCase(_ShadowBannedBase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
profile.register_servlets,
room.register_servlets,
]
def test_displayname(self):
"""Profile changes should succeed, but don't end up in a room."""
original_display_name = "banned"
new_display_name = "new name"
# Join a room.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
# The update should succeed.
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
{"displayname": new_display_name},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertEqual(channel.json_body, {})
# The user's display name should be updated.
request, channel = self.make_request(
"GET", "/profile/%s/displayname" % (self.banned_user_id,)
)
self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["displayname"], new_display_name)
# But the display name in the room should not be.
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
self.banned_user_id,
room_id,
"m.room.member",
self.banned_user_id,
False,
)
)
self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name}
)
def test_room_displayname(self):
"""Changes to state events for a room should be processed, but not end up in the room."""
original_display_name = "banned"
new_display_name = "new name"
# Join a room.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
# The update should succeed.
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
% (room_id, self.banned_user_id),
{"membership": "join", "displayname": new_display_name},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertIn("event_id", channel.json_body)
# The display name in the room should not be changed.
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
self.banned_user_id,
room_id,
"m.room.member",
self.banned_user_id,
False,
)
)
self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name}
)
+87 -101
View File
@@ -28,7 +28,7 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account
from synapse.types import JsonDict, RoomAlias
from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
from tests import unittest
@@ -675,6 +675,92 @@ class RoomMemberStateTestCase(RoomBase):
self.assertEquals(json.loads(content), channel.json_body)
class RoomJoinRatelimitTestCase(RoomBase):
user_id = "@sid1:red"
servlets = [
profile.register_servlets,
room.register_servlets,
]
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit(self):
"""Tests that local joins are actually rate-limited."""
for i in range(3):
self.helper.create_room_as(self.user_id)
self.helper.create_room_as(self.user_id, expect_code=429)
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit_profile_change(self):
"""Tests that sending a profile update into all of the user's joined rooms isn't
rate-limited by the rate-limiter on joins."""
# Create and join as many rooms as the rate-limiting config allows in a second.
room_ids = [
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
]
# Let some time for the rate-limiter to forget about our multi-join.
self.reactor.advance(2)
# Add one to make sure we're joined to more rooms than the config allows us to
# join in a second.
room_ids.append(self.helper.create_room_as(self.user_id))
# Create a profile for the user, since it hasn't been done on registration.
store = self.hs.get_datastore()
self.get_success(
store.create_profile(UserID.from_string(self.user_id).localpart)
)
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
request, channel = self.make_request("PUT", path, {"displayname": "John Doe"})
self.render(request)
self.assertEquals(channel.code, 200, channel.json_body)
# Check that all the rooms have been sent a profile update into.
for room_id in room_ids:
path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (
room_id,
self.user_id,
)
request, channel = self.make_request("GET", path)
self.render(request)
self.assertEquals(channel.code, 200)
self.assertIn("displayname", channel.json_body)
self.assertEquals(channel.json_body["displayname"], "John Doe")
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit_idempotent(self):
"""Tests that the room join endpoints remain idempotent despite rate-limiting
on room joins."""
room_id = self.helper.create_room_as(self.user_id)
# Let's test both paths to be sure.
paths_to_test = [
"/_matrix/client/r0/rooms/%s/join",
"/_matrix/client/r0/join/%s",
]
for path in paths_to_test:
# Make sure we send more requests than the rate-limiting config would allow
# if all of these requests ended up joining the user to a room.
for i in range(4):
request, channel = self.make_request("POST", path % room_id, {})
self.render(request)
self.assertEquals(channel.code, 200)
class RoomMessagesTestCase(RoomBase):
""" Tests /rooms/$room_id/messages/$user_id/$msg_id REST events. """
@@ -1974,103 +2060,3 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
class ShadowBannedTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
directory.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
self.banned_user_id = self.register_user("banned", "test")
self.banned_access_token = self.login("banned", "test")
self.store = self.hs.get_datastore()
self.get_success(
self.store.db_pool.simple_update(
table="users",
keyvalues={"name": self.banned_user_id},
updatevalues={"shadow_banned": True},
desc="shadow_ban",
)
)
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
def test_invite(self):
"""Invites from shadow-banned users don't actually get sent."""
# The create works fine.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
# Inviting the user completes successfully.
self.helper.invite(
room=room_id,
src=self.banned_user_id,
tok=self.banned_access_token,
targ=self.other_user_id,
)
# But the user wasn't actually invited.
invited_rooms = self.get_success(
self.store.get_invited_rooms_for_local_user(self.other_user_id)
)
self.assertEqual(invited_rooms, [])
def test_invite_3pid(self):
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
identity_handler = self.hs.get_handlers().identity_handler
identity_handler.lookup_3pid = Mock(
side_effect=AssertionError("This should not get called")
)
# The create works fine.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
# Inviting the user completes successfully.
request, channel = self.make_request(
"POST",
"/rooms/%s/invite" % (room_id,),
{"id_server": "test", "medium": "email", "address": "test@test.test"},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
# This should have raised an error earlier, but double check this wasn't called.
identity_handler.lookup_3pid.assert_not_called()
def test_create_room(self):
"""Invitations during a room creation should be discarded, but the room still gets created."""
# The room creation is successful.
request, channel = self.make_request(
"POST",
"/_matrix/client/r0/createRoom",
{"visibility": "public", "invite": [self.other_user_id]},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
room_id = channel.json_body["room_id"]
# But the user wasn't actually invited.
invited_rooms = self.get_success(
self.store.get_invited_rooms_for_local_user(self.other_user_id)
)
self.assertEqual(invited_rooms, [])
# Since a real room was created, the other user should be able to join it.
self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
# Both users should be in the room.
users = self.get_success(self.store.get_users_in_room(room_id))
self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
+7 -3
View File
@@ -39,7 +39,9 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
def create_room_as(self, room_creator=None, is_public=True, tok=None):
def create_room_as(
self, room_creator=None, is_public=True, tok=None, expect_code=200,
):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
@@ -54,9 +56,11 @@ class RestHelper(object):
)
render(request, self.resource, self.hs.get_reactor())
assert channel.result["code"] == b"200", channel.result
assert channel.result["code"] == b"%d" % expect_code, channel.result
self.auth_user_id = temp_id
return channel.json_body["room_id"]
if expect_code == 200:
return channel.json_body["room_id"]
def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
self.change_membership(
+17 -11
View File
@@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
value = yield self.datastore.db_pool.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
value = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
)
)
self.assertEquals("Value", value)
@@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
ret = yield self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"],
ret = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"],
)
)
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
@@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
ret = yield self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
allow_none=True,
ret = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
allow_none=True,
)
)
self.assertFalse(ret)
+4 -4
View File
@@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name")
)
res = yield self.store.get_device("user_id", "device_id")
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name 1")
)
res = yield self.store.get_device("user_id", "device_id")
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
res = yield self.store.get_device("user_id", "device_id")
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
@@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# check it worked
res = yield self.store.get_device("user_id", "device_id")
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks
+36
View File
@@ -182,3 +182,39 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
def test_get_persisted_upto_position(self):
"""Test that `get_persisted_upto_position` correctly tracks updates to
positions.
"""
self._insert_rows("first", 3)
self._insert_rows("second", 5)
id_gen = self._create_id_generator("first")
# Min is 3 and there is a gap between 5, so we expect it to be 3.
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
# We advance "first" straight to 6. Min is now 5 but there is no gap so
# we expect it to be 6
id_gen.advance("first", 6)
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
# No gap, so we expect 7.
id_gen.advance("second", 7)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
# We haven't seen 8 yet, so we expect 7 still.
id_gen.advance("second", 9)
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
# Now that we've seen 7, 8 and 9 we can got straight to 9.
id_gen.advance("first", 8)
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
# Jump forward with gaps. The minimum is 11, even though we haven't seen
# 10 we know that everything before 11 must be persisted.
id_gen.advance("first", 11)
id_gen.advance("second", 15)
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
+18 -5
View File
@@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase):
def test_displayname(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
yield defer.ensureDeferred(
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
)
self.assertEquals(
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
"Frank",
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.u_frank.localpart)
)
),
)
@defer.inlineCallbacks
def test_avatar_url(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
)
)
self.assertEquals(
"http://my.site/here",
(yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
),
)
+32 -1
View File
@@ -17,6 +17,7 @@
from twisted.internet import defer
from synapse.api.constants import UserTypes
from synapse.api.errors import ThreepidValidationError
from tests import unittest
from tests.utils import setup_test_homeserver
@@ -52,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"user_type": None,
"deactivated": 0,
},
(yield self.store.get_user_by_id(self.user_id)),
(yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
)
@defer.inlineCallbacks
@@ -122,3 +123,33 @@ class RegistrationStoreTestCase(unittest.TestCase):
)
res = yield self.store.is_support_user(SUPPORT_USER)
self.assertTrue(res)
@defer.inlineCallbacks
def test_3pid_inhibit_invalid_validation_session_error(self):
"""Tests that enabling the configuration option to inhibit 3PID errors on
/requestToken also inhibits validation errors caused by an unknown session ID.
"""
# Check that, with the config setting set to false (the default value), a
# validation error is caused by the unknown session ID.
try:
yield defer.ensureDeferred(
self.store.validate_threepid_session(
"fake_sid", "fake_client_secret", "fake_token", 0,
)
)
except ThreepidValidationError as e:
self.assertEquals(e.msg, "Unknown session_id", e)
# Set the config setting to true.
self.store._ignore_unknown_session_error = True
# Check that now the validation error is caused by the token not matching.
try:
yield defer.ensureDeferred(
self.store.validate_threepid_session(
"fake_sid", "fake_client_secret", "fake_token", 0,
)
)
except ThreepidValidationError as e:
self.assertEquals(e.msg, "Validation token not found or has expired", e)
+16 -4
View File
@@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
(yield self.store.get_room(self.room.to_string())),
(yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
)
@defer.inlineCallbacks
def test_get_room_unknown_room(self):
self.assertIsNone((yield self.store.get_room("!uknown:test")),)
self.assertIsNone(
(yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
)
@defer.inlineCallbacks
def test_get_room_with_stats(self):
@@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
(yield self.store.get_room_with_stats(self.room.to_string())),
(
yield defer.ensureDeferred(
self.store.get_room_with_stats(self.room.to_string())
)
),
)
@defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
self.assertIsNone(
(
yield defer.ensureDeferred(
self.store.get_room_with_stats("!uknown:test")
)
),
)
class RoomEventsStoreTestCase(unittest.TestCase):
+1
View File
@@ -209,6 +209,7 @@ commands = mypy \
synapse/server.py \
synapse/server_notices \
synapse/spam_checker_api \
synapse/state \
synapse/storage/databases/main/ui_auth.py \
synapse/storage/database.py \
synapse/storage/engines \