1
0

Compare commits

...

38 Commits

Author SHA1 Message Date
Richard van der Hoff 311c15dd4f put a cache on /state_ids 2020-07-18 22:18:42 +01:00
Richard van der Hoff 6d174fda89 Merge tag 'v1.17.0-mod1' into rav/modular_hacks 2020-07-18 20:47:56 +01:00
Richard van der Hoff 75da9f7a8e Abort requests if the client disconnects early 2020-07-18 20:01:03 +01:00
Jason Robinson 7078251569 Don't fail start on non-C locale postgres database
Instead just warn. This is a temporary workaround to allow
provisioning new Modular hosts.

See https://github.com/matrix-org/synapse/pull/6734
and https://github.com/matrix-org/synapse/issues/6696#issuecomment-575280941

Signed-off-by: Jason Robinson <jasonr@matrix.org>
2020-07-13 11:31:34 +01:00
Richard van der Hoff 29df3d0e9f 1.17.0 2020-07-13 10:20:36 +01:00
Richard van der Hoff 8ccb7f08d9 Merge branch 'master' into release-v1.17.0 2020-07-10 18:38:18 +01:00
Richard van der Hoff 43726783e4 1.17.0rc1 2020-07-09 16:53:19 +01:00
Patrick Cloke 38e1fac886 Fix some spelling mistakes / typos. (#7811) 2020-07-09 09:52:58 -04:00
Richard van der Hoff 53ee214f2f update_membership declaration: now always returns an event id. (#7809) 2020-07-09 13:01:42 +01:00
Richard van der Hoff 8ca39bd2c3 Improve stacktraces from exceptions in background processes (#7808)
use `Failure()` to fish out the real exception.
2020-07-09 13:01:33 +01:00
Richard van der Hoff 08c5181a8d Fix can only concatenate list (not "tuple") to list exception (#7810)
It seems auth_events can be either a list or a tuple, depending on Things.
2020-07-09 12:48:15 +01:00
Patrick Cloke 8fa7fdd4cb Pass original request headers from workers to the main process. (#7797) 2020-07-09 07:34:46 -04:00
Richard van der Hoff 2ab0b021f1 Generate real events when we reject invites (#7804)
Fixes #2181. 

The basic premise is that, when we
fail to reject an invite via the remote server, we can generate our own
out-of-band leave event and persist it as an outlier, so that we have something
to send to the client.
2020-07-09 10:40:19 +01:00
Richard van der Hoff 67593b1728 Add HomeServer.signing_key property (#7805)
... instead of duplicating `config.signing_key[0]` everywhere
2020-07-08 17:51:56 +01:00
Richard van der Hoff ef5ed5292b Revert "Update the installation docs on apt-transport-https (#7801)"
This reverts commit e0c0129693.

As discussed at
https://github.com/matrix-org/synapse/pull/7801#pullrequestreview-444652786, I
don't think this is an improvement.
2020-07-08 16:57:10 +01:00
Patrick Cloke e7efd8f827 Do not use simplejson in Synapse. (#7800) 2020-07-08 07:15:08 -04:00
Patrick Cloke ff0680f69d Stop passing bytes when dumping JSON (#7799) 2020-07-08 07:14:56 -04:00
Dirk Heinrichs e0c0129693 Update the installation docs on apt-transport-https (#7801)
* Starting with apt 1.6, https support has moved into the main package and apt-transport-https has become a transitional dummy package.

Signed-off-by: Dirk Heinrichs <dirk.heinrichs@altum.de>
2020-07-08 11:34:13 +01:00
Richard van der Hoff 59ddcd790b Merge branch 'master' into develop 2020-07-08 11:25:34 +01:00
Nicolai Søborg 96bb01d8ec Change Caddy links (old is deprecated) (#7789)
* Change Caddy links

Current links points to Caddy v1 which is deprecated.

Signed-off-by: Nicolai Søborg <git@xn--sb-lka.org>
2020-07-08 10:09:16 +01:00
Richard van der Hoff 76dbd7b8d6 Stop populating unused table local_invites. (#7793)
This table is no longer used, so we may as well stop populating it. Removing it
would prevent people rolling back to older releases of Synapse, so that can
happen in a future release.
2020-07-07 14:20:40 +01:00
Erik Johnston 67d7756fcf Refactor getting replication updates from database v2. (#7740) 2020-07-07 12:11:35 +01:00
Juho Vanhanen d378c3da78 Add libwebp dependency to Dockerfile (#7791)
* Add libwebp dependency to Dockerfile

Signed-off-by: Juho Vanhanen <juho@vanhanen.io>
2020-07-06 13:37:39 +01:00
Patrick Cloke 2a266f4511 Add documentation for JWT login type and improve sample config. (#7776) 2020-07-06 08:31:51 -04:00
Patrick Cloke 6d687ebba1 Convert the appservice handler to async/await. (#7775) 2020-07-06 07:40:35 -04:00
reivilibre 57feeab364 Don't ignore set_tweak actions with no explicit value. (#7766)
* Fix spec compliance; tweaks without values are valid

(default to True, which is only concretely specified for
`highlight`, but it seems only reasonable to generalise)

* Changelog for 7766.

* Add documentation to `tweaks_for_actions`

May as well tidy up when I'm here.

* Add a test for `tweaks_for_actions`
2020-07-06 11:43:41 +01:00
Oliver Kurz 4e118742ca Allow to use higher versions of prometheus_client (#7780)
Fixes https://github.com/matrix-org/synapse/issues/7641

The package was pinned to <0.8.0 without an obvious reasoning with
7ad1d7635
in https://github.com/matrix-org/synapse/pull/5636
while the version selection looks to just try to exclude an arbitrary
next minor version number that might introduce API breaking changes.
Selecting the next minor number might be a good conservative selection.

Downstream distributions already reported success patching out the version
requirements.

This also fixes the integration of upgraded packages into openSUSE packages,
e.g. for openSUSE Tumbleweed which already ships prometheus_client >= 0.8 .

Signed-off-by: Oliver Kurz <okurz@suse.de>

Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
2020-07-06 10:21:41 +01:00
Will Hunt 62b1ce8539 isort 5 compatibility (#7786)
The CI appears to use the latest version of isort, which is a problem when isort gets a major version bump. Rather than try to pin the version, I've done the necessary to make isort5 happy with synapse.
2020-07-05 16:32:02 +01:00
Erik Johnston 5cdca53aa0 Merge different Resource implementation classes (#7732) 2020-07-03 19:02:19 +01:00
Dirk Klimpel 21a212f8e5 Fix inconsistent handling of upper and lower cases of email addresses. (#7021)
fixes #7016
2020-07-03 14:03:13 +01:00
Alex Kotov 8097659f6e Allow YAML config file to contain None (#7779)
Useful when config file is fully commented

Signed-off-by: Alex Kotov <kotovalexarian@gmail.com>
2020-07-03 13:19:03 +01:00
Patrick Cloke f3e0f16240 Merge tag 'v1.16.0rc2' into develop
Synapse 1.16.0rc2 (2020-07-02)
==============================

Synapse 1.16.0rc2 includes the security fixes released with Synapse 1.15.2.
Please see [below](https://github.com/matrix-org/synapse/blob/master/CHANGES.md#synapse-1152-2020-07-02) for more details.

Improved Documentation
----------------------

- Update postgres image in example `docker-compose.yaml` to tag `12-alpine`. ([\#7696](https://github.com/matrix-org/synapse/issues/7696))

Internal Changes
----------------

- Add some metrics for inbound and outbound federation latencies: `synapse_federation_server_pdu_process_time` and `synapse_event_processing_lag_by_event`. ([\#7771](https://github.com/matrix-org/synapse/issues/7771))
2020-07-02 11:25:56 -04:00
Patrick Cloke 4d978d7db4 Merge branch 'master' into develop 2020-07-02 10:55:41 -04:00
reivilibre e5808c4cfb Hack to add push priority to push notifications (#7765)
* Remove obsolete comment about ancient temporary code

Signed-off-by: Olivier Wilkinson (reivilibre) <olivier@librepush.net>

* Implement hack to set push priority

based on whether the tweaks indicate the event might cause
effects.

* Changelog for 7765

Signed-off-by: Olivier Wilkinson (reivilibre) <olivier@librepush.net>

* Antilint

* Add tests for push priority

Signed-off-by: Olivier Wilkinson (reivilibre) <olivier@librepush.net>

* Update synapse/push/httppusher.py

Co-authored-by: Brendan Abolivier <babolivier@matrix.org>

* Antilint

* Remove needless invites from tests.

Co-authored-by: Brendan Abolivier <babolivier@matrix.org>
2020-07-01 17:02:31 +01:00
Richard van der Hoff e866512367 Add early returns to _check_for_soft_fail (#7769)
my editor was complaining about unset variables, so let's add some early
returns to fix that and reduce indentation/cognitive load.
2020-07-01 16:41:19 +01:00
Richard van der Hoff f01e2ca039 Use symbolic names for replication stream names (#7768)
This makes it much easier to find where streams are referenced.
2020-07-01 16:35:40 +01:00
Richard van der Hoff a6eae69ffe Type checking for FederationHandler (#7770)
fix a few things to make this pass mypy.
2020-07-01 16:21:02 +01:00
Richard van der Hoff 244dbb04f7 Fix incorrect error message when database CTYPE was set incorrectly. (#7760) 2020-07-01 13:56:16 +01:00
131 changed files with 2136 additions and 1185 deletions
+54 -3
View File
@@ -1,9 +1,13 @@
Synapse 1.17.0 (2020-07-13)
===========================
Synapse 1.17.0 is identical to 1.17.0rc1, with the addition of the fix that was included in 1.16.1.
Synapse 1.16.1 (2020-07-10)
===========================
In some distributions of Synapse 1.16.0, we incorrectly included a database
migration which added a new, unused table. This release removes the redundant
table.
In some distributions of Synapse 1.16.0, we incorrectly included a database migration which added a new, unused table. This release removes the redundant table.
Bugfixes
--------
@@ -11,6 +15,53 @@ Bugfixes
- Drop table `local_rejections_stream` which was incorrectly added in Synapse 1.16.0. ([\#7816](https://github.com/matrix-org/synapse/issues/7816), [b1beb3ff5](https://github.com/matrix-org/synapse/commit/b1beb3ff5))
Synapse 1.17.0rc1 (2020-07-09)
==============================
Bugfixes
--------
- Fix inconsistent handling of upper and lower case in email addresses when used as identifiers for login, etc. Contributed by @dklimpel. ([\#7021](https://github.com/matrix-org/synapse/issues/7021))
- Fix "Tried to close a non-active scope!" error messages when opentracing is enabled. ([\#7732](https://github.com/matrix-org/synapse/issues/7732))
- Fix incorrect error message when database CTYPE was set incorrectly. ([\#7760](https://github.com/matrix-org/synapse/issues/7760))
- Fix to not ignore `set_tweak` actions in Push Rules that have no `value`, as permitted by the specification. ([\#7766](https://github.com/matrix-org/synapse/issues/7766))
- Fix synctl to handle empty config files correctly. Contributed by @kotovalexarian. ([\#7779](https://github.com/matrix-org/synapse/issues/7779))
- Fixes a long standing bug in worker mode where worker information was saved in the devices table instead of the original IP address and user agent. ([\#7797](https://github.com/matrix-org/synapse/issues/7797))
- Fix 'stuck invites' which happen when we are unable to reject a room invite received over federation. ([\#7804](https://github.com/matrix-org/synapse/issues/7804), [\#7809](https://github.com/matrix-org/synapse/issues/7809), [\#7810](https://github.com/matrix-org/synapse/issues/7810))
Updates to the Docker image
---------------------------
- Include libwebp in the Docker file to properly handle webp image uploads. ([\#7791](https://github.com/matrix-org/synapse/issues/7791))
Improved Documentation
----------------------
- Improve the documentation of the non-standard JSON web token login type. ([\#7776](https://github.com/matrix-org/synapse/issues/7776))
- Update doc links for caddy. Contributed by Nicolai Søborg. ([\#7789](https://github.com/matrix-org/synapse/issues/7789))
Internal Changes
----------------
- Refactor getting replication updates from database. ([\#7740](https://github.com/matrix-org/synapse/issues/7740))
- Send push notifications with a high or low priority depending upon whether they may generate user-observable effects. ([\#7765](https://github.com/matrix-org/synapse/issues/7765))
- Use symbolic names for replication stream names. ([\#7768](https://github.com/matrix-org/synapse/issues/7768))
- Add early returns to `_check_for_soft_fail`. ([\#7769](https://github.com/matrix-org/synapse/issues/7769))
- Fix up `synapse.handlers.federation` to pass mypy. ([\#7770](https://github.com/matrix-org/synapse/issues/7770))
- Convert the appserver handler to async/await. ([\#7775](https://github.com/matrix-org/synapse/issues/7775))
- Allow to use higher versions of prometheus_client <0.9.0 which are expected to introduce no breaking changes. Contributed by Oliver Kurz. ([\#7780](https://github.com/matrix-org/synapse/issues/7780))
- Update linting scripts and codebase to be compatible with `isort` v5. ([\#7786](https://github.com/matrix-org/synapse/issues/7786))
- Stop populating unused table `local_invites`. ([\#7793](https://github.com/matrix-org/synapse/issues/7793))
- Ensure that strings (not bytes) are passed into JSON serialization. ([\#7799](https://github.com/matrix-org/synapse/issues/7799))
- Switch from simplejson to the standard library json. ([\#7800](https://github.com/matrix-org/synapse/issues/7800))
- Add `signing_key` property to `HomeServer` to save code duplication. ([\#7805](https://github.com/matrix-org/synapse/issues/7805))
- Improve stacktraces from exceptions in background processes. ([\#7808](https://github.com/matrix-org/synapse/issues/7808))
- Fix various spelling errors in comments and log lines. ([\#7811](https://github.com/matrix-org/synapse/issues/7811))
Synapse 1.16.0 (2020-07-08)
===========================
+1 -1
View File
@@ -215,7 +215,7 @@ Using a reverse proxy with Synapse
It is recommended to put a reverse proxy such as
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
`Caddy <https://caddyserver.com/docs/proxy>`_ or
`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_ or
`HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
doing so is that it means that you can expose the default https port (443) to
Matrix clients without needing to run Synapse with root privileges.
+12
View File
@@ -1,9 +1,21 @@
matrix-synapse-py3 (1.17.0) stable; urgency=medium
* New synapse release 1.17.0.
-- Synapse Packaging team <packages@matrix.org> Mon, 13 Jul 2020 10:20:31 +0100
matrix-synapse-py3 (1.16.1) stable; urgency=medium
* New synapse release 1.16.1.
-- Synapse Packaging team <packages@matrix.org> Fri, 10 Jul 2020 12:09:24 +0100
matrix-synapse-py3 (1.17.0rc1) stable; urgency=medium
* New synapse release 1.17.0rc1.
-- Synapse Packaging team <packages@matrix.org> Thu, 09 Jul 2020 16:53:12 +0100
matrix-synapse-py3 (1.16.0) stable; urgency=medium
* New synapse release 1.16.0.
+2
View File
@@ -24,6 +24,7 @@ RUN apk add \
build-base \
libffi-dev \
libjpeg-turbo-dev \
libwebp-dev \
libressl-dev \
libxslt-dev \
linux-headers \
@@ -61,6 +62,7 @@ FROM docker.io/python:${PYTHON_VERSION}-alpine3.11
RUN apk add --no-cache --virtual .runtime_deps \
libffi \
libjpeg-turbo \
libwebp \
libressl \
libxslt \
libpq \
+90
View File
@@ -0,0 +1,90 @@
# JWT Login Type
Synapse comes with a non-standard login type to support
[JSON Web Tokens](https://en.wikipedia.org/wiki/JSON_Web_Token). In general the
documentation for
[the login endpoint](https://matrix.org/docs/spec/client_server/r0.6.1#login)
is still valid (and the mechanism works similarly to the
[token based login](https://matrix.org/docs/spec/client_server/r0.6.1#token-based)).
To log in using a JSON Web Token, clients should submit a `/login` request as
follows:
```json
{
"type": "org.matrix.login.jwt",
"token": "<jwt>"
}
```
Note that the login type of `m.login.jwt` is supported, but is deprecated. This
will be removed in a future version of Synapse.
The `jwt` should encode the local part of the user ID as the standard `sub`
claim. In the case that the token is not valid, the homeserver must respond with
`401 Unauthorized` and an error code of `M_UNAUTHORIZED`.
(Note that this differs from the token based logins which return a
`403 Forbidden` and an error code of `M_FORBIDDEN` if an error occurs.)
As with other login types, there are additional fields (e.g. `device_id` and
`initial_device_display_name`) which can be included in the above request.
## Preparing Synapse
The JSON Web Token integration in Synapse uses the
[`PyJWT`](https://pypi.org/project/pyjwt/) library, which must be installed
as follows:
* The relevant libraries are included in the Docker images and Debian packages
provided by `matrix.org` so no further action is needed.
* If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip
install synapse[pyjwt]` to install the necessary dependencies.
* For other installation mechanisms, see the documentation provided by the
maintainer.
To enable the JSON web token integration, you should then add an `jwt_config` section
to your configuration file (or uncomment the `enabled: true` line in the
existing section). See [sample_config.yaml](./sample_config.yaml) for some
sample settings.
## How to test JWT as a developer
Although JSON Web Tokens are typically generated from an external server, the
examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
1. Configure Synapse with JWT logins:
```yaml
jwt_config:
enabled: true
secret: "my-secret-token"
algorithm: "HS256"
```
2. Generate a JSON web token:
```bash
$ pyjwt --key=my-secret-token --alg=HS256 encode sub=test-user
eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXIifQ.Ag71GT8v01UO3w80aqRPTeuVPBIBZkYhNTJJ-_-zQIc
```
3. Query for the login types and ensure `org.matrix.login.jwt` is there:
```bash
curl http://localhost:8080/_matrix/client/r0/login
```
4. Login used the generated JSON web token from above:
```bash
$ curl http://localhost:8082/_matrix/client/r0/login -X POST \
--data '{"type":"org.matrix.login.jwt","token":"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXIifQ.Ag71GT8v01UO3w80aqRPTeuVPBIBZkYhNTJJ-_-zQIc"}'
{
"access_token": "<access token>",
"device_id": "ACBDEFGHI",
"home_server": "localhost:8080",
"user_id": "@test-user:localhost:8480"
}
```
You should now be able to use the returned access token to query the client API.
+1 -1
View File
@@ -3,7 +3,7 @@
It is recommended to put a reverse proxy such as
[nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html),
[Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html),
[Caddy](https://caddyserver.com/docs/proxy) or
[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or
[HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage
of doing so is that it means that you can expose the default https port
(443) to Matrix clients without needing to run Synapse with root
+31 -4
View File
@@ -1804,12 +1804,39 @@ sso:
#template_dir: "res/templates"
# The JWT needs to contain a globally unique "sub" (subject) claim.
# JSON web token integration. The following settings can be used to make
# Synapse JSON web tokens for authentication, instead of its internal
# password database.
#
# Each JSON Web Token needs to contain a "sub" (subject) claim, which is
# used as the localpart of the mxid.
#
# Note that this is a non-standard login type and client support is
# expected to be non-existant.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md.
#
#jwt_config:
# enabled: true
# secret: "a secret"
# algorithm: "HS256"
# Uncomment the following to enable authorization using JSON web
# tokens. Defaults to false.
#
#enabled: true
# This is either the private shared secret or the public key used to
# decode the contents of the JSON web token.
#
# Required if 'enabled' is true.
#
#secret: "provided-by-your-issuer"
# The algorithm used to sign the JSON web token.
#
# Supported algorithms are listed at
# https://pyjwt.readthedocs.io/en/latest/algorithms.html
#
# Required if 'enabled' is true.
#
#algorithm: "provided-by-your-issuer"
password_config:
+1 -1
View File
@@ -2,9 +2,9 @@ import argparse
import json
import logging
import sys
import urllib2
import dns.resolver
import urllib2
from signedjson.key import decode_verify_key_bytes, write_signing_keys
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
+1 -1
View File
@@ -15,7 +15,7 @@ else
fi
echo "Linting these locations: $files"
isort -y -rc $files
isort $files
python3 -m black $files
./scripts-dev/config-lint.sh
flake8 $files
-1
View File
@@ -26,7 +26,6 @@ ignore=W503,W504,E203,E731,E501
[isort]
line_length = 88
not_skip = __init__.py
sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER
default_section=THIRDPARTY
known_first_party = synapse
+1 -1
View File
@@ -36,7 +36,7 @@ try:
except ImportError:
pass
__version__ = "1.16.1"
__version__ = "1.17.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
+2 -3
View File
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
@@ -22,7 +21,6 @@ from netaddr import IPAddress
from twisted.internet import defer
from twisted.web.server import Request
import synapse.logging.opentracing as opentracing
import synapse.types
from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking
@@ -35,6 +33,7 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.logging import opentracing as opentracing
from synapse.types import StateMap, UserID
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache
@@ -538,7 +537,7 @@ class Auth(object):
# Currently we ignore the `for_verification` flag even though there are
# some situations where we can drop particular auth events when adding
# to the event's `auth_events` (e.g. joins pointing to previous joins
# when room is publically joinable). Dropping event IDs has the
# when room is publicly joinable). Dropping event IDs has the
# advantage that the auth chain for the room grows slower, but we use
# the auth chain in state resolution v2 to order events, which means
# care must be taken if dropping events to ensure that it doesn't
+25 -5
View File
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
from typing_extensions import ContextManager
from twisted.internet import defer, reactor
from twisted.internet import address, defer, reactor
import synapse
import synapse.events
@@ -206,10 +206,30 @@ class KeyUploadServlet(RestServlet):
if body:
# They're actually trying to upload something, proxy to main synapse.
# Pass through the auth headers, if any, in case the access token
# is there.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
headers = {"Authorization": auth_headers}
# Proxy headers from the original request, such as the auth headers
# (in case the access token is there) and the original IP /
# User-Agent of the request.
headers = {
header: request.requestHeaders.getRawHeaders(header, [])
for header in (b"Authorization", b"User-Agent")
}
# Add the previous hop the the X-Forwarded-For header.
x_forwarded_for = request.requestHeaders.getRawHeaders(
b"X-Forwarded-For", []
)
if isinstance(request.client, (address.IPv4Address, address.IPv6Address)):
previous_host = request.client.host.encode("ascii")
# If the header exists, add to the comma-separated list of the first
# instance of the header. Otherwise, generate a new header.
if x_forwarded_for:
x_forwarded_for = [
x_forwarded_for[0] + b", " + previous_host
] + x_forwarded_for[1:]
else:
x_forwarded_for = [previous_host]
headers[b"X-Forwarded-For"] = x_forwarded_for
try:
result = await self.http_client.post_json_get_json(
self.main_uri + request.uri.decode("ascii"), body, headers=headers
-1
View File
@@ -98,7 +98,6 @@ class ApplicationServiceApi(SimpleHttpClient):
if service.url is None:
return False
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
response = None
try:
response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
+1
View File
@@ -16,6 +16,7 @@ from synapse.config._base import ConfigError
if __name__ == "__main__":
import sys
from synapse.config.homeserver import HomeServerConfig
action = sys.argv[1]
+2 -3
View File
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
# This file can't be called email.py because if it is, we cannot:
@@ -73,7 +72,7 @@ class EmailConfig(Config):
template_dir = email_config.get("template_dir")
# we need an absolute path, because we change directory after starting (and
# we don't yet know what auxilliary templates like mail.css we will need).
# we don't yet know what auxiliary templates like mail.css we will need).
# (Note that loading as package_resources with jinja.PackageLoader doesn't
# work for the same reason.)
if not template_dir:
@@ -145,8 +144,8 @@ class EmailConfig(Config):
or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
):
# make sure we can import the required deps
import jinja2
import bleach
import jinja2
# prevent unused warnings
jinja2
+31 -4
View File
@@ -45,10 +45,37 @@ class JWTConfig(Config):
def generate_config_section(self, **kwargs):
return """\
# The JWT needs to contain a globally unique "sub" (subject) claim.
# JSON web token integration. The following settings can be used to make
# Synapse JSON web tokens for authentication, instead of its internal
# password database.
#
# Each JSON Web Token needs to contain a "sub" (subject) claim, which is
# used as the localpart of the mxid.
#
# Note that this is a non-standard login type and client support is
# expected to be non-existant.
#
# See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md.
#
#jwt_config:
# enabled: true
# secret: "a secret"
# algorithm: "HS256"
# Uncomment the following to enable authorization using JSON web
# tokens. Defaults to false.
#
#enabled: true
# This is either the private shared secret or the public key used to
# decode the contents of the JSON web token.
#
# Required if 'enabled' is true.
#
#secret: "provided-by-your-issuer"
# The algorithm used to sign the JSON web token.
#
# Supported algorithms are listed at
# https://pyjwt.readthedocs.io/en/latest/algorithms.html
#
# Required if 'enabled' is true.
#
#algorithm: "provided-by-your-issuer"
"""
+1 -1
View File
@@ -162,7 +162,7 @@ class EventBuilderFactory(object):
def __init__(self, hs):
self.clock = hs.get_clock()
self.hostname = hs.hostname
self.signing_key = hs.config.signing_key[0]
self.signing_key = hs.signing_key
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
+3 -3
View File
@@ -87,7 +87,7 @@ class FederationClient(FederationBase):
self.transport_layer = hs.get_federation_transport_client()
self.hostname = hs.hostname
self.signing_key = hs.config.signing_key[0]
self.signing_key = hs.signing_key
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
@@ -245,7 +245,7 @@ class FederationClient(FederationBase):
event_id: event to fetch
room_version: version of the room
outlier: Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
it's from an arbitrary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
timeout: How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
@@ -351,7 +351,7 @@ class FederationClient(FederationBase):
outlier: bool = False,
include_none: bool = False,
) -> List[EventBase]:
"""Takes a list of PDUs and checks the signatures and hashs of each
"""Takes a list of PDUs and checks the signatures and hashes of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
+14 -5
View File
@@ -95,6 +95,9 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
self._state_ids_resp_cache = ResponseCache(
hs, "state_ids_resp", timeout_ms=30000
)
async def on_backfill_request(
self, origin: str, room_id: str, versions: List[str], limit: int
@@ -362,10 +365,16 @@ class FederationServer(FederationBase):
if not in_room:
raise AuthError(403, "Host not in room.")
resp = await self._state_ids_resp_cache.wrap(
(room_id, event_id), self._on_state_ids_request_compute, room_id, event_id,
)
return 200, resp
async def _on_state_ids_request_compute(self, room_id, event_id):
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(
self, room_id: str, event_id: str
@@ -717,7 +726,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
logger.warning("Ignorning non-bool allow_ip_literals flag")
logger.warning("Ignoring non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
@@ -731,7 +740,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
logger.warning("Ignorning non-list deny ACL %s", deny)
logger.warning("Ignoring non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
@@ -741,7 +750,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
logger.warning("Ignorning non-list allow ACL %s", allow)
logger.warning("Ignoring non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
+1 -1
View File
@@ -359,7 +359,7 @@ class BaseFederationRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
TypeId = "" # Unique string that ids the type. Must be overridden in sub classes.
@staticmethod
def from_data(data):
@@ -119,7 +119,7 @@ class PerDestinationQueue(object):
)
def send_pdu(self, pdu: EventBase, order: int) -> None:
"""Add a PDU to the queue, and start the transmission loop if neccessary
"""Add a PDU to the queue, and start the transmission loop if necessary
Args:
pdu: pdu to send
@@ -129,7 +129,7 @@ class PerDestinationQueue(object):
self.attempt_new_transaction()
def send_presence(self, states: Iterable[UserPresenceState]) -> None:
"""Add presence updates to the queue. Start the transmission loop if neccessary.
"""Add presence updates to the queue. Start the transmission loop if necessary.
Args:
states: presence to send
+1 -1
View File
@@ -746,7 +746,7 @@ class TransportLayerClient(object):
def remove_user_from_group(
self, destination, group_id, requester_user_id, user_id, content
):
"""Remove a user fron a group
"""Remove a user from a group
"""
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
+9 -7
View File
@@ -109,7 +109,7 @@ class Authenticator(object):
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.notifer = hs.get_notifier()
self.notifier = hs.get_notifier()
self.replication_client = None
if hs.config.worker.worker_app:
@@ -175,7 +175,7 @@ class Authenticator(object):
await self.store.set_destination_retry_timings(origin, None, 0, 0)
# Inform the relevant places that the remote server is back up.
self.notifer.notify_remote_server_up(origin)
self.notifier.notify_remote_server_up(origin)
if self.replication_client:
# If we're on a worker we try and inform master about this. The
# replication client doesn't hook into the notifier to avoid
@@ -340,6 +340,12 @@ class BaseFederationServlet(object):
if origin:
with ratelimiter.ratelimit(origin) as d:
await d
if request._disconnected:
logger.warning(
"client disconnected before we started processing "
"request"
)
return -1, None
response = await func(
origin, content, request.args, *args, **kwargs
)
@@ -361,11 +367,7 @@ class BaseFederationServlet(object):
continue
server.register_paths(
method,
(pattern,),
self._wrap(code),
self.__class__.__name__,
trace=False,
method, (pattern,), self._wrap(code), self.__class__.__name__,
)
+1 -1
View File
@@ -70,7 +70,7 @@ class GroupAttestationSigning(object):
self.keyring = hs.get_keyring()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.signing_key = hs.config.signing_key[0]
self.signing_key = hs.signing_key
@defer.inlineCallbacks
def verify_attestation(self, attestation, group_id, user_id, server_name=None):
+1 -1
View File
@@ -41,7 +41,7 @@ class GroupsServerWorkerHandler(object):
self.clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.is_mine_id = hs.is_mine_id
self.signing_key = hs.config.signing_key[0]
self.signing_key = hs.signing_key
self.server_name = hs.hostname
self.attestations = hs.get_groups_attestation_signing()
self.transport_client = hs.get_federation_transport_client()
+32 -42
View File
@@ -48,8 +48,7 @@ class ApplicationServicesHandler(object):
self.current_max = 0
self.is_processing = False
@defer.inlineCallbacks
def notify_interested_services(self, current_id):
async def notify_interested_services(self, current_id):
"""Notifies (pushes) all application services interested in this event.
Pushing is done asynchronously, so this method won't block for any
@@ -74,7 +73,7 @@ class ApplicationServicesHandler(object):
(
upper_bound,
events,
) = yield self.store.get_new_events_for_appservice(
) = await self.store.get_new_events_for_appservice(
self.current_max, limit
)
@@ -85,10 +84,9 @@ class ApplicationServicesHandler(object):
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
@defer.inlineCallbacks
def handle_event(event):
async def handle_event(event):
# Gather interested services
services = yield self._get_services_for_event(event)
services = await self._get_services_for_event(event)
if len(services) == 0:
return # no services need notifying
@@ -96,9 +94,9 @@ class ApplicationServicesHandler(object):
# query API for all services which match that user regex.
# This needs to block as these user queries need to be
# made BEFORE pushing the event.
yield self._check_user_exists(event.sender)
await self._check_user_exists(event.sender)
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
await self._check_user_exists(event.state_key)
if not self.started_scheduler:
@@ -115,17 +113,16 @@ class ApplicationServicesHandler(object):
self.scheduler.submit_event_for_as(service, event)
now = self.clock.time_msec()
ts = yield self.store.get_received_ts(event.event_id)
ts = await self.store.get_received_ts(event.event_id)
synapse.metrics.event_processing_lag_by_event.labels(
"appservice_sender"
).observe((now - ts) / 1000)
@defer.inlineCallbacks
def handle_room_events(events):
async def handle_room_events(events):
for event in events:
yield handle_event(event)
await handle_event(event)
yield make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(handle_room_events, evs)
@@ -135,10 +132,10 @@ class ApplicationServicesHandler(object):
)
)
yield self.store.set_appservice_last_pos(upper_bound)
await self.store.set_appservice_last_pos(upper_bound)
now = self.clock.time_msec()
ts = yield self.store.get_received_ts(events[-1].event_id)
ts = await self.store.get_received_ts(events[-1].event_id)
synapse.metrics.event_processing_positions.labels(
"appservice_sender"
@@ -161,8 +158,7 @@ class ApplicationServicesHandler(object):
finally:
self.is_processing = False
@defer.inlineCallbacks
def query_user_exists(self, user_id):
async def query_user_exists(self, user_id):
"""Check if any application service knows this user_id exists.
Args:
@@ -170,15 +166,14 @@ class ApplicationServicesHandler(object):
Returns:
True if this user exists on at least one application service.
"""
user_query_services = yield self._get_services_for_user(user_id=user_id)
user_query_services = self._get_services_for_user(user_id=user_id)
for user_service in user_query_services:
is_known_user = yield self.appservice_api.query_user(user_service, user_id)
is_known_user = await self.appservice_api.query_user(user_service, user_id)
if is_known_user:
return True
return False
@defer.inlineCallbacks
def query_room_alias_exists(self, room_alias):
async def query_room_alias_exists(self, room_alias):
"""Check if an application service knows this room alias exists.
Args:
@@ -193,19 +188,18 @@ class ApplicationServicesHandler(object):
s for s in services if (s.is_interested_in_alias(room_alias_str))
]
for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias(
is_known_alias = await self.appservice_api.query_alias(
alias_service, room_alias_str
)
if is_known_alias:
# the alias exists now so don't query more ASes.
result = yield self.store.get_association_from_room_alias(room_alias)
result = await self.store.get_association_from_room_alias(room_alias)
return result
@defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
async def query_3pe(self, kind, protocol, fields):
services = self._get_services_for_3pn(protocol)
results = yield make_deferred_yieldable(
results = await make_deferred_yieldable(
defer.DeferredList(
[
run_in_background(
@@ -224,8 +218,7 @@ class ApplicationServicesHandler(object):
return ret
@defer.inlineCallbacks
def get_3pe_protocols(self, only_protocol=None):
async def get_3pe_protocols(self, only_protocol=None):
services = self.store.get_app_services()
protocols = {}
@@ -238,7 +231,7 @@ class ApplicationServicesHandler(object):
if p not in protocols:
protocols[p] = []
info = yield self.appservice_api.get_3pe_protocol(s, p)
info = await self.appservice_api.get_3pe_protocol(s, p)
if info is not None:
protocols[p].append(info)
@@ -263,8 +256,7 @@ class ApplicationServicesHandler(object):
return protocols
@defer.inlineCallbacks
def _get_services_for_event(self, event):
async def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event.
Args:
@@ -280,7 +272,7 @@ class ApplicationServicesHandler(object):
# inside of a list comprehension anymore.
interested_list = []
for s in services:
if (yield s.is_interested(event, self.store)):
if await s.is_interested(event, self.store):
interested_list.append(s)
return interested_list
@@ -288,21 +280,20 @@ class ApplicationServicesHandler(object):
def _get_services_for_user(self, user_id):
services = self.store.get_app_services()
interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
return defer.succeed(interested_list)
return interested_list
def _get_services_for_3pn(self, protocol):
services = self.store.get_app_services()
interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
return defer.succeed(interested_list)
return interested_list
@defer.inlineCallbacks
def _is_unknown_user(self, user_id):
async def _is_unknown_user(self, user_id):
if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
return False
user_info = yield self.store.get_user_by_id(user_id)
user_info = await self.store.get_user_by_id(user_id)
if user_info:
return False
@@ -311,10 +302,9 @@ class ApplicationServicesHandler(object):
service_list = [s for s in services if s.sender == user_id]
return len(service_list) == 0
@defer.inlineCallbacks
def _check_user_exists(self, user_id):
unknown_user = yield self._is_unknown_user(user_id)
async def _check_user_exists(self, user_id):
unknown_user = await self._is_unknown_user(user_id)
if unknown_user:
exists = yield self.query_user_exists(user_id)
exists = await self.query_user_exists(user_id)
return exists
return True
+4 -4
View File
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
import unicodedata
@@ -24,7 +23,6 @@ import attr
import bcrypt # type: ignore[import]
import pymacaroons
import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
from synapse.api.errors import (
AuthError,
@@ -45,6 +43,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
@@ -928,7 +928,7 @@ class AuthHandler(BaseHandler):
# for the presence of an email address during password reset was
# case sensitive).
if medium == "email":
address = address.lower()
address = canonicalise_email(address)
await self.store.user_add_threepid(
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
@@ -956,7 +956,7 @@ class AuthHandler(BaseHandler):
# 'Canonicalise' email addresses as per above
if medium == "email":
address = address.lower()
address = canonicalise_email(address)
identity_handler = self.hs.get_handlers().identity_handler
result = await identity_handler.try_unbind_threepid(
+1 -2
View File
@@ -12,11 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
import xml.etree.ElementTree as ET
from typing import Dict, Optional, Tuple
from xml.etree import ElementTree as ET
from twisted.web.client import PartialDownloadError
+83 -79
View File
@@ -19,8 +19,9 @@
import itertools
import logging
from collections import Container
from http import HTTPStatus
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
import attr
from signedjson.key import decode_verify_key_bytes
@@ -742,6 +743,9 @@ class FederationHandler(BaseHandler):
# device and recognize the algorithm then we can work out the
# exact key to expect. Otherwise check it matches any key we
# have for that device.
current_keys = [] # type: Container[str]
if device:
keys = device.get("keys", {}).get("keys", {})
@@ -758,15 +762,15 @@ class FederationHandler(BaseHandler):
current_keys = keys.values()
elif device_id:
# We don't have any keys for the device ID.
current_keys = []
pass
else:
# The event didn't include a device ID, so we just look for
# keys across all devices.
current_keys = (
current_keys = [
key
for device in cached_devices
for key in device.get("keys", {}).get("keys", {}).values()
)
]
# We now check that the sender key matches (one of) the expected
# keys.
@@ -1011,7 +1015,7 @@ class FederationHandler(BaseHandler):
if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
joined_domains = {}
joined_domains = {} # type: Dict[str, int]
for u, d in joined_users:
try:
dom = get_domain_from_id(u)
@@ -1277,14 +1281,15 @@ class FederationHandler(BaseHandler):
try:
# Try the host we successfully got a response to /make_join/
# request first.
host_list = list(target_hosts)
try:
target_hosts.remove(origin)
target_hosts.insert(0, origin)
host_list.remove(origin)
host_list.insert(0, origin)
except ValueError:
pass
ret = await self.federation_client.send_join(
target_hosts, event, room_version_obj
host_list, event, room_version_obj
)
origin = ret["origin"]
@@ -1562,7 +1567,7 @@ class FederationHandler(BaseHandler):
room_version,
event.get_pdu_json(),
self.hs.hostname,
self.hs.config.signing_key[0],
self.hs.signing_key,
)
)
@@ -1584,13 +1589,14 @@ class FederationHandler(BaseHandler):
# Try the host that we succesfully called /make_leave/ on first for
# the /send_leave/ request.
host_list = list(target_hosts)
try:
target_hosts.remove(origin)
target_hosts.insert(0, origin)
host_list.remove(origin)
host_list.insert(0, origin)
except ValueError:
pass
await self.federation_client.send_leave(target_hosts, event)
await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
stream_id = await self.persist_events_and_notify([(event, context)])
@@ -1604,7 +1610,7 @@ class FederationHandler(BaseHandler):
user_id: str,
membership: str,
content: JsonDict = {},
params: Optional[Dict[str, str]] = None,
params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
) -> Tuple[str, EventBase, RoomVersion]:
(
origin,
@@ -2018,8 +2024,8 @@ class FederationHandler(BaseHandler):
auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
@@ -2055,76 +2061,67 @@ class FederationHandler(BaseHandler):
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
if do_soft_fail_check:
extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
if backfilled or event.internal_metadata.is_outlier():
return
extrem_ids = set(extrem_ids)
prev_event_ids = set(event.prev_event_ids())
extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids)
prev_event_ids = set(event.prev_event_ids())
if extrem_ids == prev_event_ids:
# If they're the same then the current state is the same as the
# state at the event, so no point rechecking auth for soft fail.
do_soft_fail_check = False
if extrem_ids == prev_event_ids:
# If they're the same then the current state is the same as the
# state at the event, so no point rechecking auth for soft fail.
return
if do_soft_fail_check:
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# Calculate the "current state".
if state is not None:
# If we're explicitly given the state then we won't have all the
# prev events, and so we have a gap in the graph. In this case
# we want to be a little careful as we might have been down for
# a while and have an incorrect view of the current state,
# however we still want to do checks as gaps are easy to
# maliciously manufacture.
#
# So we use a "current state" that is actually a state
# resolution across the current forward extremities and the
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
# Calculate the "current state".
if state is not None:
# If we're explicitly given the state then we won't have all the
# prev events, and so we have a gap in the graph. In this case
# we want to be a little careful as we might have been down for
# a while and have an incorrect view of the current state,
# however we still want to do checks as gaps are easy to
# maliciously manufacture.
#
# So we use a "current state" that is actually a state
# resolution across the current forward extremities and the
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
state_sets = await self.state_store.get_state_groups(
event.room_id, extrem_ids
)
state_sets = list(state_sets.values())
state_sets.append(state)
current_state_ids = await self.state_handler.resolve_events(
room_version, state_sets, event
)
current_state_ids = {
k: e.event_id for k, e in current_state_ids.items()
}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
)
logger.debug(
"Doing soft-fail check for %s: state %s",
event.event_id,
current_state_ids,
state_sets = await self.state_store.get_state_groups(
event.room_id, extrem_ids
)
state_sets = list(state_sets.values())
state_sets.append(state)
current_state_ids = await self.state_handler.resolve_events(
room_version, state_sets, event
)
current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
)
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [
e for k, e in current_state_ids.items() if k in auth_types
]
logger.debug(
"Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
)
current_auth_events = await self.store.get_events(current_state_ids)
current_auth_events = {
(e.type, e.state_key): e for e in current_auth_events.values()
}
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
try:
event_auth.check(
room_version_obj, event, auth_events=current_auth_events
)
except AuthError as e:
logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True
current_auth_events = await self.store.get_events(current_state_ids)
current_auth_events = {
(e.type, e.state_key): e for e in current_auth_events.values()
}
try:
event_auth.check(room_version_obj, event, auth_events=current_auth_events)
except AuthError as e:
logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True
async def on_query_auth(
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
@@ -2293,10 +2290,10 @@ class FederationHandler(BaseHandler):
remote_auth_chain = await self.federation_client.get_event_auth(
origin, event.room_id, event.event_id
)
except RequestSendFailed as e:
except RequestSendFailed as e1:
# The other side isn't around or doesn't implement the
# endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e)
logger.info("Failed to get event auth from remote: %s", e1)
return context
seen_remotes = await self.store.have_seen_events(
@@ -2774,7 +2771,8 @@ class FederationHandler(BaseHandler):
logger.debug("Checking auth on event %r", event.content)
last_exception = None
last_exception = None # type: Optional[Exception]
# for each public key in the 3pid invite event
for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
try:
@@ -2828,6 +2826,12 @@ class FederationHandler(BaseHandler):
return
except Exception as e:
last_exception = e
if last_exception is None:
# we can only get here if get_public_keys() returned an empty list
# TODO: make this better
raise RuntimeError("no public key in invite event")
raise last_exception
async def _check_key_revocation(self, public_key, url):
+1 -1
View File
@@ -70,7 +70,7 @@ class GroupsLocalWorkerHandler(object):
self.clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.is_mine_id = hs.is_mine_id
self.signing_key = hs.config.signing_key[0]
self.signing_key = hs.signing_key
self.server_name = hs.hostname
self.notifier = hs.get_notifier()
self.attestations = hs.get_groups_attestation_signing()
+2 -2
View File
@@ -251,10 +251,10 @@ class IdentityHandler(BaseHandler):
# 'browser-like' HTTPS.
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
method="POST",
method=b"POST",
url_bytes=url_bytes,
content=content,
destination_is=id_server,
destination_is=id_server.encode("ascii"),
)
headers = {b"Authorization": auth_headers}
+16 -7
View File
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
from canonicaljson import encode_canonical_json, json
@@ -55,6 +55,9 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -349,7 +352,7 @@ _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000
class EventCreationHandler(object):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -814,11 +817,17 @@ class EventCreationHandler(object):
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
try:
await self.auth.check_from_context(room_version, event, context)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
raise err
if event.internal_metadata.is_out_of_band_membership():
# the only sort of out-of-band-membership events we expect to see here
# are invite rejections we have generated ourselves.
assert event.type == EventTypes.Member
assert event.content["membership"] == Membership.LEAVE
else:
try:
await self.auth.check_from_context(room_version, event, context)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
raise err
# Ensure that we can round trip before trying to persist in db
try:
+137 -65
View File
@@ -1,7 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2019 The Matrix.org Foundation C.I.C.
# Copyright 2016-2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,17 +16,21 @@
import abc
import logging
from http import HTTPStatus
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional, Tuple, Union
from unpaddedbase64 import encode_base64
from synapse import types
from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.room_versions import EventFormatVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase
from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
from synapse.replication.http.membership import (
ReplicationLocallyRejectInviteRestServlet,
)
from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -74,10 +76,6 @@ class RoomMemberHandler(object):
)
if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence
else:
self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client(
hs
)
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
@@ -105,46 +103,28 @@ class RoomMemberHandler(object):
raise NotImplementedError()
@abc.abstractmethod
async def _remote_reject_invite(
async def remote_reject_invite(
self,
invite_event_id: str,
txn_id: Optional[str],
requester: Requester,
remote_room_hosts: List[str],
room_id: str,
target: UserID,
content: dict,
) -> Tuple[Optional[str], int]:
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
content: JsonDict,
) -> Tuple[str, int]:
"""
Rejects an out-of-band invite we have received from a remote server
Args:
requester
remote_room_hosts: List of servers to use to try and reject invite
room_id
target: The user rejecting the invite
content: The content for the rejection event
invite_event_id: ID of the invite to be rejected
txn_id: optional transaction ID supplied by the client
requester: user making the rejection request, according to the access token
content: additional content to include in the rejection event.
Normally an empty dict.
Returns:
A dictionary to be returned to the client, may
include event_id etc, or nothing if we locally rejected
event id, stream_id of the leave event
"""
raise NotImplementedError()
async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
"""Mark the invite has having been rejected even though we failed to
create a leave event for it.
"""
if self._is_on_event_persistence_instance:
return await self.persist_event_storage.locally_reject_invite(
user_id, room_id
)
else:
result = await self._locally_reject_client(
instance_name=self._event_stream_writer_instance,
user_id=user_id,
room_id=room_id,
)
return result["stream_id"]
@abc.abstractmethod
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has joined the
@@ -288,7 +268,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
) -> Tuple[Optional[str], int]:
) -> Tuple[str, int]:
key = (room_id,)
with (await self.member_linearizer.queue(key)):
@@ -319,7 +299,7 @@ class RoomMemberHandler(object):
ratelimit: bool = True,
content: Optional[dict] = None,
require_consent: bool = True,
) -> Tuple[Optional[str], int]:
) -> Tuple[str, int]:
content_specified = bool(content)
if content is None:
content = {}
@@ -485,11 +465,17 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
inviter = await self._get_inviter(target.to_string(), room_id)
if not inviter:
invite = await self.store.get_invite_for_local_user_in_room(
user_id=target.to_string(), room_id=room_id
) # type: Optional[RoomsForUser]
if not invite:
raise SynapseError(404, "Not a known room")
if self.hs.is_mine(inviter):
logger.info(
"%s rejects invite to %s from %s", target, room_id, invite.sender
)
if self.hs.is_mine_id(invite.sender):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
@@ -497,10 +483,10 @@ class RoomMemberHandler(object):
# active on other servers.
pass
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
return await self._remote_reject_invite(
requester, remote_room_hosts, room_id, target, content,
# send the rejection to the inviter's HS (with fallback to
# local event)
return await self.remote_reject_invite(
invite.event_id, txn_id, requester, content,
)
return await self._local_membership_update(
@@ -1014,33 +1000,119 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id
async def _remote_reject_invite(
async def remote_reject_invite(
self,
invite_event_id: str,
txn_id: Optional[str],
requester: Requester,
remote_room_hosts: List[str],
room_id: str,
target: UserID,
content: dict,
) -> Tuple[Optional[str], int]:
"""Implements RoomMemberHandler._remote_reject_invite
content: JsonDict,
) -> Tuple[str, int]:
"""
Rejects an out-of-band invite received from a remote user
Implements RoomMemberHandler.remote_reject_invite
"""
invite_event = await self.store.get_event(invite_event_id)
room_id = invite_event.room_id
target_user = invite_event.state_key
# first of all, try doing a rejection via the inviting server
fed_handler = self.federation_handler
try:
inviter_id = UserID.from_string(invite_event.sender)
event, stream_id = await fed_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, target.to_string(), content=content,
[inviter_id.domain], room_id, target_user, content=content
)
return event.event_id, stream_id
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
# if we were unable to reject the invite, we will generate our own
# leave event.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warning("Failed to reject invite: %s", e)
stream_id = await self.locally_reject_invite(target.to_string(), room_id)
return None, stream_id
return await self._locally_reject_invite(
invite_event, txn_id, requester, content
)
async def _locally_reject_invite(
self,
invite_event: EventBase,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
"""Generate a local invite rejection
This is called after we fail to reject an invite via a remote server. It
generates an out-of-band membership event locally.
Args:
invite_event: the invite to be rejected
txn_id: optional transaction ID supplied by the client
requester: user making the rejection request, according to the access token
content: additional content to include in the rejection event.
Normally an empty dict.
"""
room_id = invite_event.room_id
target_user = invite_event.state_key
room_version = await self.store.get_room_version(room_id)
content["membership"] = Membership.LEAVE
# the auth events for the new event are the same as that of the invite, plus
# the invite itself.
#
# the prev_events are just the invite.
invite_hash = invite_event.event_id # type: Union[str, Tuple]
if room_version.event_format == EventFormatVersions.V1:
alg, h = compute_event_reference_hash(invite_event)
invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
auth_events = tuple(invite_event.auth_events) + (invite_hash,)
prev_events = (invite_hash,)
# we cap depth of generated events, to ensure that they are not
# rejected by other servers (and so that they can be persisted in
# the db)
depth = min(invite_event.depth + 1, MAX_DEPTH)
event_dict = {
"depth": depth,
"auth_events": auth_events,
"prev_events": prev_events,
"type": EventTypes.Member,
"room_id": room_id,
"sender": target_user,
"content": content,
"state_key": target_user,
}
event = create_local_event_from_event_dict(
clock=self.clock,
hostname=self.hs.hostname,
signing_key=self.hs.signing_key,
room_version=room_version,
event_dict=event_dict,
)
event.internal_metadata.outlier = True
event.internal_metadata.out_of_band_membership = True
if txn_id is not None:
event.internal_metadata.txn_id = txn_id
if requester.access_token_id is not None:
event.internal_metadata.token_id = requester.access_token_id
EventValidator().validate_new(event, self.config)
context = await self.state_handler.compute_event_context(event)
context.app_service = requester.app_service
stream_id = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[UserID.from_string(target_user)],
)
return event.event_id, stream_id
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_joined_room
+10 -9
View File
@@ -61,21 +61,22 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
return ret["event_id"], ret["stream_id"]
async def _remote_reject_invite(
async def remote_reject_invite(
self,
invite_event_id: str,
txn_id: Optional[str],
requester: Requester,
remote_room_hosts: List[str],
room_id: str,
target: UserID,
content: dict,
) -> Tuple[Optional[str], int]:
"""Implements RoomMemberHandler._remote_reject_invite
) -> Tuple[str, int]:
"""
Rejects an out-of-band invite received from a remote user
Implements RoomMemberHandler.remote_reject_invite
"""
ret = await self._remote_reject_client(
invite_event_id=invite_event_id,
txn_id=txn_id,
requester=requester,
remote_room_hosts=remote_room_hosts,
room_id=room_id,
user_id=target.to_string(),
content=content,
)
return ret["event_id"], ret["stream_id"]
+3
View File
@@ -294,6 +294,9 @@ class TypingHandler(object):
rows.sort()
limited = False
# We, unusually, use a strict limit here as we have all the rows in
# memory rather than pulling them out of the database with a `LIMIT ?`
# clause.
if len(rows) > limit:
rows = rows[:limit]
current_id = rows[-1][0]
+5 -14
View File
@@ -13,13 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.http.server import wrap_json_request_handler
from synapse.http.server import DirectServeJsonResource
class AdditionalResource(Resource):
class AdditionalResource(DirectServeJsonResource):
"""Resource wrapper for additional_resources
If the user has configured additional_resources, we need to wrap the
@@ -41,16 +38,10 @@ class AdditionalResource(Resource):
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
function to be called to handle the request.
"""
Resource.__init__(self)
super().__init__()
self._handler = handler
# required by the request_handler wrapper
self.clock = hs.get_clock()
def render(self, request):
self._async_render(request)
return NOT_DONE_YET
@wrap_json_request_handler
def _async_render(self, request):
# Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not.
return self._handler(request)
+8 -4
View File
@@ -176,7 +176,7 @@ class MatrixFederationHttpClient(object):
def __init__(self, hs, tls_client_options_factory):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.signing_key = hs.signing_key
self.server_name = hs.hostname
real_reactor = hs.get_reactor()
@@ -562,13 +562,17 @@ class MatrixFederationHttpClient(object):
Returns:
list[bytes]: a list of headers to be added as "Authorization:" headers
"""
request = {"method": method, "uri": url_bytes, "origin": self.server_name}
request = {
"method": method.decode("ascii"),
"uri": url_bytes.decode("ascii"),
"origin": self.server_name,
}
if destination is not None:
request["destination"] = destination
request["destination"] = destination.decode("ascii")
if destination_is is not None:
request["destination_is"] = destination_is
request["destination_is"] = destination_is.decode("ascii")
if content is not None:
request["content"] = content
+198 -183
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import collections
import html
import logging
@@ -21,7 +22,7 @@ import types
import urllib
from http import HTTPStatus
from io import BytesIO
from typing import Awaitable, Callable, TypeVar, Union
from typing import Any, Callable, Dict, Tuple, Union
import jinja2
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@@ -62,99 +63,43 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
"""
def wrap_json_request_handler(h):
"""Wraps a request handler method with exception handling.
Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)",
where "request" must be a SynapseRequest.
The handler must return a deferred or a coroutine. If the deferred succeeds
we assume that a response has been sent. If the deferred fails with a SynapseError we use
it to send a JSON response with the appropriate HTTP reponse code. If the
deferred fails with any other type of error we send a 500 reponse.
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients.
"""
async def wrapped_request_handler(self, request):
try:
await h(self, request)
except SynapseError as e:
code = e.code
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
if f.check(SynapseError):
error_code = f.value.code
error_dict = f.value.error_dict()
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
if request.startedWriting:
if request.transport:
try:
request.transport.abortConnection()
except Exception:
# abortConnection throws if the connection is already closed
pass
else:
respond_with_json(
request,
code,
e.error_dict(),
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
else:
error_code = 500
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
except Exception:
# failure.Failure() fishes the original Failure out
# of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
logger.error(
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
)
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
if request.startedWriting:
if request.transport:
try:
request.transport.abortConnection()
except Exception:
# abortConnection throws if the connection is already closed
pass
else:
respond_with_json(
request,
500,
{"error": "Internal server error", "errcode": Codes.UNKNOWN},
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
logger.error(
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
)
return wrap_async_request_handler(wrapped_request_handler)
TV = TypeVar("TV")
def wrap_html_request_handler(
h: Callable[[TV, SynapseRequest], Awaitable]
) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
"""Wraps a request handler method with exception handling.
Also does the wrapping with request.processing as per wrap_async_request_handler.
The handler method must have a signature of "handle_foo(self, request)",
where "request" must be a SynapseRequest.
"""
async def wrapped_request_handler(self, request):
try:
await h(self, request)
except Exception:
f = failure.Failure()
return_html_error(f, request, HTML_ERROR_TEMPLATE)
return wrap_async_request_handler(wrapped_request_handler)
# Only respond with an error response if we haven't already started writing,
# otherwise lets just kill the connection
if request.startedWriting:
if request.transport:
try:
request.transport.abortConnection()
except Exception:
# abortConnection throws if the connection is already closed
pass
else:
respond_with_json(
request,
error_code,
error_dict,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
def return_html_error(
@@ -249,7 +194,113 @@ class HttpServer(object):
pass
class JsonResource(HttpServer, resource.Resource):
class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
"""Base class for resources that have async handlers.
Sub classes can either implement `_async_render_<METHOD>` to handle
requests by method, or override `_async_render` to handle all requests.
Args:
extract_context: Whether to attempt to extract the opentracing
context from the request the servlet is handling.
"""
def __init__(self, extract_context=False):
super().__init__()
self._extract_context = extract_context
def render(self, request):
""" This gets called by twisted every time someone sends us a request.
"""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET
@wrap_async_request_handler
async def _async_render_wrapper(self, request):
"""This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc.
"""
try:
request.request_metrics.name = self.__class__.__name__
with trace_servlet(request, self._extract_context):
callback_return = await self._async_render(request)
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
except Exception:
# failure.Failure() fishes the original Failure out
# of our stack, and thus gives us a sensible stack
# trace.
f = failure.Failure()
self._send_error_response(f, request)
async def _async_render(self, request):
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overriden in sub classes for
different routing.
"""
method_handler = getattr(
self, "_async_render_%s" % (request.method.decode("ascii"),), None
)
if method_handler:
raw_callback_return = method_handler(request)
# Is it synchronous? We'll allow this for now.
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return
return callback_return
_unrecognised_request_handler(request)
@abc.abstractmethod
def _send_response(
self, request: SynapseRequest, code: int, response_object: Any,
) -> None:
raise NotImplementedError()
@abc.abstractmethod
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
) -> None:
raise NotImplementedError()
class DirectServeJsonResource(_AsyncResource):
"""A resource that will call `self._async_on_<METHOD>` on new requests,
formatting responses and errors as JSON.
"""
def _send_response(
self, request, code, response_object,
):
"""Implements _AsyncResource._send_response
"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request,
code,
response_object,
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
)
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response
"""
return_json_error(f, request)
class JsonResource(DirectServeJsonResource):
""" This implements the HttpServer interface and provides JSON support for
Resources.
@@ -269,17 +320,15 @@ class JsonResource(HttpServer, resource.Resource):
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)
def __init__(self, hs, canonical_json=True):
resource.Resource.__init__(self)
def __init__(self, hs, canonical_json=True, extract_context=False):
super().__init__(extract_context)
self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
self.hs = hs
def register_paths(
self, method, path_patterns, callback, servlet_classname, trace=True
):
def register_paths(self, method, path_patterns, callback, servlet_classname):
"""
Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate
@@ -295,74 +344,23 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
trace (bool): Whether we should start a span to trace the servlet.
"""
method = method.encode("utf-8") # method is bytes on py3
if trace:
# We don't extract the context from the servlet because we can't
# trust the sender
callback = trace_servlet(servlet_classname)(callback)
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback, servlet_classname)
)
def render(self, request):
""" This gets called by twisted every time someone sends us a request.
"""
defer.ensureDeferred(self._async_render(request))
return NOT_DONE_YET
@wrap_json_request_handler
async def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
"""
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
# Make sure we have a name for this handler in prometheus.
request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
kwargs = intern_dict(
{
name: urllib.parse.unquote(value) if value else value
for name, value in group_dict.items()
}
)
callback_return = callback(request, **kwargs)
# Is it synchronous? We'll allow this for now.
if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await callback_return
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
def _get_handler_for_request(self, request):
"""Finds a callback method to handle the given request
Args:
request (twisted.web.http.Request):
def _get_handler_for_request(
self, request: SynapseRequest
) -> Tuple[Callable, str, Dict[str, str]]:
"""Finds a callback method to handle the given request.
Returns:
Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
label to use for that method in prometheus metrics, and the
dict mapping keys to path components as specified in the
handler's path match regexp.
The callback will normally be a method registered via
register_paths, so will return (possibly via Deferred) either
None, or a tuple of (http code, response body).
A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback
"""
request_path = request.path.decode("ascii")
@@ -377,42 +375,59 @@ class JsonResource(HttpServer, resource.Resource):
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, "unrecognised_request_handler", {}
def _send_response(
self, request, code, response_json_object, response_code_message=None
):
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request,
code,
response_json_object,
send_cors=True,
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
async def _async_render(self, request):
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
# Make sure we have an appopriate name for this handler in prometheus
# (rather than the default of JsonResource).
request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
kwargs = intern_dict(
{
name: urllib.parse.unquote(value) if value else value
for name, value in group_dict.items()
}
)
raw_callback_return = callback(request, **kwargs)
class DirectServeResource(resource.Resource):
def render(self, request):
# Is it synchronous? We'll allow this for now.
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return
return callback_return
class DirectServeHtmlResource(_AsyncResource):
"""A resource that will call `self._async_on_<METHOD>` on new requests,
formatting responses and errors as HTML.
"""
# The error template to use for this resource
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
def _send_response(
self, request: SynapseRequest, code: int, response_object: Any,
):
"""Implements _AsyncResource._send_response
"""
Render the request, using an asynchronous render handler if it exists.
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
html_bytes = response_object
respond_with_html_bytes(request, 200, html_bytes)
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response
"""
async_render_callback_name = "_async_render_" + request.method.decode("ascii")
# Try and get the async renderer
callback = getattr(self, async_render_callback_name, None)
# No async renderer for this request method.
if not callback:
return super().render(request)
resp = trace_servlet(self.__class__.__name__)(callback)(request)
# If it's a coroutine, turn it into a Deferred
if isinstance(resp, types.CoroutineType):
defer.ensureDeferred(resp)
return NOT_DONE_YET
return_html_error(f, request, self.ERROR_TEMPLATE)
class StaticResource(File):
+32 -38
View File
@@ -164,12 +164,10 @@ Gotchas
than one caller? Will all of those calling functions have be in a context
with an active span?
"""
import contextlib
import inspect
import logging
import re
import types
from functools import wraps
from typing import TYPE_CHECKING, Dict, Optional, Type
@@ -181,6 +179,7 @@ from twisted.internet import defer
from synapse.config import ConfigError
if TYPE_CHECKING:
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
# Helper class
@@ -227,6 +226,7 @@ except ImportError:
tags = _DummyTagNames
try:
from jaeger_client import Config as JaegerConfig
from synapse.logging.scopecontextmanager import LogContextScopeManager
except ImportError:
JaegerConfig = None # type: ignore
@@ -793,48 +793,42 @@ def tag_args(func):
return _tag_args_inner
def trace_servlet(servlet_name, extract_context=False):
"""Decorator which traces a serlet. It starts a span with some servlet specific
tags such as the servlet_name and request information
@contextlib.contextmanager
def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
"""Returns a context manager which traces a request. It starts a span
with some servlet specific tags such as the request metrics name and
request information.
Args:
servlet_name (str): The name to be used for the span's operation_name
extract_context (bool): Whether to attempt to extract the opentracing
request
extract_context: Whether to attempt to extract the opentracing
context from the request the servlet is handling.
"""
def _trace_servlet_inner_1(func):
if not opentracing:
return func
if opentracing is None:
yield
return
@wraps(func)
async def _trace_servlet_inner(request, *args, **kwargs):
request_tags = {
"request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
tags.PEER_HOST_IPV6: request.getClientIP(),
}
request_tags = {
"request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
tags.PEER_HOST_IPV6: request.getClientIP(),
}
if extract_context:
scope = start_active_span_from_request(
request, servlet_name, tags=request_tags
)
else:
scope = start_active_span(servlet_name, tags=request_tags)
request_name = request.request_metrics.name
if extract_context:
scope = start_active_span_from_request(request, request_name, tags=request_tags)
else:
scope = start_active_span(request_name, tags=request_tags)
with scope:
result = func(request, *args, **kwargs)
with scope:
try:
yield
finally:
# We set the operation name again in case its changed (which happens
# with JsonResource).
scope.span.set_operation_name(request.request_metrics.name)
if not isinstance(result, (types.CoroutineType, defer.Deferred)):
# Some servlets aren't async and just return results
# directly, so we handle that here.
return result
return await result
return _trace_servlet_inner
return _trace_servlet_inner_1
scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Set
from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.logging.context import LoggingContext, PreserveLoggingContext
@@ -212,7 +213,14 @@ def run_as_background_process(desc, func, *args, **kwargs):
return (yield result)
except Exception:
logger.exception("Background process '%s' threw an exception", desc)
# failure.Failure() fishes the original Failure out of our stack, and
# thus gives us a sensible stack trace.
f = Failure()
logger.error(
"Background process '%s' threw an exception",
desc,
exc_info=(f.type, f.value, f.getTracebackObject()),
)
finally:
_background_process_in_flight_count.labels(desc).dec()
+1 -1
View File
@@ -83,7 +83,7 @@ class _NotifierUserStream(object):
self.current_token = current_token
# The last token for which we should wake up any streams that have a
# token that comes before it. This gets updated everytime we get poked.
# token that comes before it. This gets updated every time we get poked.
# We start it at the current token since if we get any streams
# that have a token from before we have no idea whether they should be
# woken up or not, so lets just wake them up.
+14 -3
View File
@@ -20,6 +20,7 @@ from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.api.constants import EventTypes
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
@@ -305,12 +306,23 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
priority = "low"
if (
event.type == EventTypes.Encrypted
or tweaks.get("highlight")
or tweaks.get("sound")
):
# HACK send our push as high priority only if it generates a sound, highlight
# or may do so (i.e. is encrypted so has unknown effects).
priority = "high"
if self.data.get("format") == "event_id_only":
d = {
"notification": {
"event_id": event.event_id,
"room_id": event.room_id,
"counts": {"unread": badge},
"prio": priority,
"devices": [
{
"app_id": self.app_id,
@@ -334,9 +346,8 @@ class HttpPusher(object):
"room_id": event.room_id,
"type": event.type,
"sender": event.user_id,
"counts": { # -- we don't mark messages as read yet so
# we have no way of knowing
# Just set the badge to 1 until we have read receipts
"prio": priority,
"counts": {
"unread": badge,
# 'missed_calls': 2
},
+27 -4
View File
@@ -16,7 +16,7 @@
import logging
import re
from typing import Pattern
from typing import Any, Dict, List, Pattern, Union
from synapse.events import EventBase
from synapse.types import UserID
@@ -72,13 +72,36 @@ def _test_ineq_condition(condition, number):
return False
def tweaks_for_actions(actions):
def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
"""
Converts a list of actions into a `tweaks` dict (which can then be passed to
the push gateway).
This function ignores all actions other than `set_tweak` actions, and treats
absent `value`s as `True`, which agrees with the only spec-defined treatment
of absent `value`s (namely, for `highlight` tweaks).
Args:
actions: list of actions
e.g. [
{"set_tweak": "a", "value": "AAA"},
{"set_tweak": "b", "value": "BBB"},
{"set_tweak": "highlight"},
"notify"
]
Returns:
dictionary of tweaks for those actions
e.g. {"a": "AAA", "b": "BBB", "highlight": True}
"""
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if "set_tweak" in a and "value" in a:
tweaks[a["set_tweak"]] = a["value"]
if "set_tweak" in a:
# value is allowed to be absent in which case the value assumed
# should be True.
tweaks[a["set_tweak"]] = a.get("value", True)
return tweaks
+1 -1
View File
@@ -66,7 +66,7 @@ REQUIREMENTS = [
"pymacaroons>=0.13.0",
"msgpack>=0.5.2",
"phonenumbers>=8.2.0",
"prometheus_client>=0.0.18,<0.8.0",
"prometheus_client>=0.0.18,<0.9.0",
# we use attr.validators.deep_iterable, which arrived in 19.1.0
"attrs>=19.1.0",
"netaddr>=0.7.18",
+2 -1
View File
@@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication"
class ReplicationRestResource(JsonResource):
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
# We enable extracting jaeger contexts here as these are internal APIs.
super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs)
def register_servlets(self, hs):
+4 -11
View File
@@ -28,11 +28,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
from synapse.logging.opentracing import (
inject_active_span_byte_dict,
trace,
trace_servlet,
)
from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -96,11 +92,11 @@ class ReplicationEndpoint(object):
# assert here that sub classes don't try and use the name.
assert (
"instance_name" not in self.PATH_ARGS
), "`instance_name` is a reserved paramater name"
), "`instance_name` is a reserved parameter name"
assert (
"instance_name"
not in signature(self.__class__._serialize_payload).parameters
), "`instance_name` is a reserved paramater name"
), "`instance_name` is a reserved parameter name"
assert self.METHOD in ("PUT", "POST", "GET")
@@ -240,11 +236,8 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
# We don't let register paths trace this servlet using the default tracing
# options because we wish to extract the context explicitly.
http_server.register_paths(
method, [pattern], handler, self.__class__.__name__, trace=False
method, [pattern], handler, self.__class__.__name__,
)
def _cached_handler(self, request, txn_id, **kwargs):
+25 -67
View File
@@ -14,11 +14,11 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
from synapse.types import JsonDict, Requester, UserID
from synapse.util.distributor import user_joined_room, user_left_room
if TYPE_CHECKING:
@@ -88,49 +88,54 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"""Rejects the invite for the user and room.
"""Rejects an out-of-band invite we have received from a remote server
Request format:
POST /_synapse/replication/remote_reject_invite/:room_id/:user_id
POST /_synapse/replication/remote_reject_invite/:event_id
{
"txn_id": ...,
"requester": ...,
"remote_room_hosts": [...],
"content": { ... }
}
"""
NAME = "remote_reject_invite"
PATH_ARGS = ("room_id", "user_id")
PATH_ARGS = ("invite_event_id",)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
self.federation_handler = hs.get_handlers().federation_handler
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.member_handler = hs.get_room_member_handler()
@staticmethod
def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
def _serialize_payload( # type: ignore
invite_event_id: str,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
):
"""
Args:
requester(Requester)
room_id (str)
user_id (str)
remote_room_hosts (list[str]): Servers to try and reject via
invite_event_id: ID of the invite to be rejected
txn_id: optional transaction ID supplied by the client
requester: user making the rejection request, according to the access token
content: additional content to include in the rejection event.
Normally an empty dict.
"""
return {
"txn_id": txn_id,
"requester": requester.serialize(),
"remote_room_hosts": remote_room_hosts,
"content": content,
}
async def _handle_request(self, request, room_id, user_id):
async def _handle_request(self, request, invite_event_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
txn_id = content["txn_id"]
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
@@ -138,60 +143,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
if requester.user:
request.authenticated_entity = requester.user.to_string()
logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
try:
event, stream_id = await self.federation_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, user_id, event_content,
)
event_id = event.event_id
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warning("Failed to reject invite: %s", e)
stream_id = await self.member_handler.locally_reject_invite(
user_id, room_id
)
event_id = None
# hopefully we're now on the master, so this won't recurse!
event_id, stream_id = await self.member_handler.remote_reject_invite(
invite_event_id, txn_id, requester, event_content,
)
return 200, {"event_id": event_id, "stream_id": stream_id}
class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
"""Rejects the invite for the user and room locally.
Request format:
POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
{}
"""
NAME = "locally_reject_invite"
PATH_ARGS = ("room_id", "user_id")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.member_handler = hs.get_room_member_handler()
@staticmethod
def _serialize_payload(room_id, user_id):
return {}
async def _handle_request(self, request, room_id, user_id):
logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
return 200, {"stream_id": stream_id}
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
"""Notifies that a user has joined or left the room
@@ -245,4 +204,3 @@ def register_servlets(hs, http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)
@@ -16,6 +16,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
from synapse.storage.data_stores.main.tags import TagsWorkerStore
from synapse.storage.database import Database
@@ -39,12 +40,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
return self._account_data_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "tag_account_data":
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == "account_data":
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(token)
for row in rows:
if not row.room_id:
@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
@@ -44,7 +45,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "to_device":
if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(token)
for row in rows:
if row.entity.startswith("@"):
+2 -1
View File
@@ -15,6 +15,7 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream
from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,7 +39,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "groups":
if stream_name == GroupServerStream.NAME:
self._group_updates_id_gen.advance(token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import PresenceStream
from synapse.storage import DataStore
from synapse.storage.data_stores.main.presence import PresenceStore
from synapse.storage.database import Database
@@ -42,7 +43,7 @@ class SlavedPresenceStore(BaseSlavedStore):
return self._presence_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "presence":
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
from .events import SlavedEventStore
@@ -30,7 +31,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "push_rules":
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
+2 -1
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import PushersStream
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
from synapse.storage.database import Database
@@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
return self._pushers_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "pushers":
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
@@ -14,20 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
# DataStore or are cached and don't have cache invalidation logic.
#
# Rather than write duplicate versions of those functions, or lift them to
# a common base class, we going to grab the underlying __func__ object from
# the method descriptor on the DataStore and chuck them into our class.
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs):
@@ -52,7 +45,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
self.get_receipts_for_room.invalidate((room_id, receipt_type))
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "receipts":
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(token)
for row in rows:
self.invalidate_caches_for_receipt(
+2 -1
View File
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import PublicRoomsStream
from synapse.storage.data_stores.main.room import RoomWorkerStore
from synapse.storage.database import Database
@@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
return self._public_room_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "public_rooms":
if stream_name == PublicRoomsStream.NAME:
self._public_room_id_gen.advance(token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
+1 -1
View File
@@ -25,7 +25,7 @@ Structure of the module:
* command.py - the definitions of all the valid commands
* protocol.py - the TCP protocol classes
* resource.py - handles streaming stream updates to replications
* streams/ - the definitons of all the valid streams
* streams/ - the definitions of all the valid streams
The general interaction of the classes are:
+1 -1
View File
@@ -33,8 +33,8 @@ from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
+3 -10
View File
@@ -18,18 +18,11 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
allowed to be sent by which side.
"""
import abc
import json
import logging
import platform
from typing import Tuple, Type
if platform.python_implementation() == "PyPy":
import json
_json_encoder = json.JSONEncoder()
else:
import simplejson as json # type: ignore[no-redef] # noqa: F821
_json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
_json_encoder = json.JSONEncoder()
logger = logging.getLogger(__name__)
@@ -54,7 +47,7 @@ class Command(metaclass=abc.ABCMeta):
@abc.abstractmethod
def to_line(self) -> str:
"""Serialises the comamnd for the wire. Does not include the command
"""Serialises the command for the wire. Does not include the command
prefix.
"""
+2 -2
View File
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
@@ -149,10 +148,11 @@ class ReplicationCommandHandler:
using TCP.
"""
if hs.config.redis.redis_enabled:
import txredisapi
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
)
import txredisapi
logger.info(
"Connecting to redis (host=%r port=%r)",
+1 -1
View File
@@ -317,7 +317,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def _queue_command(self, cmd):
"""Queue the command until the connection is ready to write to again.
"""
logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
self.pending_commands.append(cmd)
if len(self.pending_commands) > self.max_line_buffer:
+1 -1
View File
@@ -177,7 +177,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
Args:
hs
outbound_redis_connection: A connection to redis that will be used to
send outbound commands (this is seperate to the redis connection
send outbound commands (this is separate to the redis connection
used to subscribe).
"""
+10 -46
View File
@@ -198,26 +198,6 @@ def current_token_without_instance(
return lambda instance_name: current_token()
def db_query_to_update_function(
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
) -> UpdateFunction:
"""Wraps a db query function which returns a list of rows to make it
suitable for use as an `update_function` for the Stream class
"""
async def update_function(instance_name, from_token, upto_token, limit):
rows = await query_function(from_token, upto_token, limit)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return update_function
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
"""Makes a suitable function for use as an `update_function` that queries
the master process for updates.
@@ -393,7 +373,7 @@ class PushersStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_pushers_stream_token),
db_query_to_update_function(store.get_all_updated_pushers_rows),
store.get_all_updated_pushers_rows,
)
@@ -421,27 +401,13 @@ class CachesStream(Stream):
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
self.store = hs.get_datastore()
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
self.store.get_cache_stream_token,
self._update_function,
store.get_cache_stream_token,
store.get_all_updated_caches,
)
async def _update_function(
self, instance_name: str, from_token: int, upto_token: int, limit: int
):
rows = await self.store.get_all_updated_caches(
instance_name, from_token, upto_token, limit
)
updates = [(row[0], row[1:]) for row in rows]
limited = False
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
class PublicRoomsStream(Stream):
"""The public rooms list changed
@@ -465,7 +431,7 @@ class PublicRoomsStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_current_public_room_stream_id),
db_query_to_update_function(store.get_all_new_public_rooms),
store.get_all_new_public_rooms,
)
@@ -486,7 +452,7 @@ class DeviceListsStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
store.get_all_device_list_changes_for_remotes,
)
@@ -504,7 +470,7 @@ class ToDeviceStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_to_device_stream_token),
db_query_to_update_function(store.get_all_new_device_messages),
store.get_all_new_device_messages,
)
@@ -524,7 +490,7 @@ class TagAccountDataStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_max_account_data_stream_id),
db_query_to_update_function(store.get_all_updated_tags),
store.get_all_updated_tags,
)
@@ -612,7 +578,7 @@ class GroupServerStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_group_stream_token),
db_query_to_update_function(store.get_all_groups_changes),
store.get_all_groups_changes,
)
@@ -630,7 +596,5 @@ class UserSignatureStream(Stream):
super().__init__(
hs.get_instance_name(),
current_token_without_instance(store.get_device_stream_token),
db_query_to_update_function(
store.get_all_user_signature_changes_for_remotes
),
store.get_all_user_signature_changes_for_remotes,
)
+1 -3
View File
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
from collections import Iterable
from typing import List, Tuple, Type
@@ -22,7 +21,6 @@ import attr
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
"""Handling of the 'events' replication stream
This stream contains rows of various types. Each row therefore contains a 'type'
@@ -64,7 +62,7 @@ class BaseEventsStreamRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
# Unique string that ids the type. Must be overriden in sub classes.
# Unique string that ids the type. Must be overridden in sub classes.
TypeId = None # type: str
@classmethod
+35 -25
View File
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -26,8 +27,9 @@ from synapse.http.servlet import (
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID
from synapse.types import JsonDict, UserID
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
@@ -113,7 +115,7 @@ class LoginRestServlet(RestServlet):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
def on_GET(self, request):
def on_GET(self, request: SynapseRequest):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
@@ -141,10 +143,10 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
def on_OPTIONS(self, request):
def on_OPTIONS(self, request: SynapseRequest):
return 200, {}
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest):
self._address_ratelimiter.ratelimit(request.getClientIP())
login_submission = parse_json_object_from_request(request)
@@ -153,9 +155,9 @@ class LoginRestServlet(RestServlet):
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
result = await self.do_jwt_login(login_submission)
result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = await self.do_token_login(login_submission)
result = await self._do_token_login(login_submission)
else:
result = await self._do_other_login(login_submission)
except KeyError:
@@ -166,14 +168,14 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
async def _do_other_login(self, login_submission):
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
Args:
login_submission:
Returns:
dict: HTTP response
HTTP response
"""
# Log the request we got, but only certain fields to minimise the chance of
# logging someone's password (even if they accidentally put it in the wrong
@@ -206,11 +208,14 @@ class LoginRestServlet(RestServlet):
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
# For emails, canonicalise the address.
# We store all email addresses canonicalised in the DB.
# (See add_threepid in synapse/handlers/auth.py)
if medium == "email":
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
try:
address = canonicalise_email(address)
except ValueError as e:
raise SynapseError(400, str(e))
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
@@ -288,25 +293,30 @@ class LoginRestServlet(RestServlet):
return result
async def _complete_login(
self, user_id, login_submission, callback=None, create_non_existent_users=False
):
self,
user_id: str,
login_submission: JsonDict,
callback: Optional[
Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
] = None,
create_non_existent_users: bool = False,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
all succesful logins.
all successful logins.
Applies the ratelimiting for succesful login attempts against an
Applies the ratelimiting for successful login attempts against an
account.
Args:
user_id (str): ID of the user to register.
login_submission (dict): Dictionary of login information.
callback (func|None): Callback function to run after registration.
create_non_existent_users (bool): Whether to create the user if
they don't exist. Defaults to False.
user_id: ID of the user to register.
login_submission: Dictionary of login information.
callback: Callback function to run after registration.
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
Returns:
result (Dict[str,str]): Dictionary of account information after
successful registration.
result: Dictionary of account information after successful registration.
"""
# Before we actually log them in we check if they've already logged in
@@ -340,7 +350,7 @@ class LoginRestServlet(RestServlet):
return result
async def do_token_login(self, login_submission):
async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -350,7 +360,7 @@ class LoginRestServlet(RestServlet):
result = await self._complete_login(user_id, login_submission)
return result
async def do_jwt_login(self, login_submission):
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
token = login_submission.get("token", None)
if token is None:
raise LoginError(
+2 -4
View File
@@ -217,10 +217,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
)
event_id = event.event_id
ret = {} # type: dict
if event_id:
set_tag("event_id", event_id)
ret = {"event_id": event_id}
set_tag("event_id", event_id)
ret = {"event_id": event_id}
return 200, ret
+1 -1
View File
@@ -50,7 +50,7 @@ class VoipRestServlet(RestServlet):
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
password = base64.b64encode(mac.digest())
password = base64.b64encode(mac.digest()).decode("ascii")
elif turnUris and turnUsername and turnPassword and userLifetime:
username = turnUsername
+31 -9
View File
@@ -30,7 +30,7 @@ from synapse.http.servlet import (
from synapse.push.mailer import Mailer, load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -83,7 +83,15 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
email = body["email"]
# Canonicalise the email address. The addresses are all stored canonicalised
# in the database. This allows the user to reset his password without having to
# know the exact spelling (eg. upper and lower case) of address in the database.
# Stored in the database "foo@bar.com"
# User requests with "FOO@bar.com" would raise a Not Found error
try:
email = canonicalise_email(body["email"])
except ValueError as e:
raise SynapseError(400, str(e))
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -94,6 +102,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
# canonicalisation, matches the one in our database.
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
@@ -274,10 +286,13 @@ class PasswordRestServlet(RestServlet):
if "medium" not in threepid or "address" not in threepid:
raise SynapseError(500, "Malformed threepid")
if threepid["medium"] == "email":
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# For emails, canonicalise the address.
# We store all email addresses canonicalised in the DB.
# (See add_threepid in synapse/handlers/auth.py)
threepid["address"] = threepid["address"].lower()
try:
threepid["address"] = canonicalise_email(threepid["address"])
except ValueError as e:
raise SynapseError(400, str(e))
# if using email, we must know about the email they're authing with!
threepid_user_id = await self.datastore.get_user_id_by_threepid(
threepid["medium"], threepid["address"]
@@ -392,7 +407,16 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
email = body["email"]
# Canonicalise the email address. The addresses are all stored canonicalised
# in the database.
# This ensures that the validation email is sent to the canonicalised address
# as it will later be entered into the database.
# Otherwise the email will be sent to "FOO@bar.com" and stored as
# "foo@bar.com" in database.
try:
email = canonicalise_email(body["email"])
except ValueError as e:
raise SynapseError(400, str(e))
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -403,9 +427,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
existing_user_id = await self.store.get_user_id_by_threepid(
"email", body["email"]
)
existing_user_id = await self.store.get_user_id_by_threepid("email", email)
if existing_user_id is not None:
if self.config.request_token_inhibit_3pid_errors:
+19 -3
View File
@@ -47,7 +47,7 @@ from synapse.push.mailer import load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -116,7 +116,14 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
email = body["email"]
# For emails, canonicalise the address.
# We store all email addresses canonicalised in the DB.
# (See on_POST in EmailThreepidRequestTokenRestServlet
# in synapse/rest/client/v2_alpha/account.py)
try:
email = canonicalise_email(body["email"])
except ValueError as e:
raise SynapseError(400, str(e))
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -128,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", body["email"]
"email", email
)
if existing_user_id is not None:
@@ -552,6 +559,15 @@ class RegisterRestServlet(RestServlet):
if login_type in auth_result:
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
# For emails, canonicalise the address.
# We store all email addresses canonicalised in the DB.
# (See on_POST in EmailThreepidRequestTokenRestServlet
# in synapse/rest/client/v2_alpha/account.py)
if medium == "email":
try:
address = canonicalise_email(address)
except ValueError as e:
raise SynapseError(400, str(e))
existing_user_id = await self.store.get_user_id_by_threepid(
medium, address
+2 -8
View File
@@ -26,11 +26,7 @@ from twisted.internet import defer
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
from synapse.http.server import (
DirectServeResource,
respond_with_html,
wrap_html_request_handler,
)
from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_string
from synapse.types import UserID
@@ -48,7 +44,7 @@ else:
return a == b
class ConsentResource(DirectServeResource):
class ConsentResource(DirectServeHtmlResource):
"""A twisted Resource to display a privacy policy and gather consent to it
When accessed via GET, returns the privacy policy via a template.
@@ -119,7 +115,6 @@ class ConsentResource(DirectServeResource):
self._hmac_secret = hs.config.form_secret.encode("utf-8")
@wrap_html_request_handler
async def _async_render_GET(self, request):
"""
Args:
@@ -160,7 +155,6 @@ class ConsentResource(DirectServeResource):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")
@wrap_html_request_handler
async def _async_render_POST(self, request):
"""
Args:
+4 -8
View File
@@ -20,17 +20,13 @@ from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import (
DirectServeResource,
respond_with_json_bytes,
wrap_json_request_handler,
)
from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
from synapse.http.servlet import parse_integer, parse_json_object_from_request
logger = logging.getLogger(__name__)
class RemoteKey(DirectServeResource):
class RemoteKey(DirectServeJsonResource):
"""HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource):
isLeaf = True
def __init__(self, hs):
super().__init__()
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config
@wrap_json_request_handler
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
(server,) = request.postpath
@@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
@wrap_json_request_handler
async def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
+3 -11
View File
@@ -14,16 +14,10 @@
# limitations under the License.
#
from twisted.web.server import NOT_DONE_YET
from synapse.http.server import (
DirectServeResource,
respond_with_json,
wrap_json_request_handler,
)
from synapse.http.server import DirectServeJsonResource, respond_with_json
class MediaConfigResource(DirectServeResource):
class MediaConfigResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs):
@@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource):
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
@wrap_json_request_handler
async def _async_render_GET(self, request):
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
def render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
+2 -10
View File
@@ -15,18 +15,14 @@
import logging
import synapse.http.servlet
from synapse.http.server import (
DirectServeResource,
set_cors_headers,
wrap_json_request_handler,
)
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from ._base import parse_media_id, respond_404
logger = logging.getLogger(__name__)
class DownloadResource(DirectServeResource):
class DownloadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
@@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource):
self.media_repo = media_repo
self.server_name = hs.hostname
# this is expected by @wrap_json_request_handler
self.clock = hs.get_clock()
@wrap_json_request_handler
async def _async_render_GET(self, request):
set_cors_headers(request)
request.setHeader(
@@ -34,10 +34,9 @@ from twisted.internet.error import DNSLookupError
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
DirectServeResource,
DirectServeJsonResource,
respond_with_json,
respond_with_json_bytes,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -58,7 +57,7 @@ OG_TAG_NAME_MAXLEN = 50
OG_TAG_VALUE_MAXLEN = 1000
class PreviewUrlResource(DirectServeResource):
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
@@ -108,11 +107,10 @@ class PreviewUrlResource(DirectServeResource):
self._start_expire_url_cache_data, 10 * 1000
)
def render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request):
request.setHeader(b"Allow", b"OPTIONS, GET")
return respond_with_json(request, 200, {}, send_cors=True)
respond_with_json(request, 200, {}, send_cors=True)
@wrap_json_request_handler
async def _async_render_GET(self, request):
# XXX: if get_user_by_req fails, what should we do in an async render?
+2 -8
View File
@@ -16,11 +16,7 @@
import logging
from synapse.http.server import (
DirectServeResource,
set_cors_headers,
wrap_json_request_handler,
)
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
from ._base import (
@@ -34,7 +30,7 @@ from ._base import (
logger = logging.getLogger(__name__)
class ThumbnailResource(DirectServeResource):
class ThumbnailResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
@@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource):
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
self.clock = hs.get_clock()
@wrap_json_request_handler
async def _async_render_GET(self, request):
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
+1 -2
View File
@@ -12,11 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from io import BytesIO
import PIL.Image as Image
from PIL import Image as Image
logger = logging.getLogger(__name__)
+3 -11
View File
@@ -15,20 +15,14 @@
import logging
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import (
DirectServeResource,
respond_with_json,
wrap_json_request_handler,
)
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__)
class UploadResource(DirectServeResource):
class UploadResource(DirectServeJsonResource):
isLeaf = True
def __init__(self, hs, media_repo):
@@ -43,11 +37,9 @@ class UploadResource(DirectServeResource):
self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock()
def render_OPTIONS(self, request):
async def _async_render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
@wrap_json_request_handler
async def _async_render_POST(self, request):
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
+3 -4
View File
@@ -14,18 +14,17 @@
# limitations under the License.
import logging
from synapse.http.server import DirectServeResource, wrap_html_request_handler
from synapse.http.server import DirectServeHtmlResource
logger = logging.getLogger(__name__)
class OIDCCallbackResource(DirectServeResource):
class OIDCCallbackResource(DirectServeHtmlResource):
isLeaf = 1
def __init__(self, hs):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
@wrap_html_request_handler
async def _async_render_GET(self, request):
return await self._oidc_handler.handle_oidc_callback(request)
await self._oidc_handler.handle_oidc_callback(request)
+2 -2
View File
@@ -16,10 +16,10 @@
from twisted.python import failure
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeResource, return_html_error
from synapse.http.server import DirectServeHtmlResource, return_html_error
class SAML2ResponseResource(DirectServeResource):
class SAML2ResponseResource(DirectServeHtmlResource):
"""A Twisted web resource which handles the SAML response"""
isLeaf = 1
+1 -2
View File
@@ -19,7 +19,6 @@ Injectable secrets module for Synapse.
See https://docs.python.org/3/library/secrets.html#module-secrets for the API
used in Python 3.6, and the API emulated in Python 2.7.
"""
import sys
# secrets is available since python 3.6
@@ -31,8 +30,8 @@ if sys.version_info[0:2] >= (3, 6):
else:
import os
import binascii
import os
class Secrets(object):
def token_bytes(self, nbytes=32):
+2
View File
@@ -232,6 +232,8 @@ class HomeServer(object):
self._reactor = reactor
self.hostname = hostname
# the key we use to sign events and requests
self.signing_key = config.key.signing_key[0]
self.config = config
self._building = {}
self._listening_services = []
+35 -9
View File
@@ -16,10 +16,12 @@
import itertools
import logging
from typing import Any, Iterable, Optional, Tuple
from typing import Any, Iterable, List, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream
from synapse.replication.tcp.streams.events import (
EventsStream,
EventsStreamCurrentStateRow,
EventsStreamEventRow,
)
@@ -44,13 +46,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
async def get_all_updated_caches(
self, instance_name: str, last_id: int, current_id: int, limit: int
):
"""Fetches cache invalidation rows between the two given IDs written
by the given instance. Returns at most `limit` rows.
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for caches replication stream.
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
return []
return [], current_id, False
def get_all_updated_caches_txn(txn):
# We purposefully don't bound by the current token, as we want to
@@ -64,17 +83,24 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (last_id, instance_name, limit))
return txn.fetchall()
updates = [(row[0], row[1:]) for row in txn]
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "events":
if stream_name == EventsStream.NAME:
for row in rows:
self._process_event_stream_row(token, row)
elif stream_name == "backfill":
elif stream_name == BackfillStream.NAME:
for row in rows:
self._invalidate_caches_for_event(
-token,
@@ -86,7 +112,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
row.relates_to,
backfilled=True,
)
elif stream_name == "caches":
elif stream_name == CachesStream.NAME:
if self._cache_id_gen:
self._cache_id_gen.advance(instance_name, token)
+38 -16
View File
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import List, Tuple
from canonicaljson import json
@@ -207,31 +208,46 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
async def get_all_new_device_messages(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for to device replication stream.
Args:
last_pos(int):
current_pos(int):
limit(int):
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A deferred list of rows from the device inbox
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
if last_pos == current_pos:
return defer.succeed([])
if last_id == current_id:
return [], current_id, False
def get_all_new_device_messages_txn(txn):
# We limit like this as we might have multiple rows per stream_id, and
# we want to make sure we always get all entries for any stream_id
# we return.
upper_pos = min(current_pos, last_pos + limit)
upper_pos = min(current_id, last_id + limit)
sql = (
"SELECT max(stream_id), user_id"
" FROM device_inbox"
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY user_id"
)
txn.execute(sql, (last_pos, upper_pos))
rows = txn.fetchall()
txn.execute(sql, (last_id, upper_pos))
updates = [(row[0], row[1:]) for row in txn]
sql = (
"SELECT max(stream_id), destination"
@@ -239,15 +255,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" WHERE ? < stream_id AND stream_id <= ?"
" GROUP BY destination"
)
txn.execute(sql, (last_pos, upper_pos))
rows.extend(txn)
txn.execute(sql, (last_id, upper_pos))
updates.extend((row[0], row[1:]) for row in txn)
# Order by ascending stream ordering
rows.sort()
updates.sort()
return rows
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return self.db.runInteraction(
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
+48 -22
View File
@@ -582,32 +582,58 @@ class DeviceWorkerStore(SQLBaseStore):
return set()
async def get_all_device_list_changes_for_remotes(
self, from_key: int, to_key: int, limit: int,
) -> List[Tuple[int, str]]:
"""Return a list of `(stream_id, entity)` which is the combined list of
changes to devices and which destinations need to be poked. Entity is
either a user ID (starting with '@') or a remote destination.
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for device lists replication stream.
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
SELECT stream_id, entity FROM (
SELECT stream_id, user_id AS entity FROM device_lists_stream
UNION ALL
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
if last_id == current_id:
return [], current_id, False
return await self.db.execute(
def _get_all_device_list_changes_for_remotes(txn):
# This query Does The Right Thing where it'll correctly apply the
# bounds to the inner queries.
sql = """
SELECT stream_id, entity FROM (
SELECT stream_id, user_id AS entity FROM device_lists_stream
UNION ALL
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], row[1:]) for row in txn]
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_device_list_changes_for_remotes",
None,
sql,
from_key,
to_key,
limit,
_get_all_device_list_changes_for_remotes,
)
@cached(max_entries=10000)
@@ -14,7 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
from typing import Dict, List, Tuple
from canonicaljson import encode_canonical_json, json
@@ -479,34 +479,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result
def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
"""Return a list of changes from the user signature stream to notify remotes.
async def get_all_user_signature_changes_for_remotes(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for groups replication stream.
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
users or servers, so no `destination` is needed in the returned
list. However, this is needed to poke workers.
Args:
from_key (int): the stream ID to start at (exclusive)
to_key (int): the stream ID to end at (inclusive)
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
sql = """
SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
return self.db.execute(
if last_id == current_id:
return [], current_id, False
def _get_all_user_signature_changes_for_remotes_txn(txn):
sql = """
SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], (row[1:])) for row in txn]
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_user_signature_changes_for_remotes",
None,
sql,
from_key,
to_key,
limit,
_get_all_user_signature_changes_for_remotes_txn,
)
+23 -96
View File
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import OrderedDict, namedtuple
@@ -28,12 +27,7 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
from synapse.api.constants import (
EventContentFields,
EventTypes,
Membership,
RelationTypes,
)
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.room_versions import RoomVersions
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
@@ -48,8 +42,8 @@ from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.storage.data_stores.main import DataStore
from synapse.server import HomeServer
from synapse.storage.data_stores.main import DataStore
logger = logging.getLogger(__name__)
@@ -820,7 +814,6 @@ class PersistEventsStore:
"event_reference_hashes",
"event_search",
"event_to_state_groups",
"local_invites",
"state_events",
"rejections",
"redactions",
@@ -1197,65 +1190,27 @@ class PersistEventsStore:
(event.state_key,),
)
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened. If the event is an
# outlier it is only current if its an "out of band membership",
# like a remote invite or a rejection of a remote invite.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_out_of_band_membership()
)
is_mine = self.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self.db.simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
},
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(
sql,
(
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
),
)
# We also update the `local_current_membership` table with
# latest invite info. This will usually get updated by the
# `current_state_events` handling, unless its an outlier.
if event.internal_metadata.is_outlier():
# This should only happen for out of band memberships, so
# we add a paranoia check.
assert event.internal_metadata.is_out_of_band_membership()
self.db.simple_upsert_txn(
txn,
table="local_current_membership",
keyvalues={
"room_id": event.room_id,
"user_id": event.state_key,
},
values={
"event_id": event.event_id,
"membership": event.membership,
},
)
# We update the local_current_membership table only if the event is
# "current", i.e., its something that has just happened.
#
# This will usually get updated by the `current_state_events` handling,
# unless its an outlier, and an outlier is only "current" if it's an "out of
# band membership", like a remote invite or a rejection of a remote invite.
if (
self.is_mine_id(event.state_key)
and not backfilled
and event.internal_metadata.is_outlier()
and event.internal_metadata.is_out_of_band_membership()
):
self.db.simple_upsert_txn(
txn,
table="local_current_membership",
keyvalues={"room_id": event.room_id, "user_id": event.state_key},
values={
"event_id": event.event_id,
"membership": event.membership,
},
)
def _handle_event_relations(self, txn, event):
"""Handles inserting relation data during peristence of events
@@ -1586,31 +1541,3 @@ class PersistEventsStore:
if not ev.internal_metadata.is_outlier()
],
)
async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
"""Mark the invite has having been rejected even though we failed to
create a leave event for it.
"""
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
# We also clear this entry from `local_current_membership`.
# Ideally we'd point to a leave event, but we don't have one, so
# nevermind.
self.db.simple_delete_txn(
txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
)
with self._stream_id_gen.get_next() as stream_ordering:
await self.db.runInteraction("locally_reject_invite", f, stream_ordering)
return stream_ordering
@@ -38,6 +38,8 @@ from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
@@ -80,10 +82,7 @@ class EventsWorkerStore(SQLBaseStore):
# We are the process in charge of generating stream ids for events,
# so instantiate ID generators based on the database
self._stream_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
extra_tables=[("local_invites", "stream_id")],
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
@@ -113,9 +112,9 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_ongoing = 0
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == "events":
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(token)
elif stream_name == "backfill":
elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(-token)
super().process_replication_rows(stream_name, instance_name, token, rows)
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
from canonicaljson import json
from twisted.internet import defer
@@ -526,13 +528,35 @@ class GroupServerWorkerStore(SQLBaseStore):
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
)
def get_all_groups_changes(self, from_token, to_token, limit):
from_token = int(from_token)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(
from_token
)
async def get_all_groups_changes(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for groups replication stream.
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
last_id = int(last_id)
has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
if not has_changed:
return defer.succeed([])
return [], current_id, False
def _get_all_groups_changes_txn(txn):
sql = """
@@ -541,13 +565,21 @@ class GroupServerWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
LIMIT ?
"""
txn.execute(sql, (from_token, to_token, limit))
return [
(stream_id, group_id, user_id, gtype, json.loads(content_json))
txn.execute(sql, (last_id, current_id, limit))
updates = [
(stream_id, (group_id, user_id, gtype, json.loads(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
]
return self.db.runInteraction(
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
return updates, upto_token, limited
return await self.db.runInteraction(
"get_all_groups_changes", _get_all_groups_changes_txn
)
@@ -361,7 +361,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_push_summary",
"pusher_throttle",
"group_summary_rooms",
"local_invites",
"room_account_data",
"room_tags",
"local_current_membership",
+50 -58
View File
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
from typing import Iterable, Iterator
from typing import Iterable, Iterator, List, Tuple
from canonicaljson import encode_canonical_json, json
@@ -98,77 +98,69 @@ class PusherWorkerStore(SQLBaseStore):
rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
return rows
def get_all_updated_pushers(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed(([], []))
async def get_all_updated_pushers_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for pushers replication stream.
def get_all_updated_pushers_txn(txn):
sql = (
"SELECT id, user_name, access_token, profile_tag, kind,"
" app_id, app_display_name, device_display_name, pushkey, ts,"
" lang, data"
" FROM pushers"
" WHERE ? < id AND id <= ?"
" ORDER BY id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
updated = txn.fetchall()
sql = (
"SELECT stream_id, user_id, app_id, pushkey"
" FROM deleted_pushers"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
deleted = txn.fetchall()
return updated, deleted
return self.db.runInteraction(
"get_all_updated_pushers", get_all_updated_pushers_txn
)
def get_all_updated_pushers_rows(self, last_id, current_id, limit):
"""Get all the pushers that have changed between the given tokens.
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
Deferred(list(tuple)): each tuple consists of:
stream_id (str)
user_id (str)
app_id (str)
pushkey (str)
was_deleted (bool): whether the pusher was added/updated (False)
or deleted (True)
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
return defer.succeed([])
return [], current_id, False
def get_all_updated_pushers_rows_txn(txn):
sql = (
"SELECT id, user_name, app_id, pushkey"
" FROM pushers"
" WHERE ? < id AND id <= ?"
" ORDER BY id ASC LIMIT ?"
)
sql = """
SELECT id, user_name, app_id, pushkey
FROM pushers
WHERE ? < id AND id <= ?
ORDER BY id ASC LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
results = [list(row) + [False] for row in txn]
updates = [
(stream_id, (user_name, app_id, pushkey, False))
for stream_id, user_name, app_id, pushkey in txn
]
sql = (
"SELECT stream_id, user_id, app_id, pushkey"
" FROM deleted_pushers"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
sql = """
SELECT stream_id, user_id, app_id, pushkey
FROM deleted_pushers
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
"""
txn.execute(sql, (last_id, current_id, limit))
updates.extend(
(stream_id, (user_name, app_id, pushkey, True))
for stream_id, user_name, app_id, pushkey in txn
)
results.extend(list(row) + [True] for row in txn)
results.sort() # Sort so that they're ordered by stream id
updates.sort() # Sort so that they're ordered by stream id
return results
limited = False
upper_bound = current_id
if len(updates) >= limit:
limited = True
upper_bound = updates[-1][0]
return self.db.runInteraction(
return updates, upper_bound, limited
return await self.db.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
+35 -6
View File
@@ -803,7 +803,32 @@ class RoomWorkerStore(SQLBaseStore):
return total_media_quarantined
def get_all_new_public_rooms(self, prev_id, current_id, limit):
async def get_all_new_public_rooms(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for public rooms replication stream.
Args:
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
return [], current_id, False
def get_all_new_public_rooms(txn):
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
@@ -813,13 +838,17 @@ class RoomWorkerStore(SQLBaseStore):
LIMIT ?
"""
txn.execute(sql, (prev_id, current_id, limit))
return txn.fetchall()
txn.execute(sql, (last_id, current_id, limit))
updates = [(row[0], row[1:]) for row in txn]
limited = False
upto_token = current_id
if len(updates) >= limit:
upto_token = updates[-1][0]
limited = True
if prev_id == current_id:
return defer.succeed([])
return updates, upto_token, limited
return self.db.runInteraction(
return await self.db.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import simplejson
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.prepare_database import get_statements
@@ -66,7 +64,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
progress_json = simplejson.dumps(progress)
progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import simplejson
from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
@@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
progress_json = simplejson.dumps(progress)
progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import simplejson
from synapse.storage.engines import PostgresEngine
from synapse.storage.prepare_database import get_statements
@@ -50,7 +48,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"rows_inserted": 0,
"have_added_indexes": False,
}
progress_json = simplejson.dumps(progress)
progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
@@ -11,11 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import simplejson
from synapse.storage.prepare_database import get_statements
logger = logging.getLogger(__name__)
@@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
progress_json = simplejson.dumps(progress)
progress_json = json.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
+33 -12
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import List, Tuple
from canonicaljson import json
@@ -53,18 +54,32 @@ class TagsWorkerStore(AccountDataWorkerStore):
return deferred
@defer.inlineCallbacks
def get_all_updated_tags(self, last_id, current_id, limit):
"""Get all the client tags that have changed on the server
async def get_all_updated_tags(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
"""Get updates for tags replication stream.
Args:
last_id(int): The position to fetch from.
current_id(int): The position to fetch up to.
instance_name: The writer we want to fetch updates from. Unused
here since there is only ever one writer.
last_id: The token to fetch updates from. Exclusive.
current_id: The token to fetch updates up to. Inclusive.
limit: The requested limit for the number of rows to return. The
function may return more or fewer rows.
Returns:
A deferred list of tuples of stream_id int, user_id string,
room_id string, tag string and content string.
A tuple consisting of: the updates, a token to use to fetch
subsequent updates, and whether we returned fewer rows than exists
between the requested tokens due to the limit.
The token returned can be used in a subsequent call to this
function to get further updatees.
The updates are a list of 2-tuples of stream ID and the row data
"""
if last_id == current_id:
return []
return [], current_id, False
def get_all_updated_tags_txn(txn):
sql = (
@@ -76,7 +91,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
tag_ids = yield self.db.runInteraction(
tag_ids = await self.db.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn
)
@@ -89,21 +104,27 @@ class TagsWorkerStore(AccountDataWorkerStore):
for tag, content in txn:
tags.append(json.dumps(tag) + ":" + content)
tag_json = "{" + ",".join(tags) + "}"
results.append((stream_id, user_id, room_id, tag_json))
results.append((stream_id, (user_id, room_id, tag_json)))
return results
batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
tags = yield self.db.runInteraction(
tags = await self.db.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
)
results.extend(tags)
return results
limited = False
upto_token = current_id
if len(results) >= limit:
upto_token = results[-1][0]
limited = True
return results, upto_token, limited
@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
+1 -1
View File
@@ -17,10 +17,10 @@ from typing import Any, Dict, Optional, Union
import attr
import synapse.util.stringutils as stringutils
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.types import JsonDict
from synapse.util import stringutils as stringutils
@attr.s
+2 -2
View File
@@ -92,10 +92,10 @@ class PostgresEngine(BaseDatabaseEngine):
errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
if ctype != "C":
errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (ctype,))
if errors:
raise IncorrectDatabaseSetup(
logger.warning(
"Database is incorrectly configured:\n\n%s\n\n"
"See docs/postgres.md for more information." % ("\n".join(errors))
)
-6
View File
@@ -783,9 +783,3 @@ class EventsPersistenceStorage(object):
for user_id in left_users:
await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
"""Mark the invite has having been rejected even though we failed to
create a leave event for it.
"""
return await self.persist_events_store.locally_reject_invite(user_id, room_id)
-2
View File
@@ -12,12 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Iterable, Iterator, List, Tuple
from typing_extensions import Protocol
"""
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
"""

Some files were not shown because too many files have changed in this diff Show More