1
0

Compare commits

...

53 Commits

Author SHA1 Message Date
Andrew Morgan
7affcd01c7 Merge branch 'develop' of github.com:matrix-org/synapse into anoa/user_param_ui_auth
* 'develop' of github.com:matrix-org/synapse: (369 commits)
  Add functions to `MultiWriterIdGen` used by events stream (#8164)
  Do not allow send_nonmember_event to be called with shadow-banned users. (#8158)
  Changelog fixes
  1.19.1rc1
  Make StreamIdGen `get_next` and `get_next_mult` async  (#8161)
  Wording fixes to 'name' user admin api filter (#8163)
  Fix missing double-backtick in RST document
  Search in columns 'name' and 'displayname' in the admin users endpoint (#7377)
  Add type hints for state. (#8140)
  Stop shadow-banned users from sending non-member events. (#8142)
  Allow capping a room's retention policy (#8104)
  Add healthcheck for default localhost 8008 port on /health endpoint. (#8147)
  Fix flaky shadow-ban tests. (#8152)
  Fix join ratelimiter breaking profile updates and idempotency (#8153)
  Do not apply ratelimiting on joins to appservices (#8139)
  Don't fail /submit_token requests on incorrect session ID if request_token_inhibit_3pid_errors is turned on (#7991)
  Do not apply ratelimiting on joins to appservices (#8139)
  Micro-optimisations to get_auth_chain_ids (#8132)
  Allow denying or shadow banning registrations via the spam checker (#8034)
  Stop shadow-banned users from sending invites. (#8095)
  ...
2020-08-26 12:22:25 +01:00
Erik Johnston
eba98fb024 Add functions to MultiWriterIdGen used by events stream (#8164) 2020-08-25 17:32:30 +01:00
Patrick Cloke
5099bd68da Do not allow send_nonmember_event to be called with shadow-banned users. (#8158) 2020-08-25 10:52:15 -04:00
Brendan Abolivier
6e1c64a668 Merge tag 'v1.19.1rc1' into develop
Synapse 1.19.1rc1 (2020-08-25)
==============================

Bugfixes
--------

- Fix a bug introduced in v1.19.0 where appservices with ratelimiting disabled would still be ratelimited when joining rooms. ([\#8139](https://github.com/matrix-org/synapse/issues/8139))
- Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests. ([\#8153](https://github.com/matrix-org/synapse/issues/8153))
2020-08-25 15:48:11 +01:00
Brendan Abolivier
0a4e541dc5 Changelog fixes 2020-08-25 15:29:57 +01:00
Brendan Abolivier
b79d69796c 1.19.1rc1 2020-08-25 15:24:39 +01:00
Erik Johnston
2231dffee6 Make StreamIdGen get_next and get_next_mult async (#8161)
This is mainly so that `StreamIdGenerator` and `MultiWriterIdGenerator`
will have the same interface, allowing them to be used interchangeably.
2020-08-25 15:10:08 +01:00
Andrew Morgan
74bf8d4d06 Wording fixes to 'name' user admin api filter (#8163)
Some fixes to wording I noticed after merging #7377.
2020-08-25 15:03:24 +01:00
Andrew Morgan
79ac619403 Fix missing double-backtick in RST document 2020-08-25 14:24:06 +01:00
Manuel Stahl
97962ad17b Search in columns 'name' and 'displayname' in the admin users endpoint (#7377)
* Search in columns 'name' and 'displayname' in the admin users endpoint

Signed-off-by: Manuel Stahl <manuel.stahl@awesome-technologies.de>
2020-08-25 14:18:14 +01:00
Patrick Cloke
5758dcf30c Add type hints for state. (#8140) 2020-08-24 14:25:27 -04:00
Patrick Cloke
cbd8d83da7 Stop shadow-banned users from sending non-member events. (#8142) 2020-08-24 13:58:56 -04:00
Brendan Abolivier
420484a334 Allow capping a room's retention policy (#8104) 2020-08-24 18:21:04 +01:00
Christopher May-Townsend
64e8a4697a Add healthcheck for default localhost 8008 port on /health endpoint. (#8147) 2020-08-24 18:15:18 +01:00
Patrick Cloke
3f8f96be00 Fix flaky shadow-ban tests. (#8152) 2020-08-24 13:08:33 -04:00
Brendan Abolivier
393a811a41 Fix join ratelimiter breaking profile updates and idempotency (#8153) 2020-08-24 18:06:04 +01:00
Will Hunt
2df82ae451 Do not apply ratelimiting on joins to appservices (#8139)
Add new method ratelimiter.can_requester_do_action and ensure that appservices are exempt from being ratelimited.

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
Co-authored-by: Erik Johnston <erik@matrix.org>
2020-08-24 14:53:53 +01:00
Brendan Abolivier
3f49f74610 Don't fail /submit_token requests on incorrect session ID if request_token_inhibit_3pid_errors is turned on (#7991)
* Don't raise session_id errors on submit_token if request_token_inhibit_3pid_errors is set

* Changelog

* Also wait some time before responding to /requestToken

* Incorporate review

* Update synapse/storage/databases/main/registration.py

Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>

* Incorporate review

Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
2020-08-24 11:33:55 +01:00
Will Hunt
cbbf9126cb Do not apply ratelimiting on joins to appservices (#8139)
Add new method ratelimiter.can_requester_do_action and ensure that appservices are exempt from being ratelimited.

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
Co-authored-by: Erik Johnston <erik@matrix.org>
2020-08-21 15:07:56 +01:00
Richard van der Hoff
09fd0eda81 Micro-optimisations to get_auth_chain_ids (#8132) 2020-08-21 10:06:45 +01:00
Patrick Cloke
3f91638da6 Allow denying or shadow banning registrations via the spam checker (#8034) 2020-08-20 15:42:58 -04:00
Patrick Cloke
e259d63f73 Stop shadow-banned users from sending invites. (#8095) 2020-08-20 15:07:42 -04:00
Richard van der Hoff
318f4e738e Be more tolerant of membership events in unknown rooms (#8110)
It turns out that not all out-of-band membership events are labelled as such,
so we need to be more accepting here.
2020-08-20 16:42:12 +01:00
Patrick Cloke
592cdf73be Improve the error code when trying to register using a name reserved for guests. (#8135) 2020-08-20 10:39:41 -04:00
Patrick Cloke
dbc630a628 Use the JSON encoder without whitespace in more places. (#8124) 2020-08-20 10:32:33 -04:00
Patrick Cloke
5eac0b7e76 Add more types to synapse.storage.database. (#8127) 2020-08-20 09:00:59 -04:00
Patrick Cloke
731dfff347 Ensure a group ID is valid before trying to get rooms for it. (#8129) 2020-08-20 06:41:32 -04:00
Patrick Cloke
76c43f086a Do not assume calls to runInteraction return Deferreds. (#8133) 2020-08-20 06:39:55 -04:00
Richard van der Hoff
12aebdfa5a Close the database connection we create during startup (#8131)
... otherwise it gets leaked.
2020-08-19 20:41:53 +01:00
Erik Johnston
c9c544cda5 Remove ChainedIdGenerator. (#8123)
It's just a thin wrapper around two ID gens to make `get_current_token`
and `get_next` return tuples. This can easily be replaced by calling the
appropriate methods on the underlying ID gens directly.
2020-08-19 13:41:51 +01:00
Patrick Cloke
f594e434c3 Switch the JSON byte producer from a pull to a push producer. (#8116) 2020-08-19 08:07:57 -04:00
Ryan Cole
cfeb37f039 Updated docs: Added note about missing 308 redirect support. (#8120)
* Updated docs: Added note about missing 308 redirect support.

* Added changelog
2020-08-19 12:26:50 +01:00
Patrick Cloke
eebf52be06 Be stricter about JSON that is accepted by Synapse (#8106) 2020-08-19 07:26:03 -04:00
Patrick Cloke
d89692ea84 Convert runWithConnection to async. (#8121) 2020-08-19 07:09:24 -04:00
Patrick Cloke
d294f0e7e1 Remove the unused inlineCallbacks code-paths in the caching code (#8119) 2020-08-19 07:09:07 -04:00
Erik Johnston
76d21d14a0 Separate get_current_token into two. (#8113)
The function is used for two purposes: 1) for subscribers of streams to
get a token they can use to get further updates with, and 2) for
replication to track position of the writers of the stream.

For streams with a single writer the two scenarios produce the same
result, however the situation becomes complicated for streams with
multiple writers. The current `MultiWriterIdGenerator` does not
correctly handle the first case (which is not an issue as its only used
for the `caches` stream which nothing subscribes to outside of
replication).
2020-08-19 10:39:31 +01:00
Patrick Cloke
f40645e60b Convert events worker database to async/await. (#8071) 2020-08-18 16:20:49 -04:00
Andrew Morgan
af21fbb338 Simplify medium and address assignment 2020-06-25 11:05:52 +01:00
Andrew Morgan
cb272bcfe8 Explain why we rate-limit using a threepid 2020-06-25 11:03:10 +01:00
Andrew Morgan
d9277e94f3 Don't lowercase medium in this PR 2020-06-16 12:00:57 +01:00
Andrew Morgan
b1c0eb3178 Docstring spacing 2020-06-16 11:39:19 +01:00
Andrew Morgan
53981c31e9 Change SynapseError comment 2020-06-16 11:33:16 +01:00
Andrew Morgan
efb5670845 Update synapse/handlers/auth.py
Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
2020-06-16 11:33:16 +01:00
Andrew Morgan
b8f4b0c27c Use assert_param_in_dict 2020-06-16 11:33:13 +01:00
Andrew Morgan
187623517b pop() instead of pull then del 2020-06-16 11:09:16 +01:00
Andrew Morgan
7184c16f95 Change login_id_phone_to_thirdparty to return a dict again 2020-06-16 11:03:49 +01:00
Andrew Morgan
699904c9d8 Changelog 2020-06-12 14:42:58 +01:00
Andrew Morgan
358e51be86 Add some tests for m.id.phone and m.id.thirdparty 2020-06-12 14:42:56 +01:00
Andrew Morgan
18071156e4 Remove placeholders/dummy classes for supporting identifiers in existing tests 2020-06-12 14:42:21 +01:00
Andrew Morgan
cb64c956f0 Comment cleanups, log on KeyError during login 2020-06-12 14:42:21 +01:00
Andrew Morgan
f240a8d182 Reconfigure m.login.password authdict checker to process identifiers 2020-06-12 14:42:21 +01:00
Andrew Morgan
7044c1f4fb Factor out identifier -> username conversion into its own method
We then use this in both login and authhandler, the latter being where we process m.login.password
User Interactive Authentication responses, which can now include identifiers
2020-06-12 14:42:21 +01:00
Andrew Morgan
b674bb8500 Move utility methods from login handler to auth handler 2020-06-12 14:42:18 +01:00
143 changed files with 2978 additions and 1344 deletions

View File

@@ -12,6 +12,16 @@ from Synapse as most users have updated their client. Further context can be
found at [\#6766](https://github.com/matrix-org/synapse/issues/6766).
Synapse 1.19.1rc1 (2020-08-25)
==============================
Bugfixes
--------
- Fix a bug introduced in v1.19.0 where appservices with ratelimiting disabled would still be ratelimited when joining rooms. ([\#8139](https://github.com/matrix-org/synapse/issues/8139))
- Fix a bug introduced in v1.19.0 that would cause e.g. profile updates to fail due to incorrect application of rate limits on join requests. ([\#8153](https://github.com/matrix-org/synapse/issues/8153))
Synapse 1.19.0 (2020-08-17)
===========================

1
changelog.d/7377.misc Normal file
View File

@@ -0,0 +1 @@
Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH.

1
changelog.d/7438.feature Normal file
View File

@@ -0,0 +1 @@
Support `identifier` dictionary fields in User-Interactive Authentication flows. Relax requirement of the `user` parameter.

1
changelog.d/7991.misc Normal file
View File

@@ -0,0 +1 @@
Don't fail `/submit_token` requests on incorrect session ID if `request_token_inhibit_3pid_errors` is turned on.

1
changelog.d/8034.feature Normal file
View File

@@ -0,0 +1 @@
Add support for shadow-banning users (ignoring any message send requests).

1
changelog.d/8071.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8095.feature Normal file
View File

@@ -0,0 +1 @@
Add support for shadow-banning users (ignoring any message send requests).

1
changelog.d/8104.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug introduced in v1.7.2 impacting message retention policies that would allow federated homeservers to dictate a retention period that's lower than the configured minimum allowed duration in the configuration file.

1
changelog.d/8106.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a long-standing bug where invalid JSON would be accepted by Synapse.

1
changelog.d/8110.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug introduced in Synapse 1.12.0 which could cause `/sync` requests to fail with a 404 if you had a very old outstanding room invite.

1
changelog.d/8113.misc Normal file
View File

@@ -0,0 +1 @@
Separate `get_current_token` into two since there are two different use cases for it.

1
changelog.d/8116.feature Normal file
View File

@@ -0,0 +1 @@
Iteratively encode JSON to avoid blocking the reactor.

1
changelog.d/8119.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8120.doc Normal file
View File

@@ -0,0 +1 @@
Updated documentation to note that Synapse does not follow `HTTP 308` redirects due to an upstream library not supporting them. Contributed by Ryan Cole.

1
changelog.d/8121.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8123.misc Normal file
View File

@@ -0,0 +1 @@
Remove `ChainedIdGenerator`.

1
changelog.d/8124.misc Normal file
View File

@@ -0,0 +1 @@
Reduce the amount of whitespace in JSON stored and sent in responses.

1
changelog.d/8127.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to `synapse.storage.database`.

1
changelog.d/8129.bugfix Normal file
View File

@@ -0,0 +1 @@
Return a proper error code when the rooms of an invalid group are requested.

1
changelog.d/8131.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug which could cause a leaked postgres connection if synapse was set to daemonize.

1
changelog.d/8132.misc Normal file
View File

@@ -0,0 +1 @@
Micro-optimisations to get_auth_chain_ids.

1
changelog.d/8133.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8135.bugfix Normal file
View File

@@ -0,0 +1 @@
Clarify the error code if a user tries to register with a numeric ID. This bug was introduced in v1.15.0.

1
changelog.d/8139.bugfix Normal file
View File

@@ -0,0 +1 @@
Fixes a bug where appservices with ratelimiting disabled would still be ratelimited when joining rooms. This bug was introduced in v1.19.0.

1
changelog.d/8140.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to `synapse.state`.

1
changelog.d/8142.feature Normal file
View File

@@ -0,0 +1 @@
Add support for shadow-banning users (ignoring any message send requests).

1
changelog.d/8147.docker Normal file
View File

@@ -0,0 +1 @@
Added curl for healthcheck support and readme updates for the change. Contributed by @maquis196.

1
changelog.d/8152.feature Normal file
View File

@@ -0,0 +1 @@
Add support for shadow-banning users (ignoring any message send requests).

1
changelog.d/8158.feature Normal file
View File

@@ -0,0 +1 @@
Add support for shadow-banning users (ignoring any message send requests).

1
changelog.d/8161.misc Normal file
View File

@@ -0,0 +1 @@
Refactor `StreamIdGenerator` and `MultiWriterIdGenerator` to have the same interface.

1
changelog.d/8163.misc Normal file
View File

@@ -0,0 +1 @@
Add filter `name` to the `/users` admin API, which filters by user ID or displayname. Contributed by Awesome Technologies Innovationslabor GmbH.

1
changelog.d/8164.misc Normal file
View File

@@ -0,0 +1 @@
Add functions to `MultiWriterIdGen` used by events stream.

View File

@@ -55,6 +55,7 @@ RUN pip install --prefix="/install" --no-warn-script-location \
FROM docker.io/python:${PYTHON_VERSION}-slim
RUN apt-get update && apt-get install -y \
curl \
libpq5 \
xmlsec1 \
gosu \
@@ -69,3 +70,6 @@ VOLUME ["/data"]
EXPOSE 8008/tcp 8009/tcp 8448/tcp
ENTRYPOINT ["/start.py"]
HEALTHCHECK --interval=1m --timeout=5s \
CMD curl -fSs http://localhost:8008/health || exit 1

View File

@@ -162,3 +162,32 @@ docker build -t matrixdotorg/synapse -f docker/Dockerfile .
You can choose to build a different docker image by changing the value of the `-f` flag to
point to another Dockerfile.
## Disabling the healthcheck
If you are using a non-standard port or tls inside docker you can disable the healthcheck
whilst running the above `docker run` commands.
```
--no-healthcheck
```
## Setting custom healthcheck on docker run
If you wish to point the healthcheck at a different port with docker command, add the following
```
--health-cmd 'curl -fSs http://localhost:1234/health'
```
## Setting the healthcheck in docker-compose file
You can add the following to set a custom healthcheck in a docker compose file.
You will need version >2.1 for this to work.
```
healthcheck:
test: ["CMD", "curl", "-fSs", "http://localhost:8008/health"]
interval: 1m
timeout: 10s
retries: 3
```

View File

@@ -108,7 +108,7 @@ The api is::
GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
To use it, you will need to authenticate by providing an `access_token` for a
To use it, you will need to authenticate by providing an ``access_token`` for a
server admin: see `README.rst <README.rst>`_.
The parameter ``from`` is optional but used for pagination, denoting the
@@ -119,8 +119,11 @@ from a previous call.
The parameter ``limit`` is optional but is used for pagination, denoting the
maximum number of items to return in this call. Defaults to ``100``.
The parameter ``user_id`` is optional and filters to only users with user IDs
that contain this value.
The parameter ``user_id`` is optional and filters to only return users with user IDs
that contain this value. This parameter is ignored when using the ``name`` parameter.
The parameter ``name`` is optional and filters to only return users with user ID localparts
**or** displaynames that contain this value.
The parameter ``guests`` is optional and if ``false`` will **exclude** guest users.
Defaults to ``true`` to include guest users.

View File

@@ -47,6 +47,18 @@ you invite them to. This can be caused by an incorrectly-configured reverse
proxy: see [reverse_proxy.md](<reverse_proxy.md>) for instructions on how to correctly
configure a reverse proxy.
### Known issues
**HTTP `308 Permanent Redirect` redirects are not followed**: Due to missing features
in the HTTP library used by Synapse, 308 redirects are currently not followed by
federating servers, which can cause `M_UNKNOWN` or `401 Unauthorized` errors. This
may affect users who are redirecting apex-to-www (e.g. `example.com` -> `www.example.com`),
and especially users of the Kubernetes *Nginx Ingress* module, which uses 308 redirect
codes by default. For those Kubernetes users, [this Stackoverflow post](https://stackoverflow.com/a/52617528/5096871)
might be helpful. For other users, switching to a `301 Moved Permanently` code may be
an option. 308 redirect codes will be supported properly in a future
release of Synapse.
## Running a demo federation of Synapses
If you want to get up and running quickly with a trio of homeservers in a

View File

@@ -378,11 +378,10 @@ retention:
# min_lifetime: 1d
# max_lifetime: 1y
# Retention policy limits. If set, a user won't be able to send a
# 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
# that's not within this range. This is especially useful in closed federations,
# in which server admins can make sure every federating server applies the same
# rules.
# Retention policy limits. If set, and the state of a room contains a
# 'm.room.retention' event in its state which contains a 'min_lifetime' or a
# 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
# to these limits when running purge jobs.
#
#allowed_lifetime_min: 1d
#allowed_lifetime_max: 1y
@@ -408,12 +407,19 @@ retention:
# (e.g. every 12h), but not want that purge to be performed by a job that's
# iterating over every room it knows, which could be heavy on the server.
#
# If any purge job is configured, it is strongly recommended to have at least
# a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
# set, or one job without 'shortest_max_lifetime' and one job without
# 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
# 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
# room's policy to these values is done after the policies are retrieved from
# Synapse's database (which is done using the range specified in a purge job's
# configuration).
#
#purge_jobs:
# - shortest_max_lifetime: 1d
# longest_max_lifetime: 3d
# - longest_max_lifetime: 3d
# interval: 12h
# - shortest_max_lifetime: 3d
# longest_max_lifetime: 1y
# interval: 1d
# Inhibits the /requestToken endpoints from returning an error that might leak

47
stubs/frozendict.pyi Normal file
View File

@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Stub for frozendict.
from typing import (
Any,
Hashable,
Iterable,
Iterator,
Mapping,
overload,
Tuple,
TypeVar,
)
_KT = TypeVar("_KT", bound=Hashable) # Key type.
_VT = TypeVar("_VT") # Value type.
class frozendict(Mapping[_KT, _VT]):
@overload
def __init__(self, **kwargs: _VT) -> None: ...
@overload
def __init__(self, __map: Mapping[_KT, _VT], **kwargs: _VT) -> None: ...
@overload
def __init__(
self, __iterable: Iterable[Tuple[_KT, _VT]], **kwargs: _VT
) -> None: ...
def __getitem__(self, key: _KT) -> _VT: ...
def __contains__(self, key: Any) -> bool: ...
def copy(self, **add_or_replace: Any) -> frozendict: ...
def __iter__(self) -> Iterator[_KT]: ...
def __len__(self) -> int: ...
def __repr__(self) -> str: ...
def __hash__(self) -> int: ...

View File

@@ -48,7 +48,7 @@ try:
except ImportError:
pass
__version__ = "1.19.0"
__version__ = "1.19.1rc1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View File

@@ -21,10 +21,10 @@ import typing
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from canonicaljson import json
from twisted.web import http
from synapse.util import json_decoder
if typing.TYPE_CHECKING:
from synapse.types import JsonDict
@@ -593,7 +593,7 @@ class HttpResponseException(CodeMessageException):
# try to parse the body as json, to get better errcode/msg, but
# default to M_UNKNOWN with the HTTP status as the error text
try:
j = json.loads(self.response.decode("utf-8"))
j = json_decoder.decode(self.response.decode("utf-8"))
except ValueError:
j = {}
@@ -604,3 +604,11 @@ class HttpResponseException(CodeMessageException):
errmsg = j.pop("error", self.msg)
return ProxiedRequestError(self.code, errmsg, errcode, j)
class ShadowBanError(Exception):
"""
Raised when a shadow-banned user attempts to perform an action.
This should be caught and a proper "fake" success response sent to the user.
"""

View File

@@ -17,6 +17,7 @@ from collections import OrderedDict
from typing import Any, Optional, Tuple
from synapse.api.errors import LimitExceededError
from synapse.types import Requester
from synapse.util import Clock
@@ -43,6 +44,42 @@ class Ratelimiter(object):
# * The rate_hz of this particular entry. This can vary per request
self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]]
def can_requester_do_action(
self,
requester: Requester,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
_time_now_s: Optional[int] = None,
) -> Tuple[bool, float]:
"""Can the requester perform the action?
Args:
requester: The requester to key off when rate limiting. The user property
will be used.
rate_hz: The long term number of actions that can be performed in a second.
Overrides the value set during instantiation if set.
burst_count: How many actions that can be performed before being limited.
Overrides the value set during instantiation if set.
update: Whether to count this check as performing the action
_time_now_s: The current time. Optional, defaults to the current time according
to self.clock. Only used by tests.
Returns:
A tuple containing:
* A bool indicating if they can perform the action now
* The reactor timestamp for when the action can be performed next.
-1 if rate_hz is less than or equal to zero
"""
# Disable rate limiting of users belonging to any AS that is configured
# not to be rate limited in its registration file (rate_limited: true|false).
if requester.app_service and not requester.app_service.is_rate_limited():
return True, -1.0
return self.can_do_action(
requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
)
def can_do_action(
self,
key: Any,

View File

@@ -961,11 +961,10 @@ class ServerConfig(Config):
# min_lifetime: 1d
# max_lifetime: 1y
# Retention policy limits. If set, a user won't be able to send a
# 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
# that's not within this range. This is especially useful in closed federations,
# in which server admins can make sure every federating server applies the same
# rules.
# Retention policy limits. If set, and the state of a room contains a
# 'm.room.retention' event in its state which contains a 'min_lifetime' or a
# 'max_lifetime' that's out of these bounds, Synapse will cap the room's policy
# to these limits when running purge jobs.
#
#allowed_lifetime_min: 1d
#allowed_lifetime_max: 1y
@@ -991,12 +990,19 @@ class ServerConfig(Config):
# (e.g. every 12h), but not want that purge to be performed by a job that's
# iterating over every room it knows, which could be heavy on the server.
#
# If any purge job is configured, it is strongly recommended to have at least
# a single job with neither 'shortest_max_lifetime' nor 'longest_max_lifetime'
# set, or one job without 'shortest_max_lifetime' and one job without
# 'longest_max_lifetime' set. Otherwise some rooms might be ignored, even if
# 'allowed_lifetime_min' and 'allowed_lifetime_max' are set, because capping a
# room's policy to these values is done after the policies are retrieved from
# Synapse's database (which is done using the range specified in a purge job's
# configuration).
#
#purge_jobs:
# - shortest_max_lifetime: 1d
# longest_max_lifetime: 3d
# - longest_max_lifetime: 3d
# interval: 12h
# - shortest_max_lifetime: 3d
# longest_max_lifetime: 1y
# interval: 1d
# Inhibits the /requestToken endpoints from returning an error that might leak

View File

@@ -757,9 +757,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name)
return await yieldable_gather_results(
get_key, keys_to_fetch.items()
).addCallback(lambda _: results)
await yieldable_gather_results(get_key, keys_to_fetch.items())
return results
async def get_server_verify_key_v2_direct(self, server_name, key_ids):
"""
@@ -769,7 +768,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
key_ids (iterable[str]):
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
dict[str, FetchKeyResult]: map from key ID to lookup result
Raises:
KeyLookupError if there was a problem making the lookup

View File

@@ -47,7 +47,7 @@ def check(
Args:
room_version_obj: the version of the room
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
auth_events: the existing room state.
Raises:
AuthError if the checks fail

View File

@@ -133,6 +133,8 @@ class _EventInternalMetadata(object):
rejection. This is needed as those events are marked as outliers, but
they still need to be processed as if they're new events (e.g. updating
invite state in the database, relaying to clients, etc).
(Added in synapse 0.99.0, so may be unreliable for events received before that)
"""
return self._dict.get("out_of_band_membership", False)

View File

@@ -15,9 +15,10 @@
# limitations under the License.
import inspect
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional, Tuple
from synapse.spam_checker_api import SpamCheckerApi
from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi
from synapse.types import Collection
MYPY = False
if MYPY:
@@ -160,3 +161,33 @@ class SpamChecker(object):
return True
return False
def check_registration_for_spam(
self,
email_threepid: Optional[dict],
username: Optional[str],
request_info: Collection[Tuple[str, str]],
) -> RegistrationBehaviour:
"""Checks if we should allow the given registration request.
Args:
email_threepid: The email threepid used for registering, if any
username: The request user name, if any
request_info: List of tuples of user agent and IP that
were used during the registration process.
Returns:
Enum for how the request should be handled
"""
for spam_checker in self.spam_checkers:
# For backwards compatibility, only run if the method exists on the
# spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker:
behaviour = checker(email_threepid, username, request_info)
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour
return RegistrationBehaviour.ALLOW

View File

@@ -74,15 +74,14 @@ class EventValidator(object):
)
if event.type == EventTypes.Retention:
self._validate_retention(event, config)
self._validate_retention(event)
def _validate_retention(self, event, config):
def _validate_retention(self, event):
"""Checks that an event that defines the retention policy for a room respects the
boundaries imposed by the server's administrator.
format enforced by the spec.
Args:
event (FrozenEvent): The event to validate.
config (Config): The homeserver's configuration.
"""
min_lifetime = event.content.get("min_lifetime")
max_lifetime = event.content.get("max_lifetime")
@@ -95,32 +94,6 @@ class EventValidator(object):
errcode=Codes.BAD_JSON,
)
if (
config.retention_allowed_lifetime_min is not None
and min_lifetime < config.retention_allowed_lifetime_min
):
raise SynapseError(
code=400,
msg=(
"'min_lifetime' can't be lower than the minimum allowed"
" value enforced by the server's administrator"
),
errcode=Codes.BAD_JSON,
)
if (
config.retention_allowed_lifetime_max is not None
and min_lifetime > config.retention_allowed_lifetime_max
):
raise SynapseError(
code=400,
msg=(
"'min_lifetime' can't be greater than the maximum allowed"
" value enforced by the server's administrator"
),
errcode=Codes.BAD_JSON,
)
if max_lifetime is not None:
if not isinstance(max_lifetime, int):
raise SynapseError(
@@ -129,32 +102,6 @@ class EventValidator(object):
errcode=Codes.BAD_JSON,
)
if (
config.retention_allowed_lifetime_min is not None
and max_lifetime < config.retention_allowed_lifetime_min
):
raise SynapseError(
code=400,
msg=(
"'max_lifetime' can't be lower than the minimum allowed value"
" enforced by the server's administrator"
),
errcode=Codes.BAD_JSON,
)
if (
config.retention_allowed_lifetime_max is not None
and max_lifetime > config.retention_allowed_lifetime_max
):
raise SynapseError(
code=400,
msg=(
"'max_lifetime' can't be greater than the maximum allowed"
" value enforced by the server's administrator"
),
errcode=Codes.BAD_JSON,
)
if (
min_lifetime is not None
and max_lifetime is not None

View File

@@ -28,7 +28,6 @@ from typing import (
Union,
)
from canonicaljson import json
from prometheus_client import Counter, Histogram
from twisted.internet import defer
@@ -63,7 +62,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import glob_to_regex, unwrapFirstError
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -551,7 +550,7 @@ class FederationServer(FederationBase):
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_str)
key_id: json_decoder.decode(json_str)
}
logger.info(

View File

@@ -329,10 +329,10 @@ class FederationSender(object):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
domains = await self.state.get_current_hosts_in_room(room_id)
domains_set = await self.state.get_current_hosts_in_room(room_id)
domains = [
d
for d in domains
for d in domains_set
if d != self.server_name
and self._federation_shard_config.should_handle(self._instance_name, d)
]

View File

@@ -15,8 +15,6 @@
import logging
from typing import TYPE_CHECKING, List, Tuple
from canonicaljson import json
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -28,6 +26,7 @@ from synapse.logging.opentracing import (
tags,
whitelisted_homeserver,
)
from synapse.util import json_decoder
from synapse.util.metrics import measure_func
if TYPE_CHECKING:
@@ -71,7 +70,7 @@ class TransactionManager(object):
for edu in pending_edus:
context = edu.get_context()
if context:
span_contexts.append(extract_text_map(json.loads(context)))
span_contexts.append(extract_text_map(json_decoder.decode(context)))
if keep_destination:
edu.strip_context()

View File

@@ -38,12 +38,14 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import assert_params_in_dict
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.types import Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
@@ -51,6 +53,82 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
def client_dict_convert_legacy_fields_to_identifier(
submission: Dict[str, Union[str, Dict]]
):
"""
Convert a legacy-formatted login submission to an identifier dict.
Legacy login submissions (used in both login and user-interactive authentication)
provide user-identifying information at the top-level instead of in an `indentifier`
property. This is now deprecated and replaced with identifiers:
https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
Args:
submission: The client dict to convert. Passed by reference and modified
Raises:
SynapseError: If the format of the client dict is invalid
"""
if "user" in submission:
submission["identifier"] = {"type": "m.id.user", "user": submission.pop("user")}
if "medium" in submission and "address" in submission:
submission["identifier"] = {
"type": "m.id.thirdparty",
"medium": submission.pop("medium"),
"address": submission.pop("address"),
}
# We've converted valid, legacy login submissions to an identifier. If the
# dict still doesn't have an identifier, it's invalid
assert_params_in_dict(submission, required=["identifier"])
# Ensure the identifier has a type
if "type" not in submission["identifier"]:
raise SynapseError(
400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
)
def login_id_phone_to_thirdparty(identifier: Dict[str, str]) -> Dict[str, str]:
"""Convert a phone login identifier type to a generic threepid identifier.
Args:
identifier: Login identifier dict of type 'm.id.phone'
Returns:
An equivalent m.id.thirdparty identifier dict.
"""
if "type" not in identifier:
raise SynapseError(
400, "Invalid phone-type identifier", errcode=Codes.MISSING_PARAM
)
if "country" not in identifier or (
# XXX: We used to require `number` instead of `phone`. The spec
# defines `phone`. So accept both
"phone" not in identifier
and "number" not in identifier
):
raise SynapseError(
400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
)
# Accept both "phone" and "number" as valid keys in m.id.phone
phone_number = identifier.get("phone", identifier.get("number"))
# Convert user-provided phone number to a consistent representation
msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
# Return the new dictionary
return {
"type": "m.id.thirdparty",
"medium": "msisdn",
"address": msisdn,
}
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@@ -319,7 +397,7 @@ class AuthHandler(BaseHandler):
# otherwise use whatever was last provided.
#
# This was designed to allow the client to omit the parameters
# and just supply the session in subsequent calls so it split
# and just supply the session in subsequent calls. So it splits
# auth between devices by just sharing the session, (eg. so you
# could continue registration from your phone having clicked the
# email auth link on there). It's probably too open to abuse
@@ -364,6 +442,14 @@ class AuthHandler(BaseHandler):
# authentication flow.
await self.store.set_ui_auth_clientdict(sid, clientdict)
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip
)
if not authdict:
raise InteractiveAuthIncompleteError(
session.session_id, self._auth_dict_for_flows(flows, session.session_id)
@@ -516,16 +602,129 @@ class AuthHandler(BaseHandler):
res = await checker.check_auth(authdict, clientip=clientip)
return res
# build a v1-login-style dict out of the authdict and fall back to the
# v1 code
user_id = authdict.get("user")
# We don't have a checker for the auth type provided by the client
# Assume that it is `m.login.password`.
if login_type != LoginType.PASSWORD:
raise SynapseError(
400, "Unknown authentication type", errcode=Codes.INVALID_PARAM,
)
if user_id is None:
raise SynapseError(400, "", Codes.MISSING_PARAM)
password = authdict.get("password")
if password is None:
raise SynapseError(
400,
"Missing parameter for m.login.password dict: 'password'",
errcode=Codes.INVALID_PARAM,
)
# Retrieve the user ID using details provided in the authdict
# Deprecation notice: Clients used to be able to simply provide a
# `user` field which pointed to a user_id or localpart. This has
# been deprecated in favour of an `identifier` key, which is a
# dictionary providing information on how to identify a single
# user.
# https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
#
# We convert old-style dicts to new ones here
client_dict_convert_legacy_fields_to_identifier(authdict)
# Extract a user ID from the values in the identifier
username = await self.username_from_identifier(authdict["identifier"], password)
if username is None:
raise SynapseError(400, "Valid username not found")
# Now that we've found the username, validate that the password is correct
canonical_id, _ = await self.validate_login(username, authdict)
(canonical_id, callback) = await self.validate_login(user_id, authdict)
return canonical_id
async def username_from_identifier(
self, identifier: Dict[str, str], password: Optional[str] = None
) -> Optional[str]:
"""Given a dictionary containing an identifier from a client, extract the
possibly unqualified username of the user that it identifies. Does *not*
guarantee that the user exists.
If this identifier dict contains a threepid, we attempt to ask password
auth providers about it or, failing that, look up an associated user in
the database.
Args:
identifier: The identifier dictionary provided by the client
password: The user provided password if one exists. Used for asking
password auth providers for usernames from 3pid+password combos.
Returns:
A username if one was found, or None otherwise
Raises:
SynapseError: If the identifier dict is invalid
"""
# Convert phone type identifiers to generic threepid identifiers, which
# will be handled in the next step
if identifier["type"] == "m.id.phone":
identifier = login_id_phone_to_thirdparty(identifier)
# Convert a threepid identifier to an user identifier
if identifier["type"] == "m.id.thirdparty":
address = identifier.get("address")
medium = identifier.get("medium")
if not medium or not address:
# An error would've already been raised in
# `login_id_thirdparty_from_phone` if the original submission
# was a phone identifier
raise SynapseError(
400, "Invalid thirdparty identifier", errcode=Codes.INVALID_PARAM,
)
if medium == "email":
# For emails, transform the address to lowercase.
# We store all email addresses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
# Check for auth providers that support 3pid login types
if password is not None:
canonical_user_id, _ = await self.check_password_provider_3pid(
medium, address, password,
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
return canonical_user_id
# Check local store
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
medium, address
)
if not user_id:
# We were unable to find a user_id that belonged to the threepid returned
# by the password auth provider
return None
identifier = {"type": "m.id.user", "user": user_id}
# By this point, the identifier should be a `m.id.user`: if it's anything
# else, we haven't understood it.
if identifier["type"] != "m.id.user":
raise SynapseError(
400, "Unknown login identifier type", errcode=Codes.INVALID_PARAM,
)
# User identifiers have a "user" key
user = identifier.get("user")
if user is None:
raise SynapseError(
400,
"User identifier is missing 'user' key",
errcode=Codes.INVALID_PARAM,
)
return user
def _get_params_recaptcha(self) -> dict:
return {"public_key": self.hs.config.recaptcha_public_key}
@@ -690,7 +889,8 @@ class AuthHandler(BaseHandler):
m.login.password auth types.
Args:
username: username supplied by the user
username: a localpart or fully qualified user ID - what is provided by the
client
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
Returns:
@@ -702,10 +902,10 @@ class AuthHandler(BaseHandler):
LoginError if there was an authentication problem.
"""
if username.startswith("@"):
qualified_user_id = username
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
# We need a fully qualified User ID for some method calls here
qualified_user_id = username
if not qualified_user_id.startswith("@"):
qualified_user_id = UserID(qualified_user_id, self.hs.hostname).to_string()
login_type = login_submission.get("type")
known_login_type = False

View File

@@ -35,6 +35,7 @@ class CasHandler:
"""
def __init__(self, hs):
self.hs = hs
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
@@ -210,8 +211,16 @@ class CasHandler:
else:
if not registered_user_id:
# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(
b"User-Agent", default=[b""]
)[0].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name
localpart=localpart,
default_display_name=user_display_name,
user_agent_ips=(user_agent, ip_address),
)
await self._auth_handler.complete_sso_login(

View File

@@ -16,8 +16,6 @@
import logging
from typing import Any, Dict
from canonicaljson import json
from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
@@ -27,6 +25,7 @@ from synapse.logging.opentracing import (
start_active_span,
)
from synapse.types import UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
@@ -174,7 +173,7 @@ class DeviceMessageHandler(object):
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
"org.matrix.opentracing_context": json.dumps(context),
"org.matrix.opentracing_context": json_encoder.encode(context),
}
log_kv({"local_messages": local_messages})

View File

@@ -23,6 +23,7 @@ from synapse.api.errors import (
CodeMessageException,
Codes,
NotFoundError,
ShadowBanError,
StoreError,
SynapseError,
)
@@ -199,6 +200,8 @@ class DirectoryHandler(BaseHandler):
try:
await self._update_canonical_alias(requester, user_id, room_id, room_alias)
except ShadowBanError as e:
logger.info("Failed to update alias events due to shadow-ban: %s", e)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
@@ -292,6 +295,9 @@ class DirectoryHandler(BaseHandler):
"""
Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field.
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
alias_event = await self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, ""

View File

@@ -19,7 +19,7 @@ import logging
from typing import Dict, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from signedjson.key import VerifyKey, decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
@@ -35,7 +35,7 @@ from synapse.types import (
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
from synapse.util import unwrapFirstError
from synapse.util import json_decoder, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
@@ -404,7 +404,7 @@ class E2eKeysHandler(object):
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
key_id: json_decoder.decode(json_bytes)
}
@trace
@@ -1186,7 +1186,7 @@ def _exception_to_failure(e):
def _one_time_keys_match(old_key_json, new_key):
old_key = json.loads(old_key_json)
old_key = json_decoder.decode(old_key_json)
# if either is a string rather than an object, they must match exactly
if not isinstance(old_key, dict) or not isinstance(new_key, dict):

View File

@@ -1777,9 +1777,7 @@ class FederationHandler(BaseHandler):
"""Returns the state at the event. i.e. not including said event.
"""
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
@@ -1805,9 +1803,7 @@ class FederationHandler(BaseHandler):
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@@ -2138,10 +2134,10 @@ class FederationHandler(BaseHandler):
)
state_sets = list(state_sets.values())
state_sets.append(state)
current_state_ids = await self.state_handler.resolve_events(
current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
current_state_ids = {k: e.event_id for k, e in current_states.items()}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
@@ -2153,11 +2149,13 @@ class FederationHandler(BaseHandler):
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]
current_auth_events = await self.store.get_events(current_state_ids)
auth_events_map = await self.store.get_events(current_state_ids_list)
current_auth_events = {
(e.type, e.state_key): e for e in current_auth_events.values()
(e.type, e.state_key): e for e in auth_events_map.values()
}
try:
@@ -2173,9 +2171,7 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
event = await self.store.get_event(event_id, check_room_id=room_id)
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.

View File

@@ -21,8 +21,6 @@ import logging
import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from canonicaljson import json
from twisted.internet.error import TimeoutError
from synapse.api.errors import (
@@ -34,6 +32,7 @@ from synapse.api.errors import (
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -177,7 +176,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except CodeMessageException as e:
data = json.loads(e.msg) # XXX WAT?
data = json_decoder.decode(e.msg) # XXX WAT?
return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)

View File

@@ -15,9 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from twisted.internet.interfaces import IDelayedCall
@@ -34,6 +35,7 @@ from synapse.api.errors import (
Codes,
ConsentNotGivenError,
NotFoundError,
ShadowBanError,
SynapseError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
@@ -55,6 +57,7 @@ from synapse.types import (
UserID,
create_requester,
)
from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func
@@ -644,24 +647,35 @@ class EventCreationHandler(object):
event: EventBase,
context: EventContext,
ratelimit: bool = True,
ignore_shadow_ban: bool = False,
) -> int:
"""
Persists and notifies local clients and federation of an event.
Args:
requester
event the event to send.
context: the context of the event.
requester: The requester sending the event.
event: The event to send.
context: The context of the event.
ratelimit: Whether to rate limit this send.
ignore_shadow_ban: True if shadow-banned users should be allowed to
send this event.
Return:
The stream_id of the persisted event.
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500, "Tried to send member event through non-member codepath"
)
if not ignore_shadow_ban and requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError()
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
@@ -715,12 +729,28 @@ class EventCreationHandler(object):
event_dict: dict,
ratelimit: bool = True,
txn_id: Optional[str] = None,
ignore_shadow_ban: bool = False,
) -> Tuple[EventBase, int]:
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
Args:
requester: The requester sending the event.
event_dict: An entire event.
ratelimit: Whether to rate limit this send.
txn_id: The transaction ID.
ignore_shadow_ban: True if shadow-banned users should be allowed to
send this event.
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
if not ignore_shadow_ban and requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError()
# We limit the number of concurrent event sends in a room so that we
# don't fork the DAG too much. If we don't limit then we can end up in
@@ -739,7 +769,11 @@ class EventCreationHandler(object):
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
stream_id = await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit
requester,
event,
context,
ratelimit=ratelimit,
ignore_shadow_ban=ignore_shadow_ban,
)
return event, stream_id
@@ -864,7 +898,7 @@ class EventCreationHandler(object):
# Ensure that we can round trip before trying to persist in db
try:
dump = frozendict_json_encoder.encode(event.content)
json.loads(dump)
json_decoder.decode(dump)
except Exception:
logger.exception("Failed to encode content: %r", event.content)
raise
@@ -960,7 +994,7 @@ class EventCreationHandler(object):
allow_none=True,
)
is_admin_redaction = (
is_admin_redaction = bool(
original_event and event.sender != original_event.sender
)
@@ -1080,8 +1114,8 @@ class EventCreationHandler(object):
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
auth_events_map = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@@ -1179,8 +1213,14 @@ class EventCreationHandler(object):
event.internal_metadata.proactively_send = False
# Since this is a dummy-event it is OK if it is sent by a
# shadow-banned user.
await self.send_nonmember_event(
requester, event, context, ratelimit=False
requester,
event,
context,
ratelimit=False,
ignore_shadow_ban=True,
)
dummy_event_sent = True
break

View File

@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@@ -39,6 +38,7 @@ from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -93,6 +93,7 @@ class OidcHandler:
"""
def __init__(self, hs: "HomeServer"):
self.hs = hs
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._client_auth = ClientAuth(
@@ -367,7 +368,7 @@ class OidcHandler:
# and check for an error field. If not, we respond with a generic
# error message.
try:
resp = json.loads(resp_body.decode("utf-8"))
resp = json_decoder.decode(resp_body.decode("utf-8"))
error = resp["error"]
description = resp.get("error_description", error)
except (ValueError, KeyError):
@@ -384,7 +385,7 @@ class OidcHandler:
# Since it is a not a 5xx code, body should be a valid JSON. It will
# raise if not.
resp = json.loads(resp_body.decode("utf-8"))
resp = json_decoder.decode(resp_body.decode("utf-8"))
if "error" in resp:
error = resp["error"]
@@ -689,9 +690,17 @@ class OidcHandler:
self._render_error(request, "invalid_token", str(e))
return
# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)
# Call the mapper to register/login the user
try:
user_id = await self._map_userinfo_to_user(userinfo, token)
user_id = await self._map_userinfo_to_user(
userinfo, token, user_agent, ip_address
)
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
@@ -828,7 +837,9 @@ class OidcHandler:
now = self._clock.time_msec()
return now < expiry
async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
async def _map_userinfo_to_user(
self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str
) -> str:
"""Maps a UserInfo object to a mxid.
UserInfo should have a claim that uniquely identifies users. This claim
@@ -843,6 +854,8 @@ class OidcHandler:
Args:
userinfo: an object representing the user
token: a dict with the tokens obtained from the provider
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Raises:
MappingException: if there was an error while mapping some properties
@@ -899,7 +912,9 @@ class OidcHandler:
# It's the first time this user is logging in and the mapped mxid was
# not taken, register the user
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=attributes["display_name"],
localpart=localpart,
default_display_name=attributes["display_name"],
user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(

View File

@@ -82,6 +82,9 @@ class PaginationHandler(object):
self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
if hs.config.retention_enabled:
# Run the purge jobs described in the configuration file.
for job in hs.config.retention_purge_jobs:
@@ -111,7 +114,7 @@ class PaginationHandler(object):
the range to handle (inclusive). If None, it means that the range has no
upper limit.
"""
# We want the storage layer to to include rooms with no retention policy in its
# We want the storage layer to include rooms with no retention policy in its
# return value only if a default retention policy is defined in the server's
# configuration and that policy's 'max_lifetime' is either lower (or equal) than
# max_ms or higher than min_ms (or both).
@@ -152,13 +155,32 @@ class PaginationHandler(object):
)
continue
max_lifetime = retention_policy["max_lifetime"]
# If max_lifetime is None, it means that the room has no retention policy.
# Given we only retrieve such rooms when there's a default retention policy
# defined in the server's configuration, we can safely assume that's the
# case and use it for this room.
max_lifetime = (
retention_policy["max_lifetime"] or self._retention_default_max_lifetime
)
if max_lifetime is None:
# If max_lifetime is None, it means that include_null equals True,
# therefore we can safely assume that there is a default policy defined
# in the server's configuration.
max_lifetime = self._retention_default_max_lifetime
# Cap the effective max_lifetime to be within the range allowed in the
# config.
# We do this in two steps:
# 1. Make sure it's higher or equal to the minimum allowed value, and if
# it's not replace it with that value. This is because the server
# operator can be required to not delete information before a given
# time, e.g. to comply with freedom of information laws.
# 2. Make sure the resulting value is lower or equal to the maximum allowed
# value, and if it's not replace it with that value. This is because the
# server operator can be required to delete any data after a specific
# amount of time.
if self._retention_allowed_lifetime_min is not None:
max_lifetime = max(self._retention_allowed_lifetime_min, max_lifetime)
if self._retention_allowed_lifetime_max is not None:
max_lifetime = min(max_lifetime, self._retention_allowed_lifetime_max)
logger.debug("[purge] max_lifetime for room %s: %s", room_id, max_lifetime)
# Figure out what token we should start purging at.
ts = self.clock.time_msec() - max_lifetime

View File

@@ -40,7 +40,7 @@ from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateHandler
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.types import Collection, JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
@@ -1318,7 +1318,7 @@ async def get_interested_parties(
async def get_interested_remotes(
store: DataStore, states: List[UserPresenceState], state_handler: StateHandler
) -> List[Tuple[List[str], List[UserPresenceState]]]:
) -> List[Tuple[Collection[str], List[UserPresenceState]]]:
"""Given a list of presence states figure out which remote servers
should be sent which.
@@ -1334,7 +1334,7 @@ async def get_interested_remotes(
each tuple the list of UserPresenceState should be sent to each
destination
"""
hosts_and_states = []
hosts_and_states = [] # type: List[Tuple[Collection[str], List[UserPresenceState]]]
# First we look up the rooms each user is in (as well as any explicit
# subscriptions), then for each distinct room we look up the remote

View File

@@ -26,6 +26,7 @@ from synapse.replication.http.register import (
ReplicationPostRegisterActionsServlet,
ReplicationRegisterServlet,
)
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
@@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid
self.spam_checker = hs.get_spam_checker()
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -124,7 +127,9 @@ class RegistrationHandler(BaseHandler):
try:
int(localpart)
raise SynapseError(
400, "Numeric user IDs are reserved for guest users."
400,
"Numeric user IDs are reserved for guest users.",
errcode=Codes.INVALID_USERNAME,
)
except ValueError:
pass
@@ -142,7 +147,7 @@ class RegistrationHandler(BaseHandler):
address=None,
bind_emails=[],
by_admin=False,
shadow_banned=False,
user_agent_ips=None,
):
"""Registers a new client on the server.
@@ -160,7 +165,8 @@ class RegistrationHandler(BaseHandler):
bind_emails (List[str]): list of emails to bind to this account.
by_admin (bool): True if this registration is being made via the
admin api, otherwise False.
shadow_banned (bool): Shadow-ban the created user.
user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
during the registration process.
Returns:
str: user_id
Raises:
@@ -168,6 +174,24 @@ class RegistrationHandler(BaseHandler):
"""
self.check_registration_ratelimit(address)
result = self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [],
)
if result == RegistrationBehaviour.DENY:
logger.info(
"Blocked registration of %r", localpart,
)
# We return a 429 to make it not obvious that they've been
# denied.
raise SynapseError(429, "Rate limited")
shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
if shadow_banned:
logger.info(
"Shadow banning registration of %r", localpart,
)
# do not check_auth_blocking if the call is coming through the Admin API
if not by_admin:
await self.auth.check_auth_blocking(threepid=threepid)

View File

@@ -20,6 +20,7 @@
import itertools
import logging
import math
import random
import string
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
@@ -135,6 +136,9 @@ class RoomCreationHandler(BaseHandler):
Returns:
the new room id
Raises:
ShadowBanError if the requester is shadow-banned.
"""
await self.ratelimit(requester)
@@ -170,6 +174,15 @@ class RoomCreationHandler(BaseHandler):
async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
"""
Args:
requester: the user requesting the upgrade
old_room_id: the id of the room to be replaced
new_versions: the version to upgrade the room to
Raises:
ShadowBanError if the requester is shadow-banned.
"""
user_id = requester.user.to_string()
# start by allocating a new room id
@@ -256,6 +269,9 @@ class RoomCreationHandler(BaseHandler):
old_room_id: the id of the room to be replaced
new_room_id: the id of the replacement room
old_room_state: the state map for the old room
Raises:
ShadowBanError if the requester is shadow-banned.
"""
old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
@@ -626,6 +642,7 @@ class RoomCreationHandler(BaseHandler):
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
invite_3pid_list = config.get("invite_3pid", [])
invite_list = config.get("invite", [])
for i in invite_list:
try:
@@ -634,6 +651,14 @@ class RoomCreationHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
if (invite_list or invite_3pid_list) and requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
# Allow the request to go through, but remove any associated invites.
invite_3pid_list = []
invite_list = []
await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")
@@ -648,8 +673,6 @@ class RoomCreationHandler(BaseHandler):
% (user_id,),
)
invite_3pid_list = config.get("invite_3pid", [])
visibility = config.get("visibility", None)
is_public = visibility == "public"
@@ -744,6 +767,8 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
# Note that update_membership with an action of "invite" can raise a
# ShadowBanError, but this was handled above by emptying invite_list.
_, last_stream_id = await self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
@@ -758,6 +783,8 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
# Note that do_3pid_invite can raise a ShadowBanError, but this was
# handled above by emptying invite_3pid_list.
last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
@@ -817,11 +844,13 @@ class RoomCreationHandler(BaseHandler):
async def send(etype: str, content: JsonDict, **kwargs) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
# Allow these events to be sent even if the user is shadow-banned to
# allow the room creation to complete.
(
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False
creator, event, ratelimit=False, ignore_shadow_ban=True,
)
return last_stream_id

View File

@@ -15,14 +15,21 @@
import abc
import logging
import random
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64
from synapse import types
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.errors import (
AuthError,
Codes,
LimitExceededError,
ShadowBanError,
SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import EventFormatVersions
from synapse.crypto.event_signing import compute_event_reference_hash
@@ -31,7 +38,15 @@ from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
from synapse.types import (
Collection,
JsonDict,
Requester,
RoomAlias,
RoomID,
StateMap,
UserID,
)
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -210,24 +225,40 @@ class RoomMemberHandler(object):
_, stream_id = await self.store.get_event_ordering(duplicate.event_id)
return duplicate.event_id, stream_id
stream_id = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
newly_joined = False
if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
newly_joined = True
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
if newly_joined:
await self._user_joined_room(target, room_id)
time_now_s = self.clock.time()
(
allowed,
time_allowed,
) = self._join_rate_limiter_local.can_requester_do_action(requester)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
stream_id = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
if event.membership == Membership.JOIN and newly_joined:
# Only fire user_joined_room if the user has actually joined the
# room. Don't bother if the user is just changing their profile
# info.
await self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -285,6 +316,31 @@ class RoomMemberHandler(object):
content: Optional[dict] = None,
require_consent: bool = True,
) -> Tuple[str, int]:
"""Update a user's membership in a room.
Params:
requester: The user who is performing the update.
target: The user whose membership is being updated.
room_id: The room ID whose membership is being updated.
action: The membership change, see synapse.api.constants.Membership.
txn_id: The transaction ID, if given.
remote_room_hosts: Remote servers to send the update to.
third_party_signed: Information from a 3PID invite.
ratelimit: Whether to rate limit the request.
content: The content of the created event.
require_consent: Whether consent is required.
Returns:
A tuple of the new event ID and stream ID.
Raises:
ShadowBanError if a shadow-banned requester attempts to send an invite.
"""
if action == Membership.INVITE and requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError()
key = (room_id,)
with (await self.member_linearizer.queue(key)):
@@ -457,22 +513,12 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if is_host_in_room:
if not is_host_in_room:
time_now_s = self.clock.time()
allowed, time_allowed = self._join_rate_limiter_local.can_do_action(
requester.user.to_string(),
)
if not allowed:
raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
else:
time_now_s = self.clock.time()
allowed, time_allowed = self._join_rate_limiter_remote.can_do_action(
requester.user.to_string(),
)
(
allowed,
time_allowed,
) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
if not allowed:
raise LimitExceededError(
@@ -704,9 +750,7 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
await self._user_left_room(target_user, room_id)
async def _can_guest_join(
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
"""
Returns whether a guest can join a room based on its current state.
"""
@@ -716,7 +760,7 @@ class RoomMemberHandler(object):
guest_access = await self.store.get_event(guest_access_id)
return (
return bool(
guest_access
and guest_access.content
and "guest_access" in guest_access.content
@@ -773,6 +817,25 @@ class RoomMemberHandler(object):
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> int:
"""Invite a 3PID to a room.
Args:
room_id: The room to invite the 3PID to.
inviter: The user sending the invite.
medium: The 3PID's medium.
address: The 3PID's address.
id_server: The identity server to use.
requester: The user making the request.
txn_id: The transaction ID this is part of, or None if this is not
part of a transaction.
id_access_token: The optional identity server access token.
Returns:
The new stream ID.
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
if self.config.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
@@ -780,6 +843,11 @@ class RoomMemberHandler(object):
403, "Invites have been disabled on this server", Codes.FORBIDDEN
)
if requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError()
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
await self.base_handler.ratelimit(requester)
@@ -804,6 +872,8 @@ class RoomMemberHandler(object):
)
if invitee:
# Note that update_membership with an action of "invite" can raise
# a ShadowBanError, but this was done above already.
_, stream_id = await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
@@ -909,9 +979,7 @@ class RoomMemberHandler(object):
)
return stream_id
async def _is_host_in_room(
self, current_state_ids: Dict[Tuple[str, str], str]
) -> bool:
async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
@@ -1042,7 +1110,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id
# The room is too large. Leave.
requester = types.create_requester(user, None, False, None)
requester = types.create_requester(user, None, False, False, None)
await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)

View File

@@ -54,6 +54,7 @@ class Saml2SessionData:
class SamlHandler:
def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
@@ -133,8 +134,14 @@ class SamlHandler:
# the dict.
self.expire_sessions()
# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)
user_id, current_session = await self._map_saml_response_to_user(
resp_bytes, relay_state
resp_bytes, relay_state, user_agent, ip_address
)
# Complete the interactive auth session or the login.
@@ -147,7 +154,11 @@ class SamlHandler:
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
self, resp_bytes: str, client_redirect_url: str
self,
resp_bytes: str,
client_redirect_url: str,
user_agent: str,
ip_address: str,
) -> Tuple[str, Optional[Saml2SessionData]]:
"""
Given a sample response, retrieve the cached session and user for it.
@@ -155,6 +166,8 @@ class SamlHandler:
Args:
resp_bytes: The SAML response.
client_redirect_url: The redirect URL passed in by the client.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
Tuple of the user ID and SAML session associated with this response.
@@ -291,6 +304,7 @@ class SamlHandler:
localpart=localpart,
default_display_name=displayname,
bind_emails=emails,
user_agent_ips=(user_agent, ip_address),
)
await self._datastore.record_user_external_id(

View File

@@ -16,13 +16,12 @@
import logging
from typing import Any
from canonicaljson import json
from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.util import json_decoder
logger = logging.getLogger(__name__)
@@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = json.loads(data.decode("utf-8"))
resp_body = json_decoder.decode(data.decode("utf-8"))
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly

View File

@@ -19,7 +19,7 @@ import urllib
from io import BytesIO
import treq
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from netaddr import IPAddress
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -47,6 +47,7 @@ from synapse.http import (
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
logger = logging.getLogger(__name__)
@@ -391,7 +392,7 @@ class SimpleHttpClient(object):
body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
return json.loads(body.decode("utf-8"))
return json_decoder.decode(body.decode("utf-8"))
else:
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -433,7 +434,7 @@ class SimpleHttpClient(object):
body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
return json.loads(body.decode("utf-8"))
return json_decoder.decode(body.decode("utf-8"))
else:
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -463,7 +464,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
body = await self.get_raw(uri, args, headers=headers)
return json.loads(body.decode("utf-8"))
return json_decoder.decode(body.decode("utf-8"))
async def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
@@ -506,7 +507,7 @@ class SimpleHttpClient(object):
body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
return json.loads(body.decode("utf-8"))
return json_decoder.decode(body.decode("utf-8"))
else:
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import random
import time
@@ -26,7 +25,7 @@ from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock
from synapse.util import Clock, json_decoder
from synapse.util.caches.ttlcache import TTLCache
from synapse.util.metrics import Measure
@@ -181,7 +180,7 @@ class WellKnownResolver(object):
if response.code != 200:
raise Exception("Non-200 response %s" % (response.code,))
parsed_body = json.loads(body.decode("utf-8"))
parsed_body = json_decoder.decode(body.decode("utf-8"))
logger.info("Response from .well-known: %s", parsed_body)
result = parsed_body["m.server"].encode("ascii")

View File

@@ -500,7 +500,7 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
pass
@implementer(interfaces.IPullProducer)
@implementer(interfaces.IPushProducer)
class _ByteProducer:
"""
Iteratively write bytes to the request.
@@ -515,52 +515,64 @@ class _ByteProducer:
):
self._request = request
self._iterator = iterator
self._paused = False
def start(self) -> None:
self._request.registerProducer(self, False)
# Register the producer and start producing data.
self._request.registerProducer(self, True)
self.resumeProducing()
def _send_data(self, data: List[bytes]) -> None:
"""
Send a list of strings as a response to the request.
Send a list of bytes as a chunk of a response.
"""
if not data:
return
self._request.write(b"".join(data))
def pauseProducing(self) -> None:
self._paused = True
def resumeProducing(self) -> None:
# We've stopped producing in the meantime (note that this might be
# re-entrant after calling write).
if not self._request:
return
# Get the next chunk and write it to the request.
#
# The output of the JSON encoder is coalesced until min_chunk_size is
# reached. (This is because JSON encoders produce a very small output
# per iteration.)
#
# Note that buffer stores a list of bytes (instead of appending to
# bytes) to hopefully avoid many allocations.
buffer = []
buffered_bytes = 0
while buffered_bytes < self.min_chunk_size:
try:
data = next(self._iterator)
buffer.append(data)
buffered_bytes += len(data)
except StopIteration:
# The entire JSON object has been serialized, write any
# remaining data, finalize the producer and the request, and
# clean-up any references.
self._send_data(buffer)
self._request.unregisterProducer()
self._request.finish()
self.stopProducing()
return
self._paused = False
self._send_data(buffer)
# Write until there's backpressure telling us to stop.
while not self._paused:
# Get the next chunk and write it to the request.
#
# The output of the JSON encoder is buffered and coalesced until
# min_chunk_size is reached. This is because JSON encoders produce
# very small output per iteration and the Request object converts
# each call to write() to a separate chunk. Without this there would
# be an explosion in bytes written (e.g. b"{" becoming "1\r\n{\r\n").
#
# Note that buffer stores a list of bytes (instead of appending to
# bytes) to hopefully avoid many allocations.
buffer = []
buffered_bytes = 0
while buffered_bytes < self.min_chunk_size:
try:
data = next(self._iterator)
buffer.append(data)
buffered_bytes += len(data)
except StopIteration:
# The entire JSON object has been serialized, write any
# remaining data, finalize the producer and the request, and
# clean-up any references.
self._send_data(buffer)
self._request.unregisterProducer()
self._request.finish()
self.stopProducing()
return
self._send_data(buffer)
def stopProducing(self) -> None:
# Clear a circular reference.
self._request = None
@@ -620,8 +632,7 @@ def respond_with_json(
if send_cors:
set_cors_headers(request)
producer = _ByteProducer(request, encoder(json_object))
producer.start()
_ByteProducer(request, encoder(json_object))
return NOT_DONE_YET

View File

@@ -17,9 +17,8 @@
import logging
from canonicaljson import json
from synapse.api.errors import Codes, SynapseError
from synapse.util import json_decoder
logger = logging.getLogger(__name__)
@@ -215,7 +214,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
return None
try:
content = json.loads(content_bytes.decode("utf-8"))
content = json_decoder.decode(content_bytes.decode("utf-8"))
except Exception as e:
logger.warning("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

View File

@@ -172,11 +172,11 @@ from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Type
import attr
from canonicaljson import json
from twisted.internet import defer
from synapse.config import ConfigError
from synapse.util import json_decoder, json_encoder
if TYPE_CHECKING:
from synapse.http.site import SynapseRequest
@@ -499,7 +499,9 @@ def start_active_span_from_edu(
if opentracing is None:
return _noop_context_manager()
carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
"opentracing", {}
)
context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
_references = [
opentracing.child_of(span_context_from_string(x))
@@ -690,7 +692,7 @@ def active_span_context_as_string():
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
)
return json.dumps(carrier)
return json_encoder.encode(carrier)
@only_if_tracing
@@ -699,7 +701,7 @@ def span_context_from_string(carrier):
Returns:
The active span context decoded from a string.
"""
carrier = json.loads(carrier)
carrier = json_decoder.decode(carrier)
return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)

View File

@@ -175,7 +175,7 @@ def run_as_background_process(desc: str, func, *args, **kwargs):
It returns a Deferred which completes when the function completes, but it doesn't
follow the synapse logcontext rules, which makes it appropriate for passing to
clock.looping_call and friends (or for firing-and-forgetting in the middle of a
normal synapse inlineCallbacks function).
normal synapse async function).
Args:
desc: a description for this background process type

View File

@@ -167,8 +167,10 @@ class ModuleApi(object):
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
return self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
return defer.ensureDeferred(
self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
)
def generate_short_term_login_token(
@@ -223,7 +225,9 @@ class ModuleApi(object):
Returns:
Deferred[object]: result of func
"""
return self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
return defer.ensureDeferred(
self._store.db_pool.runInteraction(desc, func, *args, **kwargs)
)
def complete_sso_login(
self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str

View File

@@ -33,3 +33,11 @@ class SlavedIdTracker(object):
int
"""
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@@ -21,16 +22,13 @@ from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
self._stream_id_gen.get_current_token(),
)
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
# We assert this for the benefit of mypy
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(token)
for row in rows:

View File

@@ -21,9 +21,7 @@ import abc
import logging
from typing import Tuple, Type
from canonicaljson import json
from synapse.util import json_encoder as _json_encoder
from synapse.util import json_decoder, json_encoder
logger = logging.getLogger(__name__)
@@ -125,7 +123,7 @@ class RdataCommand(Command):
stream_name,
instance_name,
None if token == "batch" else int(token),
json.loads(row_json),
json_decoder.decode(row_json),
)
def to_line(self):
@@ -134,7 +132,7 @@ class RdataCommand(Command):
self.stream_name,
self.instance_name,
str(self.token) if self.token is not None else "batch",
_json_encoder.encode(self.row),
json_encoder.encode(self.row),
)
)
@@ -359,7 +357,7 @@ class UserIpCommand(Command):
def from_line(cls, line):
user_id, jsn = line.split(" ", 1)
access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
@@ -367,7 +365,7 @@ class UserIpCommand(Command):
return (
self.user_id
+ " "
+ _json_encoder.encode(
+ json_encoder.encode(
(
self.access_token,
self.ip,

View File

@@ -352,7 +352,7 @@ class PushRulesStream(Stream):
)
def _current_token(self, instance_name: str) -> int:
push_rules_token, _ = self.store.get_push_rules_stream_token()
push_rules_token = self.store.get_max_push_rules_stream_id()
return push_rules_token
@@ -405,7 +405,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token,
store.get_cache_stream_token_for_writer,
store.get_all_updated_caches,
)

View File

@@ -316,6 +316,9 @@ class JoinRoomAliasServlet(RestServlet):
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
# update_membership with an action of "invite" can raise a
# ShadowBanError. This is not handled since it is assumed that
# an admin isn't going to call this API with a shadow-banned user.
await self.room_member_handler.update_membership(
requester=requester,
target=fake_requester.user,

View File

@@ -73,6 +73,7 @@ class UsersRestServletV2(RestServlet):
The parameters `from` and `limit` are required only for pagination.
By default, a `limit` of 100 is used.
The parameter `user_id` can be used to filter by user id.
The parameter `name` can be used to filter by user id or display name.
The parameter `guests` can be used to exclude guest users.
The parameter `deactivated` can be used to include deactivated users.
"""
@@ -89,11 +90,12 @@ class UsersRestServletV2(RestServlet):
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
user_id = parse_string(request, "user_id", default=None)
name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
users, total = await self.store.get_users_paginate(
start, limit, user_id, guests, deactivated
start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
if len(users) >= limit:

View File

@@ -18,6 +18,7 @@ from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.auth import client_dict_convert_legacy_fields_to_identifier
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -28,56 +29,11 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
def login_submission_legacy_convert(submission):
"""
If the input login submission is an old style object
(ie. with top-level user / medium / address) convert it
to a typed object.
"""
if "user" in submission:
submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
del submission["user"]
if "medium" in submission and "address" in submission:
submission["identifier"] = {
"type": "m.id.thirdparty",
"medium": submission["medium"],
"address": submission["address"],
}
del submission["medium"]
del submission["address"]
def login_id_thirdparty_from_phone(identifier):
"""
Convert a phone login identifier type to a generic threepid identifier
Args:
identifier(dict): Login identifier dict of type 'm.id.phone'
Returns: Login identifier dict of type 'm.id.threepid'
"""
if "country" not in identifier or (
# The specification requires a "phone" field, while Synapse used to require a "number"
# field. Accept both for backwards compatibility.
"phone" not in identifier
and "number" not in identifier
):
raise SynapseError(400, "Invalid phone-type identifier")
# Accept both "phone" and "number" as valid keys in m.id.phone
phone_number = identifier.get("phone", identifier["number"])
msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
@@ -167,7 +123,8 @@ class LoginRestServlet(RestServlet):
result = await self._do_token_login(login_submission)
else:
result = await self._do_other_login(login_submission)
except KeyError:
except KeyError as e:
logger.debug("KeyError during login: %s", e)
raise SynapseError(400, "Missing JSON keys.")
well_known_data = self._well_known_builder.get_well_known()
@@ -194,27 +151,14 @@ class LoginRestServlet(RestServlet):
login_submission.get("address"),
login_submission.get("user"),
)
login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
raise SynapseError(400, "Missing param: identifier")
identifier = login_submission["identifier"]
if "type" not in identifier:
raise SynapseError(400, "Login identifier has no type")
# convert phone type identifiers to generic threepids
if identifier["type"] == "m.id.phone":
identifier = login_id_thirdparty_from_phone(identifier)
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
address = identifier.get("address")
medium = identifier.get("medium")
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
# Convert deprecated authdict formats to the current scheme
client_dict_convert_legacy_fields_to_identifier(login_submission)
# Check whether this attempt uses a threepid, if so, check if our failed attempt
# ratelimiter allows another attempt at this time
medium = login_submission.get("medium")
address = login_submission.get("address")
if medium and address:
# For emails, canonicalise the address.
# We store all email addresses canonicalised in the DB.
# (See add_threepid in synapse/handlers/auth.py)
@@ -224,74 +168,41 @@ class LoginRestServlet(RestServlet):
except ValueError as e:
raise SynapseError(400, str(e))
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
# Check for login providers that support 3pid login types
(
canonical_user_id,
callback_3pid,
) = await self.auth_handler.check_password_provider_3pid(
medium, address, login_submission["password"]
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
result = await self._complete_login(
canonical_user_id, login_submission, callback_3pid
)
return result
# No password providers were able to handle this 3pid
# Check local store
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
medium, address
)
if not user_id:
logger.warning(
"unknown 3pid identifier medium %s, address %r", medium, address
)
# We mark that we've failed to log in here, as
# `check_password_provider_3pid` might have returned `None` due
# to an incorrect password, rather than the account not
# existing.
#
# If it returned None but the 3PID was bound then we won't hit
# this code path, which is fine as then the per-user ratelimit
# will kick in below.
self._failed_attempts_ratelimiter.can_do_action((medium, address))
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {"type": "m.id.user", "user": user_id}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
if identifier["type"] != "m.id.user":
raise SynapseError(400, "Unknown login identifier type")
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
if identifier["user"].startswith("@"):
qualified_user_id = identifier["user"]
else:
qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
# Check if we've hit the failed ratelimit (but don't update it)
self._failed_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(), update=False
# Extract a localpart or user ID from the values in the identifier
username = await self.auth_handler.username_from_identifier(
login_submission["identifier"], login_submission.get("password")
)
if not username:
if medium and address:
# The user attempted to login via threepid and failed
# Record this failed attempt using the threepid as a key, as otherwise
# the user could bypass the ratelimiter by not providing a username
self._failed_attempts_ratelimiter.can_do_action(
(medium, address.lower())
)
raise LoginError(403, "Unauthorized threepid", errcode=Codes.FORBIDDEN)
# The login failed for another reason
raise LoginError(403, "Invalid login", errcode=Codes.FORBIDDEN)
# We were able to extract a username successfully
# Check if we've hit the failed ratelimit for this user ID
self._failed_attempts_ratelimiter.ratelimit(username.lower(), update=False)
try:
canonical_user_id, callback = await self.auth_handler.validate_login(
identifier["user"], login_submission
username, login_submission
)
except LoginError:
# The user has failed to log in, so we need to update the rate
# limiter. Using `can_do_action` avoids us raising a ratelimit
# exception and masking the LoginError. The actual ratelimiting
# should have happened above.
self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
# exception and masking the LoginError. This just records the attempt.
# The actual rate-limiting happens above
self._failed_attempts_ratelimiter.can_do_action(username.lower())
raise
result = await self._complete_login(
@@ -309,7 +220,7 @@ class LoginRestServlet(RestServlet):
create_non_existent_users: bool = False,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
actually log them in (e.g. create devices). This gets called on
all successful logins.
Applies the ratelimiting for successful login attempts against an

View File

@@ -159,7 +159,7 @@ class PushRuleRestServlet(RestServlet):
return 200, {}
def notify_user(self, user_id):
stream_id, _ = self.store.get_push_rules_stream_token()
stream_id = self.store.get_max_push_rules_stream_id()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
async def set_rule_attr(self, user_id, spec, val):

View File

@@ -21,14 +21,13 @@ import re
from typing import List, Optional
from urllib import parse as urlparse
from canonicaljson import json
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
HttpResponseException,
InvalidClientCredentialsError,
ShadowBanError,
SynapseError,
)
from synapse.api.filtering import Filter
@@ -46,6 +45,8 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from synapse.util import json_decoder
from synapse.util.stringutils import random_string
MYPY = False
if MYPY:
@@ -200,23 +201,26 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if state_key is not None:
event_dict["state_key"] = state_key
if event_type == EventTypes.Member:
membership = content.get("membership", None)
event_id, _ = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
action=membership,
content=content,
)
else:
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
event_id = event.event_id
try:
if event_type == EventTypes.Member:
membership = content.get("membership", None)
event_id, _ = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
action=membership,
content=content,
)
else:
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
event_id = event.event_id
except ShadowBanError:
event_id = "$" + random_string(43)
set_tag("event_id", event_id)
ret = {"event_id": event_id}
@@ -249,12 +253,19 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
try:
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
event_id = event.event_id
except ShadowBanError:
event_id = "$" + random_string(43)
set_tag("event_id", event.event_id)
return 200, {"event_id": event.event_id}
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
def on_GET(self, request, room_id, event_type, txn_id):
return 200, "Not implemented"
@@ -519,7 +530,9 @@ class RoomMessageListRestServlet(RestServlet):
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
event_filter = Filter(
json_decoder.decode(filter_json)
) # type: Optional[Filter]
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
@@ -631,7 +644,9 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
event_filter = Filter(
json_decoder.decode(filter_json)
) # type: Optional[Filter]
else:
event_filter = None
@@ -716,16 +731,20 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content):
await self.room_member_handler.do_3pid_invite(
room_id,
requester.user,
content["medium"],
content["address"],
content["id_server"],
requester,
txn_id,
content.get("id_access_token"),
)
try:
await self.room_member_handler.do_3pid_invite(
room_id,
requester.user,
content["medium"],
content["address"],
content["id_server"],
requester,
txn_id,
content.get("id_access_token"),
)
except ShadowBanError:
# Pretend the request succeeded.
pass
return 200, {}
target = requester.user
@@ -737,15 +756,19 @@ class RoomMembershipRestServlet(TransactionRestServlet):
if "reason" in content:
event_content = {"reason": content["reason"]}
await self.room_member_handler.update_membership(
requester=requester,
target=target,
room_id=room_id,
action=membership_action,
txn_id=txn_id,
third_party_signed=content.get("third_party_signed", None),
content=event_content,
)
try:
await self.room_member_handler.update_membership(
requester=requester,
target=target,
room_id=room_id,
action=membership_action,
txn_id=txn_id,
third_party_signed=content.get("third_party_signed", None),
content=event_content,
)
except ShadowBanError:
# Pretend the request succeeded.
pass
return_value = {}
@@ -783,20 +806,27 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"redacts": event_id,
},
txn_id=txn_id,
)
try:
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"redacts": event_id,
},
txn_id=txn_id,
)
event_id = event.event_id
except ShadowBanError:
event_id = "$" + random_string(43)
set_tag("event_id", event.event_id)
return 200, {"event_id": event.event_id}
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
def on_PUT(self, request, room_id, event_id, txn_id):
set_tag("txn_id", txn_id)

View File

@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from http import HTTPStatus
from synapse.api.constants import LoginType
@@ -109,6 +110,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if self.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@@ -448,6 +452,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if self.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -516,6 +523,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)

View File

@@ -16,6 +16,7 @@
import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID
@@ -325,6 +326,9 @@ class GroupRoomServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s was not legal group ID" % (group_id,))
result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id
)

View File

@@ -16,6 +16,7 @@
import hmac
import logging
import random
from typing import List, Union
import synapse
@@ -131,6 +132,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -203,6 +207,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
if self.hs.config.request_token_inhibit_3pid_errors:
# Make the client think the operation succeeded. See the rationale in the
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
await self.hs.clock.sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(
@@ -591,12 +598,17 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_IN_USE,
)
entries = await self.store.get_user_agents_ips_to_ui_auth_session(
session_id
)
registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
password_hash=password_hash,
guest_access_token=guest_access_token,
threepid=threepid,
address=client_addr,
user_agent_ips=entries,
)
# Necessary due to auth checks prior to the threepid being
# written to the db

View File

@@ -22,7 +22,7 @@ any time to reflect changes in the MSC.
import logging
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.api.errors import ShadowBanError, SynapseError
from synapse.http.servlet import (
RestServlet,
parse_integer,
@@ -35,6 +35,7 @@ from synapse.storage.relations import (
PaginationChunk,
RelationPaginationToken,
)
from synapse.util.stringutils import random_string
from ._base import client_patterns
@@ -111,11 +112,18 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(),
}
event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id
)
try:
(
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id
)
event_id = event.event_id
except ShadowBanError:
event_id = "$" + random_string(43)
return 200, {"event_id": event.event_id}
return 200, {"event_id": event_id}
class RelationPaginationServlet(RestServlet):

View File

@@ -15,13 +15,14 @@
import logging
from synapse.api.errors import Codes, SynapseError
from synapse.api.errors import Codes, ShadowBanError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.util import stringutils
from ._base import client_patterns
@@ -62,7 +63,6 @@ class RoomUpgradeRestServlet(RestServlet):
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("new_version",))
new_version = content["new_version"]
new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"])
if new_version is None:
@@ -72,9 +72,13 @@ class RoomUpgradeRestServlet(RestServlet):
Codes.UNSUPPORTED_ROOM_VERSION,
)
new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version
)
try:
new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version
)
except ShadowBanError:
# Generate a random room ID.
new_room_id = stringutils.random_string(18)
ret = {"replacement_room": new_room_id}

View File

@@ -16,8 +16,6 @@
import itertools
import logging
from canonicaljson import json
from synapse.api.constants import PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
@@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken
from synapse.util import json_decoder
from ._base import client_patterns, set_timeline_upper_limit
@@ -125,7 +124,7 @@ class SyncRestServlet(RestServlet):
filter_collection = DEFAULT_FILTER_COLLECTION
elif filter_id.startswith("{"):
try:
filter_object = json.loads(filter_id)
filter_object = json_decoder.decode(filter_id)
set_timeline_upper_limit(
filter_object, self.hs.config.filter_timeline_limit
)

View File

@@ -15,19 +15,19 @@
import logging
from typing import Dict, Set
from canonicaljson import json
from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.util import json_decoder
logger = logging.getLogger(__name__)
class RemoteKey(DirectServeJsonResource):
"""HTTP resource for retreiving the TLS certificate and NACL signature
"""HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
that the NACL signature for the remote server is valid. Returns a dict of
@@ -209,13 +209,15 @@ class RemoteKey(DirectServeJsonResource):
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
# If there is a cache miss, request the missing keys, then recurse (and
# ensure the result is sent).
if cache_misses and query_remote_on_cache_miss:
await self.fetcher.get_keys(cache_misses)
await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
signed_keys = []
for key_json in json_results:
key_json = json.loads(key_json.decode("utf-8"))
key_json = json_decoder.decode(key_json.decode("utf-8"))
for signing_key in self.config.key_server_signing_keys:
key_json = sign_json(key_json, self.config.server_name, signing_key)

View File

@@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from twisted.web.resource import Resource
from synapse.http.server import set_cors_headers
from synapse.util import json_encoder
logger = logging.getLogger(__name__)
@@ -67,4 +67,4 @@ class WellKnownResource(Resource):
logger.debug("returning: %s", r)
request.setHeader(b"Content-Type", b"application/json")
return json.dumps(r).encode("utf-8")
return json_encoder.encode(r).encode("utf-8")

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from enum import Enum
from twisted.internet import defer
@@ -25,6 +26,16 @@ if MYPY:
logger = logging.getLogger(__name__)
class RegistrationBehaviour(Enum):
"""
Enum to define whether a registration request should allowed, denied, or shadow-banned.
"""
ALLOW = "allow"
SHADOW_BAN = "shadow_ban"
DENY = "deny"
class SpamCheckerApi(object):
"""A proxy object that gets passed to spam checkers so they can get
access to rooms and other relevant information.
@@ -48,8 +59,10 @@ class SpamCheckerApi(object):
twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
The filtered state events in the room.
"""
state_ids = yield self._store.get_filtered_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types)
state_ids = yield defer.ensureDeferred(
self._store.get_filtered_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types)
)
)
state = yield self._store.get_events(state_ids.values())
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
return state.values()

View File

@@ -16,11 +16,22 @@
import logging
from collections import namedtuple
from typing import Awaitable, Dict, Iterable, List, Optional, Set
from typing import (
Awaitable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Union,
overload,
)
import attr
from frozendict import frozendict
from prometheus_client import Histogram
from typing_extensions import Literal
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@@ -30,7 +41,7 @@ from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap
from synapse.types import Collection, StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -68,8 +79,14 @@ def _gen_state_id():
class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
# dict[(str, str), str] map from (type, state_key) to event_id
def __init__(
self,
state: StateMap[str],
state_group: Optional[int],
prev_group: Optional[int] = None,
delta_ids: Optional[StateMap[str]] = None,
):
# A map from (type, state_key) to event_id.
self.state = frozendict(state)
# the ID of a state group if one and only one is involved.
@@ -107,24 +124,49 @@ class StateHandler(object):
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
@overload
async def get_current_state(
self, room_id, event_type=None, state_key="", latest_event_ids=None
):
""" Retrieves the current state for the room. This is done by
self,
room_id: str,
event_type: Literal[None] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> StateMap[EventBase]:
...
@overload
async def get_current_state(
self,
room_id: str,
event_type: str,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Optional[EventBase]:
...
async def get_current_state(
self,
room_id: str,
event_type: Optional[str] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Union[Optional[EventBase], StateMap[EventBase]]:
"""Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
This is equivalent to getting the state of an event that were to send
next before receiving any new events.
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
Returns:
map from (type, state_key) to event
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
Otherwise, a map from (type, state_key) to event.
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
@@ -140,34 +182,30 @@ class StateHandler(object):
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
state = {
return {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
return state
async def get_current_state_ids(self, room_id, latest_event_ids=None):
async def get_current_state_ids(
self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room
Args:
room_id (str):
latest_event_ids (iterable[str]|None): if given, the forward
extremities to resolve. If None, we look them up from the
database (via a cache)
room_id:
latest_event_ids: if given, the forward extremities to resolve. If
None, we look them up from the database (via a cache).
Returns:
Deferred[dict[(str, str), str)]]: the state dict, mapping from
(event_type, state_key) -> event_id
the state dict, mapping from (event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
return state
return dict(ret.state)
async def get_current_users_in_room(
self, room_id: str, latest_event_ids: Optional[List[str]] = None
@@ -183,32 +221,34 @@ class StateHandler(object):
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = await self.store.get_joined_users_from_state(room_id, entry)
return joined_users
return await self.store.get_joined_users_from_state(room_id, entry)
async def get_current_hosts_in_room(self, room_id):
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events(self, room_id, event_ids):
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: List[str]
) -> Set[str]:
"""Get the hosts that were in a room at the given event ids
Args:
room_id (str):
event_ids (list[str]):
room_id:
event_ids:
Returns:
Deferred[list[str]]: the hosts in the room at the given events
The hosts in the room at the given events
"""
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = await self.store.get_joined_hosts(room_id, entry)
return joined_hosts
return await self.store.get_joined_hosts(room_id, entry)
async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
):
) -> EventContext:
"""Build an EventContext structure for the event.
This works out what the current state should be for the event, and
@@ -221,7 +261,7 @@ class StateHandler(object):
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
Returns:
synapse.events.snapshot.EventContext:
The event context.
"""
if event.internal_metadata.is_outlier():
@@ -275,7 +315,7 @@ class StateHandler(object):
event.room_id, event.prev_event_ids()
)
state_ids_before_event = entry.state
state_ids_before_event = dict(entry.state)
state_group_before_event = entry.state_group
state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids
@@ -346,19 +386,18 @@ class StateHandler(object):
)
@measure_func()
async def resolve_state_groups_for_events(self, room_id, event_ids):
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Iterable[str]
) -> _StateCacheEntry:
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
Args:
room_id (str)
event_ids (list[str])
explicit_room_version (str|None): If set uses the the given room
version to choose the resolution algorithm. If None, then
checks the database for room version.
room_id
event_ids
Returns:
Deferred[_StateCacheEntry]: resolved state
The resolved state
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
@@ -394,7 +433,12 @@ class StateHandler(object):
)
return result
async def resolve_events(self, room_version, state_sets, event):
async def resolve_events(
self,
room_version: str,
state_sets: Collection[Iterable[EventBase]],
event: EventBase,
) -> StateMap[EventBase]:
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
@@ -414,9 +458,7 @@ class StateHandler(object):
state_res_store=StateResolutionStore(self.store),
)
new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()}
return new_state
return {key: state_map[ev_id] for key, ev_id in new_state.items()}
class StateResolutionHandler(object):
@@ -444,7 +486,12 @@ class StateResolutionHandler(object):
@log_function
async def resolve_state_groups(
self, room_id, room_version, state_groups_ids, event_map, state_res_store
self,
room_id: str,
room_version: str,
state_groups_ids: Dict[int, StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):
"""Resolves conflicts between a set of state groups
@@ -452,13 +499,13 @@ class StateResolutionHandler(object):
not be called for a single state group
Args:
room_id (str): room we are resolving for (used for logging and sanity checks)
room_version (str): version of the room
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
room_id: room we are resolving for (used for logging and sanity checks)
room_version: version of the room
state_groups_ids:
A map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
event_map(dict[str,FrozenEvent]|None):
event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
@@ -466,10 +513,10 @@ class StateResolutionHandler(object):
If None, all events will be fetched via state_res_store.
state_res_store (StateResolutionStore)
state_res_store
Returns:
_StateCacheEntry: resolved state
The resolved state
"""
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
@@ -530,21 +577,22 @@ class StateResolutionHandler(object):
return cache
def _make_state_cache_entry(new_state, state_groups_ids):
def _make_state_cache_entry(
new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
) -> _StateCacheEntry:
"""Given a resolved state, and a set of input state groups, pick one to base
a new state group on (if any), and return an appropriately-constructed
_StateCacheEntry.
Args:
new_state (dict[(str, str), str]): resolved state map (mapping from
(type, state_key) to event_id)
new_state: resolved state map (mapping from (type, state_key) to event_id)
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
state_groups_ids:
map from state group id to the state in that state group (where
'state' is a map from state key to event id)
Returns:
_StateCacheEntry
The cache entry.
"""
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
@@ -585,7 +633,7 @@ def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
state_sets: List[StateMap[str]],
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
) -> Awaitable[StateMap[str]]:
@@ -633,15 +681,17 @@ class StateResolutionStore(object):
store = attr.ib()
def get_events(self, event_ids, allow_rejected=False):
def get_events(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database
Args:
event_ids (list): The event_ids of the events to fetch
allow_rejected (bool): If True return rejected events.
event_ids: The event_ids of the events to fetch
allow_rejected: If True return rejected events.
Returns:
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
An awaitable which resolves to a dict from event_id to event.
"""
return self.store.get_events(
@@ -651,7 +701,9 @@ class StateResolutionStore(object):
allow_rejected=allow_rejected,
)
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
def get_auth_chain_difference(
self, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -660,7 +712,7 @@ class StateResolutionStore(object):
chain.
Returns:
Deferred[Set[str]]: Set of event IDs.
An awaitable that resolves to a set of event IDs.
"""
return self.store.get_auth_chain_difference(state_sets)

View File

@@ -15,7 +15,17 @@
import hashlib
import logging
from typing import Awaitable, Callable, Dict, List, Optional
from typing import (
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
)
from synapse import event_auth
from synapse.api.constants import EventTypes
@@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "")
async def resolve_events_with_store(
room_id: str,
state_sets: List[StateMap[str]],
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_map_factory: Callable[[List[str]], Awaitable],
):
state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
) -> StateMap[str]:
"""
Args:
room_id: the room we are working in
@@ -56,8 +66,7 @@ async def resolve_events_with_store(
an Awaitable that resolves to a dict of event_id to event.
Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
A map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -75,8 +84,8 @@ async def resolve_events_with_store(
"Asking for %d/%d conflicted events", len(needed_events), needed_event_count
)
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
# the state events which are in conflict (and those in event_map)
# A map from state event id to event. Only includes the state events which
# are in conflict (and those in event_map).
state_map = await state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
@@ -91,8 +100,6 @@ async def resolve_events_with_store(
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
#
# dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
@@ -122,29 +129,30 @@ async def resolve_events_with_store(
)
def _seperate(state_sets):
def _seperate(
state_sets: Iterable[StateMap[str]],
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
Args:
state_sets(iterable[dict[(str, str), str]]):
state_sets:
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
(dict[(str, str), str], dict[(str, str), set[str]]):
A tuple of (unconflicted_state, conflicted_state), where:
A tuple of (unconflicted_state, conflicted_state), where:
unconflicted_state is a dict mapping (type, state_key)->event_id
for unconflicted state keys.
unconflicted_state is a dict mapping (type, state_key)->event_id
for unconflicted state keys.
conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys.
conflicted_state is a dict mapping (type, state_key) to a set of
event ids for conflicted state keys.
"""
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {}
conflicted_state = {} # type: StateMap[Set[str]]
for state_set in state_set_iterator:
for key, value in state_set.items():
@@ -171,7 +179,21 @@ def _seperate(state_sets):
return unconflicted_state, conflicted_state
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
def _create_auth_events_from_maps(
unconflicted_state: StateMap[str],
conflicted_state: StateMap[Set[str]],
state_map: Dict[str, EventBase],
) -> StateMap[str]:
"""
Args:
unconflicted_state: The unconflicted state map.
conflicted_state: The conflicted state map.
state_map:
Returns:
A map from state key to event id.
"""
auth_events = {}
for event_ids in conflicted_state.values():
for event_id in event_ids:
@@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
auth_event_id = unconflicted_state.get(key, None)
if auth_event_id:
auth_events[key] = auth_event_id
return auth_events
def _resolve_with_state(
unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
unconflicted_state_ids: StateMap[str],
conflicted_state_ids: StateMap[Set[str]],
auth_event_ids: StateMap[str],
state_map: Dict[str, EventBase],
):
conflicted_state = {}
for key, event_ids in conflicted_state_ids.items():
@@ -215,7 +240,9 @@ def _resolve_with_state(
return new_state
def _resolve_state_events(conflicted_state, auth_events):
def _resolve_state_events(
conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase]
) -> StateMap[EventBase]:
""" This is where we actually decide which of the conflicted state to
use.
@@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events):
return resolved_state
def _resolve_auth_events(events, auth_events):
def _resolve_auth_events(
events: List[EventBase], auth_events: StateMap[EventBase]
) -> EventBase:
reverse = list(reversed(_ordered_events(events)))
auth_keys = {
@@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events):
return event
def _resolve_normal_events(events, auth_events):
def _resolve_normal_events(
events: List[EventBase], auth_events: StateMap[EventBase]
) -> EventBase:
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
@@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events):
return event
def _ordered_events(events):
def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
def key_func(e):
# we have to use utf-8 rather than ascii here because it turns out we allow
# people to send us events with non-ascii event IDs :/

View File

@@ -16,7 +16,21 @@
import heapq
import itertools
import logging
from typing import Dict, List, Optional
from typing import (
Any,
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Set,
Tuple,
overload,
)
from typing_extensions import Literal
import synapse.state
from synapse import event_auth
@@ -40,10 +54,10 @@ async def resolve_events_with_store(
clock: Clock,
room_id: str,
room_version: str,
state_sets: List[StateMap[str]],
state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
):
) -> StateMap[str]:
"""Resolves the state using the v2 state resolution algorithm
Args:
@@ -63,8 +77,7 @@ async def resolve_events_with_store(
state_res_store:
Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
A map from (type, state_key) to event_id.
"""
logger.debug("Computing conflicted state")
@@ -171,18 +184,23 @@ async def resolve_events_with_store(
return resolved_state
async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
async def _get_power_level_for_sender(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> int:
"""Return the power level of the sender of the given event according to
their auth events.
Args:
room_id (str)
event_id (str)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
room_id
event_id
event_map
state_res_store
Returns:
Deferred[int]
The power level.
"""
event = await _get_event(room_id, event_id, event_map, state_res_store)
@@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st
return int(level)
async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
async def _get_auth_chain_difference(
state_sets: Sequence[StateMap[str]],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains.
Args:
state_sets (list)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
state_sets
event_map
state_res_store
Returns:
Deferred[set[str]]: Set of event IDs
Set of event IDs
"""
difference = await state_res_store.get_auth_chain_difference(
@@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
return difference
def _seperate(state_sets):
def _seperate(
state_sets: Iterable[StateMap[str]],
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
"""Return the unconflicted and conflicted state. This is different than in
the original algorithm, as this defines a key to be conflicted if one of
the state sets doesn't have that key.
Args:
state_sets (list)
state_sets
Returns:
tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
conflicted state dict is a map from type/state_key to set of event IDs
A tuple of unconflicted and conflicted state. The conflicted state dict
is a map from type/state_key to set of event IDs
"""
unconflicted_state = {}
conflicted_state = {}
@@ -260,18 +284,20 @@ def _seperate(state_sets):
event_ids.discard(None)
conflicted_state[key] = event_ids
return unconflicted_state, conflicted_state
# mypy doesn't understand that discarding None above means that conflicted
# state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
return unconflicted_state, conflicted_state # type: ignore
def _is_power_event(event):
def _is_power_event(event: EventBase) -> bool:
"""Return whether or not the event is a "power event", as defined by the
v2 state resolution algorithm
Args:
event (FrozenEvent)
event
Returns:
boolean
True if the event is a power event.
"""
if (event.type, event.state_key) in (
(EventTypes.PowerLevels, ""),
@@ -288,19 +314,23 @@ def _is_power_event(event):
async def _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
):
graph: Dict[str, Set[str]],
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
auth_diff: Set[str],
) -> None:
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
Args:
graph (dict[str, set[str]]): A map from event ID to the events auth
event IDs
room_id (str): the room we are working in
event_id (str): Event to add to the graph
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
auth_diff (set[str]): Set of event IDs that are in the auth difference.
graph: A map from event ID to the events auth event IDs
room_id: the room we are working in
event_id: Event to add to the graph
event_map
state_res_store
auth_diff: Set of event IDs that are in the auth difference.
"""
state = [event_id]
@@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph(
async def _reverse_topological_power_sort(
clock, room_id, event_ids, event_map, state_res_store, auth_diff
):
clock: Clock,
room_id: str,
event_ids: Iterable[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
auth_diff: Set[str],
) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
Args:
clock (Clock)
room_id (str): the room we are working in
event_ids (list[str]): The events to sort
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
auth_diff (set[str]): Set of event IDs that are in the auth difference.
clock
room_id: the room we are working in
event_ids: The events to sort
event_map
state_res_store
auth_diff: Set of event IDs that are in the auth difference.
Returns:
Deferred[list[str]]: The sorted list
The sorted list
"""
graph = {}
graph = {} # type: Dict[str, Set[str]]
for idx, event_id in enumerate(event_ids, start=1):
await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
@@ -372,22 +407,28 @@ async def _reverse_topological_power_sort(
async def _iterative_auth_checks(
clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
):
clock: Clock,
room_id: str,
room_version: str,
event_ids: List[str],
base_state: StateMap[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> StateMap[str]:
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
Args:
clock (Clock)
room_id (str)
room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to
base_state (StateMap[str]): The set of state to start with
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
clock
room_id
room_version
event_ids: Ordered list of events to apply auth checks to
base_state: The set of state to start with
event_map
state_res_store
Returns:
Deferred[StateMap[str]]: Returns the final updated state
Returns the final updated state
"""
resolved_state = base_state.copy()
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@@ -439,21 +480,26 @@ async def _iterative_auth_checks(
async def _mainline_sort(
clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
):
clock: Clock,
room_id: str,
event_ids: List[str],
resolved_power_event_id: Optional[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> List[str]:
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Args:
clock (Clock)
room_id (str): room we're working in
event_ids (list[str]): Events to sort
resolved_power_event_id (str): The final resolved power level event ID
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
clock
room_id: room we're working in
event_ids: Events to sort
resolved_power_event_id: The final resolved power level event ID
event_map
state_res_store
Returns:
Deferred[list[str]]: The sorted list
The sorted list
"""
if not event_ids:
# It's possible for there to be no event IDs here to sort, so we can
@@ -505,59 +551,90 @@ async def _mainline_sort(
async def _get_mainline_depth_for_event(
event, mainline_map, event_map, state_res_store
):
event: EventBase,
mainline_map: Dict[str, int],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> int:
"""Get the mainline depths for the given event based on the mainline map
Args:
event (FrozenEvent)
mainline_map (dict[str, int]): Map from event_id to mainline depth for
events in the mainline.
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
event
mainline_map: Map from event_id to mainline depth for events in the mainline.
event_map
state_res_store
Returns:
Deferred[int]
The mainline depth
"""
room_id = event.room_id
tmp_event = event # type: Optional[EventBase]
# We do an iterative search, replacing `event with the power level in its
# auth events (if any)
while event:
while tmp_event:
depth = mainline_map.get(event.event_id)
if depth is not None:
return depth
auth_events = event.auth_event_ids()
event = None
auth_events = tmp_event.auth_event_ids()
tmp_event = None
for aid in auth_events:
aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
event = aev
tmp_event = aev
break
# Didn't find a power level auth event, so we just return 0
return 0
async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
@overload
async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
allow_none: Literal[False] = False,
) -> EventBase:
...
@overload
async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
allow_none: Literal[True],
) -> Optional[EventBase]:
...
async def _get_event(
room_id: str,
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
allow_none: bool = False,
) -> Optional[EventBase]:
"""Helper function to look up event in event_map, falling back to looking
it up in the store
Args:
room_id (str)
event_id (str)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
allow_none (bool): if the event is not found, return None rather than raising
room_id
event_id
event_map
state_res_store
allow_none: if the event is not found, return None rather than raising
an exception
Returns:
Deferred[Optional[FrozenEvent]]
The event, or none if the event does not exist (and allow_none is True).
"""
if event_id not in event_map:
events = await state_res_store.get_events([event_id], allow_rejected=True)
@@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F
return event
def lexicographical_topological_sort(graph, key):
def lexicographical_topological_sort(
graph: Dict[str, Set[str]], key: Callable[[str], Any]
) -> Generator[str, None, None]:
"""Performs a lexicographic reverse topological sort on the graph.
This returns a reverse topological sort (i.e. if node A references B then B
@@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key):
NOTE: `graph` is modified during the sort.
Args:
graph (dict[str, set[str]]): A representation of the graph where each
node is a key in the dict and its value are the nodes edges.
key (func): A function that takes a node and returns a value that is
comparable and used to order nodes
graph: A representation of the graph where each node is a key in the
dict and its value are the nodes edges.
key: A function that takes a node and returns a value that is comparable
and used to order nodes
Yields:
str: The next node in the topological sort
The next node in the topological sort
"""
# Note, this is basically Kahn's algorithm except we look at nodes with no
# outgoing edges, c.f.
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
outdegree_map = graph
reverse_graph = {}
reverse_graph = {} # type: Dict[str, Set[str]]
# Lists of nodes with zero out degree. Is actually a tuple of
# `(key(node), node)` so that sorting does the right thing

View File

@@ -19,12 +19,11 @@ import random
from abc import ABCMeta
from typing import Any, Optional
from canonicaljson import json
from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.types import Collection, get_domain_from_id
from synapse.util import json_decoder
logger = logging.getLogger(__name__)
@@ -99,13 +98,13 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
# Decode it to a Unicode string before feeding it to json.loads, since
# Decode it to a Unicode string before feeding it to the JSON decoder, since
# Python 3.5 does not support deserializing bytes.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode("utf8")
try:
return json.loads(db_content)
return json_decoder.decode(db_content)
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise

View File

@@ -16,9 +16,8 @@
import logging
from typing import Optional
from canonicaljson import json
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import json_encoder
from . import engines
@@ -457,7 +456,7 @@ class BackgroundUpdater(object):
progress(dict): The progress of the update.
"""
progress_json = json.dumps(progress)
progress_json = json_encoder.encode(progress)
self.db_pool.simple_update_one_txn(
txn,

File diff suppressed because it is too large Load Diff

View File

@@ -87,12 +87,21 @@ class Databases(object):
logger.info("Database %r prepared", db_name)
# Closing the context manager doesn't close the connection.
# psycopg will close the connection when the object gets GCed, but *only*
# if the PID is the same as when the connection was opened [1], and
# it may not be if we fork in the meantime.
#
# [1]: https://github.com/psycopg/psycopg2/blob/2_8_5/psycopg/connection_type.c#L1378
db_conn.close()
# Sanity check that we have actually configured all the required stores.
if not main:
raise Exception("No 'main' data store configured")
if not state:
raise Exception("No 'main' data store configured")
raise Exception("No 'state' data store configured")
# We use local variables here to ensure that the databases do not have
# optional types.

View File

@@ -498,7 +498,7 @@ class DataStore(
)
def get_users_paginate(
self, start, limit, name=None, guests=True, deactivated=False
self, start, limit, user_id=None, name=None, guests=True, deactivated=False
):
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
@@ -507,7 +507,8 @@ class DataStore(
Args:
start (int): start number to begin the query from
limit (int): number of rows to retrieve
name (string): filter for user names
user_id (string): search for user_id. ignored if name is not None
name (string): search for local part of user_id or display name
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
Returns:
@@ -516,11 +517,14 @@ class DataStore(
def get_users_paginate_txn(txn):
filters = []
args = []
args = [self.hs.config.server_name]
if name:
filters.append("(name LIKE ? OR displayname LIKE ?)")
args.extend(["@%" + name + "%:%", "%" + name + "%"])
elif user_id:
filters.append("name LIKE ?")
args.append("%" + name + "%")
args.extend(["%" + user_id + "%"])
if not guests:
filters.append("is_guest = 0")
@@ -530,20 +534,23 @@ class DataStore(
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
txn.execute(sql, args)
count = txn.fetchone()[0]
args = [self.hs.config.server_name] + args + [limit, start]
sql = """
SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
sql_base = """
FROM users as u
LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
{}
ORDER BY u.name LIMIT ? OFFSET ?
""".format(
where_clause
)
sql = "SELECT COUNT(*) as total_users " + sql_base
txn.execute(sql, args)
count = txn.fetchone()[0]
sql = (
"SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url "
+ sql_base
+ " ORDER BY u.name LIMIT ? OFFSET ?"
)
args += [limit, start]
txn.execute(sql, args)
users = self.db_pool.cursor_to_dict(txn)
return users, count

View File

@@ -336,7 +336,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
with self._account_data_id_gen.get_next() as next_id:
with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -384,7 +384,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
with self._account_data_id_gen.get_next() as next_id:
with await self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.

Some files were not shown because too many files have changed in this diff Show More