Compare commits
38 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 311c15dd4f | |||
| 6d174fda89 | |||
| 75da9f7a8e | |||
| 7078251569 | |||
| 29df3d0e9f | |||
| 8ccb7f08d9 | |||
| 43726783e4 | |||
| 38e1fac886 | |||
| 53ee214f2f | |||
| 8ca39bd2c3 | |||
| 08c5181a8d | |||
| 8fa7fdd4cb | |||
| 2ab0b021f1 | |||
| 67593b1728 | |||
| ef5ed5292b | |||
| e7efd8f827 | |||
| ff0680f69d | |||
| e0c0129693 | |||
| 59ddcd790b | |||
| 96bb01d8ec | |||
| 76dbd7b8d6 | |||
| 67d7756fcf | |||
| d378c3da78 | |||
| 2a266f4511 | |||
| 6d687ebba1 | |||
| 57feeab364 | |||
| 4e118742ca | |||
| 62b1ce8539 | |||
| 5cdca53aa0 | |||
| 21a212f8e5 | |||
| 8097659f6e | |||
| f3e0f16240 | |||
| 4d978d7db4 | |||
| e5808c4cfb | |||
| e866512367 | |||
| f01e2ca039 | |||
| a6eae69ffe | |||
| 244dbb04f7 |
+54
-3
@@ -1,9 +1,13 @@
|
||||
Synapse 1.17.0 (2020-07-13)
|
||||
===========================
|
||||
|
||||
Synapse 1.17.0 is identical to 1.17.0rc1, with the addition of the fix that was included in 1.16.1.
|
||||
|
||||
|
||||
Synapse 1.16.1 (2020-07-10)
|
||||
===========================
|
||||
|
||||
In some distributions of Synapse 1.16.0, we incorrectly included a database
|
||||
migration which added a new, unused table. This release removes the redundant
|
||||
table.
|
||||
In some distributions of Synapse 1.16.0, we incorrectly included a database migration which added a new, unused table. This release removes the redundant table.
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
@@ -11,6 +15,53 @@ Bugfixes
|
||||
- Drop table `local_rejections_stream` which was incorrectly added in Synapse 1.16.0. ([\#7816](https://github.com/matrix-org/synapse/issues/7816), [b1beb3ff5](https://github.com/matrix-org/synapse/commit/b1beb3ff5))
|
||||
|
||||
|
||||
Synapse 1.17.0rc1 (2020-07-09)
|
||||
==============================
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix inconsistent handling of upper and lower case in email addresses when used as identifiers for login, etc. Contributed by @dklimpel. ([\#7021](https://github.com/matrix-org/synapse/issues/7021))
|
||||
- Fix "Tried to close a non-active scope!" error messages when opentracing is enabled. ([\#7732](https://github.com/matrix-org/synapse/issues/7732))
|
||||
- Fix incorrect error message when database CTYPE was set incorrectly. ([\#7760](https://github.com/matrix-org/synapse/issues/7760))
|
||||
- Fix to not ignore `set_tweak` actions in Push Rules that have no `value`, as permitted by the specification. ([\#7766](https://github.com/matrix-org/synapse/issues/7766))
|
||||
- Fix synctl to handle empty config files correctly. Contributed by @kotovalexarian. ([\#7779](https://github.com/matrix-org/synapse/issues/7779))
|
||||
- Fixes a long standing bug in worker mode where worker information was saved in the devices table instead of the original IP address and user agent. ([\#7797](https://github.com/matrix-org/synapse/issues/7797))
|
||||
- Fix 'stuck invites' which happen when we are unable to reject a room invite received over federation. ([\#7804](https://github.com/matrix-org/synapse/issues/7804), [\#7809](https://github.com/matrix-org/synapse/issues/7809), [\#7810](https://github.com/matrix-org/synapse/issues/7810))
|
||||
|
||||
|
||||
Updates to the Docker image
|
||||
---------------------------
|
||||
|
||||
- Include libwebp in the Docker file to properly handle webp image uploads. ([\#7791](https://github.com/matrix-org/synapse/issues/7791))
|
||||
|
||||
|
||||
Improved Documentation
|
||||
----------------------
|
||||
|
||||
- Improve the documentation of the non-standard JSON web token login type. ([\#7776](https://github.com/matrix-org/synapse/issues/7776))
|
||||
- Update doc links for caddy. Contributed by Nicolai Søborg. ([\#7789](https://github.com/matrix-org/synapse/issues/7789))
|
||||
|
||||
|
||||
Internal Changes
|
||||
----------------
|
||||
|
||||
- Refactor getting replication updates from database. ([\#7740](https://github.com/matrix-org/synapse/issues/7740))
|
||||
- Send push notifications with a high or low priority depending upon whether they may generate user-observable effects. ([\#7765](https://github.com/matrix-org/synapse/issues/7765))
|
||||
- Use symbolic names for replication stream names. ([\#7768](https://github.com/matrix-org/synapse/issues/7768))
|
||||
- Add early returns to `_check_for_soft_fail`. ([\#7769](https://github.com/matrix-org/synapse/issues/7769))
|
||||
- Fix up `synapse.handlers.federation` to pass mypy. ([\#7770](https://github.com/matrix-org/synapse/issues/7770))
|
||||
- Convert the appserver handler to async/await. ([\#7775](https://github.com/matrix-org/synapse/issues/7775))
|
||||
- Allow to use higher versions of prometheus_client <0.9.0 which are expected to introduce no breaking changes. Contributed by Oliver Kurz. ([\#7780](https://github.com/matrix-org/synapse/issues/7780))
|
||||
- Update linting scripts and codebase to be compatible with `isort` v5. ([\#7786](https://github.com/matrix-org/synapse/issues/7786))
|
||||
- Stop populating unused table `local_invites`. ([\#7793](https://github.com/matrix-org/synapse/issues/7793))
|
||||
- Ensure that strings (not bytes) are passed into JSON serialization. ([\#7799](https://github.com/matrix-org/synapse/issues/7799))
|
||||
- Switch from simplejson to the standard library json. ([\#7800](https://github.com/matrix-org/synapse/issues/7800))
|
||||
- Add `signing_key` property to `HomeServer` to save code duplication. ([\#7805](https://github.com/matrix-org/synapse/issues/7805))
|
||||
- Improve stacktraces from exceptions in background processes. ([\#7808](https://github.com/matrix-org/synapse/issues/7808))
|
||||
- Fix various spelling errors in comments and log lines. ([\#7811](https://github.com/matrix-org/synapse/issues/7811))
|
||||
|
||||
|
||||
Synapse 1.16.0 (2020-07-08)
|
||||
===========================
|
||||
|
||||
|
||||
+1
-1
@@ -215,7 +215,7 @@ Using a reverse proxy with Synapse
|
||||
It is recommended to put a reverse proxy such as
|
||||
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
|
||||
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
|
||||
`Caddy <https://caddyserver.com/docs/proxy>`_ or
|
||||
`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_ or
|
||||
`HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
|
||||
doing so is that it means that you can expose the default https port (443) to
|
||||
Matrix clients without needing to run Synapse with root privileges.
|
||||
|
||||
Vendored
+12
@@ -1,9 +1,21 @@
|
||||
matrix-synapse-py3 (1.17.0) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.17.0.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Mon, 13 Jul 2020 10:20:31 +0100
|
||||
|
||||
matrix-synapse-py3 (1.16.1) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.16.1.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Fri, 10 Jul 2020 12:09:24 +0100
|
||||
|
||||
matrix-synapse-py3 (1.17.0rc1) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.17.0rc1.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Thu, 09 Jul 2020 16:53:12 +0100
|
||||
|
||||
matrix-synapse-py3 (1.16.0) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.16.0.
|
||||
|
||||
@@ -24,6 +24,7 @@ RUN apk add \
|
||||
build-base \
|
||||
libffi-dev \
|
||||
libjpeg-turbo-dev \
|
||||
libwebp-dev \
|
||||
libressl-dev \
|
||||
libxslt-dev \
|
||||
linux-headers \
|
||||
@@ -61,6 +62,7 @@ FROM docker.io/python:${PYTHON_VERSION}-alpine3.11
|
||||
RUN apk add --no-cache --virtual .runtime_deps \
|
||||
libffi \
|
||||
libjpeg-turbo \
|
||||
libwebp \
|
||||
libressl \
|
||||
libxslt \
|
||||
libpq \
|
||||
|
||||
+90
@@ -0,0 +1,90 @@
|
||||
# JWT Login Type
|
||||
|
||||
Synapse comes with a non-standard login type to support
|
||||
[JSON Web Tokens](https://en.wikipedia.org/wiki/JSON_Web_Token). In general the
|
||||
documentation for
|
||||
[the login endpoint](https://matrix.org/docs/spec/client_server/r0.6.1#login)
|
||||
is still valid (and the mechanism works similarly to the
|
||||
[token based login](https://matrix.org/docs/spec/client_server/r0.6.1#token-based)).
|
||||
|
||||
To log in using a JSON Web Token, clients should submit a `/login` request as
|
||||
follows:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "org.matrix.login.jwt",
|
||||
"token": "<jwt>"
|
||||
}
|
||||
```
|
||||
|
||||
Note that the login type of `m.login.jwt` is supported, but is deprecated. This
|
||||
will be removed in a future version of Synapse.
|
||||
|
||||
The `jwt` should encode the local part of the user ID as the standard `sub`
|
||||
claim. In the case that the token is not valid, the homeserver must respond with
|
||||
`401 Unauthorized` and an error code of `M_UNAUTHORIZED`.
|
||||
|
||||
(Note that this differs from the token based logins which return a
|
||||
`403 Forbidden` and an error code of `M_FORBIDDEN` if an error occurs.)
|
||||
|
||||
As with other login types, there are additional fields (e.g. `device_id` and
|
||||
`initial_device_display_name`) which can be included in the above request.
|
||||
|
||||
## Preparing Synapse
|
||||
|
||||
The JSON Web Token integration in Synapse uses the
|
||||
[`PyJWT`](https://pypi.org/project/pyjwt/) library, which must be installed
|
||||
as follows:
|
||||
|
||||
* The relevant libraries are included in the Docker images and Debian packages
|
||||
provided by `matrix.org` so no further action is needed.
|
||||
|
||||
* If you installed Synapse into a virtualenv, run `/path/to/env/bin/pip
|
||||
install synapse[pyjwt]` to install the necessary dependencies.
|
||||
|
||||
* For other installation mechanisms, see the documentation provided by the
|
||||
maintainer.
|
||||
|
||||
To enable the JSON web token integration, you should then add an `jwt_config` section
|
||||
to your configuration file (or uncomment the `enabled: true` line in the
|
||||
existing section). See [sample_config.yaml](./sample_config.yaml) for some
|
||||
sample settings.
|
||||
|
||||
## How to test JWT as a developer
|
||||
|
||||
Although JSON Web Tokens are typically generated from an external server, the
|
||||
examples below use [PyJWT](https://pyjwt.readthedocs.io/en/latest/) directly.
|
||||
|
||||
1. Configure Synapse with JWT logins:
|
||||
|
||||
```yaml
|
||||
jwt_config:
|
||||
enabled: true
|
||||
secret: "my-secret-token"
|
||||
algorithm: "HS256"
|
||||
```
|
||||
2. Generate a JSON web token:
|
||||
|
||||
```bash
|
||||
$ pyjwt --key=my-secret-token --alg=HS256 encode sub=test-user
|
||||
eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXIifQ.Ag71GT8v01UO3w80aqRPTeuVPBIBZkYhNTJJ-_-zQIc
|
||||
```
|
||||
3. Query for the login types and ensure `org.matrix.login.jwt` is there:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/_matrix/client/r0/login
|
||||
```
|
||||
4. Login used the generated JSON web token from above:
|
||||
|
||||
```bash
|
||||
$ curl http://localhost:8082/_matrix/client/r0/login -X POST \
|
||||
--data '{"type":"org.matrix.login.jwt","token":"eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ0ZXN0LXVzZXIifQ.Ag71GT8v01UO3w80aqRPTeuVPBIBZkYhNTJJ-_-zQIc"}'
|
||||
{
|
||||
"access_token": "<access token>",
|
||||
"device_id": "ACBDEFGHI",
|
||||
"home_server": "localhost:8080",
|
||||
"user_id": "@test-user:localhost:8480"
|
||||
}
|
||||
```
|
||||
|
||||
You should now be able to use the returned access token to query the client API.
|
||||
@@ -3,7 +3,7 @@
|
||||
It is recommended to put a reverse proxy such as
|
||||
[nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html),
|
||||
[Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html),
|
||||
[Caddy](https://caddyserver.com/docs/proxy) or
|
||||
[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or
|
||||
[HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage
|
||||
of doing so is that it means that you can expose the default https port
|
||||
(443) to Matrix clients without needing to run Synapse with root
|
||||
|
||||
+31
-4
@@ -1804,12 +1804,39 @@ sso:
|
||||
#template_dir: "res/templates"
|
||||
|
||||
|
||||
# The JWT needs to contain a globally unique "sub" (subject) claim.
|
||||
# JSON web token integration. The following settings can be used to make
|
||||
# Synapse JSON web tokens for authentication, instead of its internal
|
||||
# password database.
|
||||
#
|
||||
# Each JSON Web Token needs to contain a "sub" (subject) claim, which is
|
||||
# used as the localpart of the mxid.
|
||||
#
|
||||
# Note that this is a non-standard login type and client support is
|
||||
# expected to be non-existant.
|
||||
#
|
||||
# See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md.
|
||||
#
|
||||
#jwt_config:
|
||||
# enabled: true
|
||||
# secret: "a secret"
|
||||
# algorithm: "HS256"
|
||||
# Uncomment the following to enable authorization using JSON web
|
||||
# tokens. Defaults to false.
|
||||
#
|
||||
#enabled: true
|
||||
|
||||
# This is either the private shared secret or the public key used to
|
||||
# decode the contents of the JSON web token.
|
||||
#
|
||||
# Required if 'enabled' is true.
|
||||
#
|
||||
#secret: "provided-by-your-issuer"
|
||||
|
||||
# The algorithm used to sign the JSON web token.
|
||||
#
|
||||
# Supported algorithms are listed at
|
||||
# https://pyjwt.readthedocs.io/en/latest/algorithms.html
|
||||
#
|
||||
# Required if 'enabled' is true.
|
||||
#
|
||||
#algorithm: "provided-by-your-issuer"
|
||||
|
||||
|
||||
password_config:
|
||||
|
||||
@@ -2,9 +2,9 @@ import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import urllib2
|
||||
|
||||
import dns.resolver
|
||||
import urllib2
|
||||
from signedjson.key import decode_verify_key_bytes, write_signing_keys
|
||||
from signedjson.sign import verify_signed_json
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
+1
-1
@@ -15,7 +15,7 @@ else
|
||||
fi
|
||||
|
||||
echo "Linting these locations: $files"
|
||||
isort -y -rc $files
|
||||
isort $files
|
||||
python3 -m black $files
|
||||
./scripts-dev/config-lint.sh
|
||||
flake8 $files
|
||||
|
||||
@@ -26,7 +26,6 @@ ignore=W503,W504,E203,E731,E501
|
||||
|
||||
[isort]
|
||||
line_length = 88
|
||||
not_skip = __init__.py
|
||||
sections=FUTURE,STDLIB,COMPAT,THIRDPARTY,TWISTED,FIRSTPARTY,TESTS,LOCALFOLDER
|
||||
default_section=THIRDPARTY
|
||||
known_first_party = synapse
|
||||
|
||||
+1
-1
@@ -36,7 +36,7 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.16.1"
|
||||
__version__ = "1.17.0"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
||||
+2
-3
@@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
@@ -22,7 +21,6 @@ from netaddr import IPAddress
|
||||
from twisted.internet import defer
|
||||
from twisted.web.server import Request
|
||||
|
||||
import synapse.logging.opentracing as opentracing
|
||||
import synapse.types
|
||||
from synapse import event_auth
|
||||
from synapse.api.auth_blocking import AuthBlocking
|
||||
@@ -35,6 +33,7 @@ from synapse.api.errors import (
|
||||
)
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging import opentracing as opentracing
|
||||
from synapse.types import StateMap, UserID
|
||||
from synapse.util.caches import register_cache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
@@ -538,7 +537,7 @@ class Auth(object):
|
||||
# Currently we ignore the `for_verification` flag even though there are
|
||||
# some situations where we can drop particular auth events when adding
|
||||
# to the event's `auth_events` (e.g. joins pointing to previous joins
|
||||
# when room is publically joinable). Dropping event IDs has the
|
||||
# when room is publicly joinable). Dropping event IDs has the
|
||||
# advantage that the auth chain for the room grows slower, but we use
|
||||
# the auth chain in state resolution v2 to order events, which means
|
||||
# care must be taken if dropping events to ensure that it doesn't
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
|
||||
|
||||
from typing_extensions import ContextManager
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import address, defer, reactor
|
||||
|
||||
import synapse
|
||||
import synapse.events
|
||||
@@ -206,10 +206,30 @@ class KeyUploadServlet(RestServlet):
|
||||
|
||||
if body:
|
||||
# They're actually trying to upload something, proxy to main synapse.
|
||||
# Pass through the auth headers, if any, in case the access token
|
||||
# is there.
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
|
||||
headers = {"Authorization": auth_headers}
|
||||
|
||||
# Proxy headers from the original request, such as the auth headers
|
||||
# (in case the access token is there) and the original IP /
|
||||
# User-Agent of the request.
|
||||
headers = {
|
||||
header: request.requestHeaders.getRawHeaders(header, [])
|
||||
for header in (b"Authorization", b"User-Agent")
|
||||
}
|
||||
# Add the previous hop the the X-Forwarded-For header.
|
||||
x_forwarded_for = request.requestHeaders.getRawHeaders(
|
||||
b"X-Forwarded-For", []
|
||||
)
|
||||
if isinstance(request.client, (address.IPv4Address, address.IPv6Address)):
|
||||
previous_host = request.client.host.encode("ascii")
|
||||
# If the header exists, add to the comma-separated list of the first
|
||||
# instance of the header. Otherwise, generate a new header.
|
||||
if x_forwarded_for:
|
||||
x_forwarded_for = [
|
||||
x_forwarded_for[0] + b", " + previous_host
|
||||
] + x_forwarded_for[1:]
|
||||
else:
|
||||
x_forwarded_for = [previous_host]
|
||||
headers[b"X-Forwarded-For"] = x_forwarded_for
|
||||
|
||||
try:
|
||||
result = await self.http_client.post_json_get_json(
|
||||
self.main_uri + request.uri.decode("ascii"), body, headers=headers
|
||||
|
||||
@@ -98,7 +98,6 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
if service.url is None:
|
||||
return False
|
||||
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
|
||||
response = None
|
||||
try:
|
||||
response = yield self.get_json(uri, {"access_token": service.hs_token})
|
||||
if response is not None: # just an empty json object
|
||||
|
||||
@@ -16,6 +16,7 @@ from synapse.config._base import ConfigError
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
|
||||
action = sys.argv[1]
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
# This file can't be called email.py because if it is, we cannot:
|
||||
@@ -73,7 +72,7 @@ class EmailConfig(Config):
|
||||
|
||||
template_dir = email_config.get("template_dir")
|
||||
# we need an absolute path, because we change directory after starting (and
|
||||
# we don't yet know what auxilliary templates like mail.css we will need).
|
||||
# we don't yet know what auxiliary templates like mail.css we will need).
|
||||
# (Note that loading as package_resources with jinja.PackageLoader doesn't
|
||||
# work for the same reason.)
|
||||
if not template_dir:
|
||||
@@ -145,8 +144,8 @@ class EmailConfig(Config):
|
||||
or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
|
||||
):
|
||||
# make sure we can import the required deps
|
||||
import jinja2
|
||||
import bleach
|
||||
import jinja2
|
||||
|
||||
# prevent unused warnings
|
||||
jinja2
|
||||
|
||||
@@ -45,10 +45,37 @@ class JWTConfig(Config):
|
||||
|
||||
def generate_config_section(self, **kwargs):
|
||||
return """\
|
||||
# The JWT needs to contain a globally unique "sub" (subject) claim.
|
||||
# JSON web token integration. The following settings can be used to make
|
||||
# Synapse JSON web tokens for authentication, instead of its internal
|
||||
# password database.
|
||||
#
|
||||
# Each JSON Web Token needs to contain a "sub" (subject) claim, which is
|
||||
# used as the localpart of the mxid.
|
||||
#
|
||||
# Note that this is a non-standard login type and client support is
|
||||
# expected to be non-existant.
|
||||
#
|
||||
# See https://github.com/matrix-org/synapse/blob/master/docs/jwt.md.
|
||||
#
|
||||
#jwt_config:
|
||||
# enabled: true
|
||||
# secret: "a secret"
|
||||
# algorithm: "HS256"
|
||||
# Uncomment the following to enable authorization using JSON web
|
||||
# tokens. Defaults to false.
|
||||
#
|
||||
#enabled: true
|
||||
|
||||
# This is either the private shared secret or the public key used to
|
||||
# decode the contents of the JSON web token.
|
||||
#
|
||||
# Required if 'enabled' is true.
|
||||
#
|
||||
#secret: "provided-by-your-issuer"
|
||||
|
||||
# The algorithm used to sign the JSON web token.
|
||||
#
|
||||
# Supported algorithms are listed at
|
||||
# https://pyjwt.readthedocs.io/en/latest/algorithms.html
|
||||
#
|
||||
# Required if 'enabled' is true.
|
||||
#
|
||||
#algorithm: "provided-by-your-issuer"
|
||||
"""
|
||||
|
||||
@@ -162,7 +162,7 @@ class EventBuilderFactory(object):
|
||||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
self.hostname = hs.hostname
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.signing_key = hs.signing_key
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
@@ -87,7 +87,7 @@ class FederationClient(FederationBase):
|
||||
self.transport_layer = hs.get_federation_transport_client()
|
||||
|
||||
self.hostname = hs.hostname
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.signing_key = hs.signing_key
|
||||
|
||||
self._get_pdu_cache = ExpiringCache(
|
||||
cache_name="get_pdu_cache",
|
||||
@@ -245,7 +245,7 @@ class FederationClient(FederationBase):
|
||||
event_id: event to fetch
|
||||
room_version: version of the room
|
||||
outlier: Indicates whether the PDU is an `outlier`, i.e. if
|
||||
it's from an arbitary point in the context as opposed to part
|
||||
it's from an arbitrary point in the context as opposed to part
|
||||
of the current block of PDUs. Defaults to `False`
|
||||
timeout: How long to try (in ms) each destination for before
|
||||
moving to the next destination. None indicates no timeout.
|
||||
@@ -351,7 +351,7 @@ class FederationClient(FederationBase):
|
||||
outlier: bool = False,
|
||||
include_none: bool = False,
|
||||
) -> List[EventBase]:
|
||||
"""Takes a list of PDUs and checks the signatures and hashs of each
|
||||
"""Takes a list of PDUs and checks the signatures and hashes of each
|
||||
one. If a PDU fails its signature check then we check if we have it in
|
||||
the database and if not then request if from the originating server of
|
||||
that PDU.
|
||||
|
||||
@@ -95,6 +95,9 @@ class FederationServer(FederationBase):
|
||||
# We cache responses to state queries, as they take a while and often
|
||||
# come in waves.
|
||||
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
|
||||
self._state_ids_resp_cache = ResponseCache(
|
||||
hs, "state_ids_resp", timeout_ms=30000
|
||||
)
|
||||
|
||||
async def on_backfill_request(
|
||||
self, origin: str, room_id: str, versions: List[str], limit: int
|
||||
@@ -362,10 +365,16 @@ class FederationServer(FederationBase):
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
resp = await self._state_ids_resp_cache.wrap(
|
||||
(room_id, event_id), self._on_state_ids_request_compute, room_id, event_id,
|
||||
)
|
||||
|
||||
return 200, resp
|
||||
|
||||
async def _on_state_ids_request_compute(self, room_id, event_id):
|
||||
state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
|
||||
auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
|
||||
|
||||
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
||||
return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
|
||||
|
||||
async def _on_context_state_request_compute(
|
||||
self, room_id: str, event_id: str
|
||||
@@ -717,7 +726,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
|
||||
# server name is a literal IP
|
||||
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
|
||||
if not isinstance(allow_ip_literals, bool):
|
||||
logger.warning("Ignorning non-bool allow_ip_literals flag")
|
||||
logger.warning("Ignoring non-bool allow_ip_literals flag")
|
||||
allow_ip_literals = True
|
||||
if not allow_ip_literals:
|
||||
# check for ipv6 literals. These start with '['.
|
||||
@@ -731,7 +740,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
|
||||
# next, check the deny list
|
||||
deny = acl_event.content.get("deny", [])
|
||||
if not isinstance(deny, (list, tuple)):
|
||||
logger.warning("Ignorning non-list deny ACL %s", deny)
|
||||
logger.warning("Ignoring non-list deny ACL %s", deny)
|
||||
deny = []
|
||||
for e in deny:
|
||||
if _acl_entry_matches(server_name, e):
|
||||
@@ -741,7 +750,7 @@ def server_matches_acl_event(server_name: str, acl_event: EventBase) -> bool:
|
||||
# then the allow list.
|
||||
allow = acl_event.content.get("allow", [])
|
||||
if not isinstance(allow, (list, tuple)):
|
||||
logger.warning("Ignorning non-list allow ACL %s", allow)
|
||||
logger.warning("Ignoring non-list allow ACL %s", allow)
|
||||
allow = []
|
||||
for e in allow:
|
||||
if _acl_entry_matches(server_name, e):
|
||||
|
||||
@@ -359,7 +359,7 @@ class BaseFederationRow(object):
|
||||
Specifies how to identify, serialize and deserialize the different types.
|
||||
"""
|
||||
|
||||
TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
|
||||
TypeId = "" # Unique string that ids the type. Must be overridden in sub classes.
|
||||
|
||||
@staticmethod
|
||||
def from_data(data):
|
||||
|
||||
@@ -119,7 +119,7 @@ class PerDestinationQueue(object):
|
||||
)
|
||||
|
||||
def send_pdu(self, pdu: EventBase, order: int) -> None:
|
||||
"""Add a PDU to the queue, and start the transmission loop if neccessary
|
||||
"""Add a PDU to the queue, and start the transmission loop if necessary
|
||||
|
||||
Args:
|
||||
pdu: pdu to send
|
||||
@@ -129,7 +129,7 @@ class PerDestinationQueue(object):
|
||||
self.attempt_new_transaction()
|
||||
|
||||
def send_presence(self, states: Iterable[UserPresenceState]) -> None:
|
||||
"""Add presence updates to the queue. Start the transmission loop if neccessary.
|
||||
"""Add presence updates to the queue. Start the transmission loop if necessary.
|
||||
|
||||
Args:
|
||||
states: presence to send
|
||||
|
||||
@@ -746,7 +746,7 @@ class TransportLayerClient(object):
|
||||
def remove_user_from_group(
|
||||
self, destination, group_id, requester_user_id, user_id, content
|
||||
):
|
||||
"""Remove a user fron a group
|
||||
"""Remove a user from a group
|
||||
"""
|
||||
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ class Authenticator(object):
|
||||
self.server_name = hs.hostname
|
||||
self.store = hs.get_datastore()
|
||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||
self.notifer = hs.get_notifier()
|
||||
self.notifier = hs.get_notifier()
|
||||
|
||||
self.replication_client = None
|
||||
if hs.config.worker.worker_app:
|
||||
@@ -175,7 +175,7 @@ class Authenticator(object):
|
||||
await self.store.set_destination_retry_timings(origin, None, 0, 0)
|
||||
|
||||
# Inform the relevant places that the remote server is back up.
|
||||
self.notifer.notify_remote_server_up(origin)
|
||||
self.notifier.notify_remote_server_up(origin)
|
||||
if self.replication_client:
|
||||
# If we're on a worker we try and inform master about this. The
|
||||
# replication client doesn't hook into the notifier to avoid
|
||||
@@ -340,6 +340,12 @@ class BaseFederationServlet(object):
|
||||
if origin:
|
||||
with ratelimiter.ratelimit(origin) as d:
|
||||
await d
|
||||
if request._disconnected:
|
||||
logger.warning(
|
||||
"client disconnected before we started processing "
|
||||
"request"
|
||||
)
|
||||
return -1, None
|
||||
response = await func(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
@@ -361,11 +367,7 @@ class BaseFederationServlet(object):
|
||||
continue
|
||||
|
||||
server.register_paths(
|
||||
method,
|
||||
(pattern,),
|
||||
self._wrap(code),
|
||||
self.__class__.__name__,
|
||||
trace=False,
|
||||
method, (pattern,), self._wrap(code), self.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class GroupAttestationSigning(object):
|
||||
self.keyring = hs.get_keyring()
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.hostname
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.signing_key = hs.signing_key
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def verify_attestation(self, attestation, group_id, user_id, server_name=None):
|
||||
|
||||
@@ -41,7 +41,7 @@ class GroupsServerWorkerHandler(object):
|
||||
self.clock = hs.get_clock()
|
||||
self.keyring = hs.get_keyring()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.signing_key = hs.signing_key
|
||||
self.server_name = hs.hostname
|
||||
self.attestations = hs.get_groups_attestation_signing()
|
||||
self.transport_client = hs.get_federation_transport_client()
|
||||
|
||||
@@ -48,8 +48,7 @@ class ApplicationServicesHandler(object):
|
||||
self.current_max = 0
|
||||
self.is_processing = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify_interested_services(self, current_id):
|
||||
async def notify_interested_services(self, current_id):
|
||||
"""Notifies (pushes) all application services interested in this event.
|
||||
|
||||
Pushing is done asynchronously, so this method won't block for any
|
||||
@@ -74,7 +73,7 @@ class ApplicationServicesHandler(object):
|
||||
(
|
||||
upper_bound,
|
||||
events,
|
||||
) = yield self.store.get_new_events_for_appservice(
|
||||
) = await self.store.get_new_events_for_appservice(
|
||||
self.current_max, limit
|
||||
)
|
||||
|
||||
@@ -85,10 +84,9 @@ class ApplicationServicesHandler(object):
|
||||
for event in events:
|
||||
events_by_room.setdefault(event.room_id, []).append(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_event(event):
|
||||
async def handle_event(event):
|
||||
# Gather interested services
|
||||
services = yield self._get_services_for_event(event)
|
||||
services = await self._get_services_for_event(event)
|
||||
if len(services) == 0:
|
||||
return # no services need notifying
|
||||
|
||||
@@ -96,9 +94,9 @@ class ApplicationServicesHandler(object):
|
||||
# query API for all services which match that user regex.
|
||||
# This needs to block as these user queries need to be
|
||||
# made BEFORE pushing the event.
|
||||
yield self._check_user_exists(event.sender)
|
||||
await self._check_user_exists(event.sender)
|
||||
if event.type == EventTypes.Member:
|
||||
yield self._check_user_exists(event.state_key)
|
||||
await self._check_user_exists(event.state_key)
|
||||
|
||||
if not self.started_scheduler:
|
||||
|
||||
@@ -115,17 +113,16 @@ class ApplicationServicesHandler(object):
|
||||
self.scheduler.submit_event_for_as(service, event)
|
||||
|
||||
now = self.clock.time_msec()
|
||||
ts = yield self.store.get_received_ts(event.event_id)
|
||||
ts = await self.store.get_received_ts(event.event_id)
|
||||
synapse.metrics.event_processing_lag_by_event.labels(
|
||||
"appservice_sender"
|
||||
).observe((now - ts) / 1000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_room_events(events):
|
||||
async def handle_room_events(events):
|
||||
for event in events:
|
||||
yield handle_event(event)
|
||||
await handle_event(event)
|
||||
|
||||
yield make_deferred_yieldable(
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(handle_room_events, evs)
|
||||
@@ -135,10 +132,10 @@ class ApplicationServicesHandler(object):
|
||||
)
|
||||
)
|
||||
|
||||
yield self.store.set_appservice_last_pos(upper_bound)
|
||||
await self.store.set_appservice_last_pos(upper_bound)
|
||||
|
||||
now = self.clock.time_msec()
|
||||
ts = yield self.store.get_received_ts(events[-1].event_id)
|
||||
ts = await self.store.get_received_ts(events[-1].event_id)
|
||||
|
||||
synapse.metrics.event_processing_positions.labels(
|
||||
"appservice_sender"
|
||||
@@ -161,8 +158,7 @@ class ApplicationServicesHandler(object):
|
||||
finally:
|
||||
self.is_processing = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_user_exists(self, user_id):
|
||||
async def query_user_exists(self, user_id):
|
||||
"""Check if any application service knows this user_id exists.
|
||||
|
||||
Args:
|
||||
@@ -170,15 +166,14 @@ class ApplicationServicesHandler(object):
|
||||
Returns:
|
||||
True if this user exists on at least one application service.
|
||||
"""
|
||||
user_query_services = yield self._get_services_for_user(user_id=user_id)
|
||||
user_query_services = self._get_services_for_user(user_id=user_id)
|
||||
for user_service in user_query_services:
|
||||
is_known_user = yield self.appservice_api.query_user(user_service, user_id)
|
||||
is_known_user = await self.appservice_api.query_user(user_service, user_id)
|
||||
if is_known_user:
|
||||
return True
|
||||
return False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_room_alias_exists(self, room_alias):
|
||||
async def query_room_alias_exists(self, room_alias):
|
||||
"""Check if an application service knows this room alias exists.
|
||||
|
||||
Args:
|
||||
@@ -193,19 +188,18 @@ class ApplicationServicesHandler(object):
|
||||
s for s in services if (s.is_interested_in_alias(room_alias_str))
|
||||
]
|
||||
for alias_service in alias_query_services:
|
||||
is_known_alias = yield self.appservice_api.query_alias(
|
||||
is_known_alias = await self.appservice_api.query_alias(
|
||||
alias_service, room_alias_str
|
||||
)
|
||||
if is_known_alias:
|
||||
# the alias exists now so don't query more ASes.
|
||||
result = yield self.store.get_association_from_room_alias(room_alias)
|
||||
result = await self.store.get_association_from_room_alias(room_alias)
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_3pe(self, kind, protocol, fields):
|
||||
services = yield self._get_services_for_3pn(protocol)
|
||||
async def query_3pe(self, kind, protocol, fields):
|
||||
services = self._get_services_for_3pn(protocol)
|
||||
|
||||
results = yield make_deferred_yieldable(
|
||||
results = await make_deferred_yieldable(
|
||||
defer.DeferredList(
|
||||
[
|
||||
run_in_background(
|
||||
@@ -224,8 +218,7 @@ class ApplicationServicesHandler(object):
|
||||
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_3pe_protocols(self, only_protocol=None):
|
||||
async def get_3pe_protocols(self, only_protocol=None):
|
||||
services = self.store.get_app_services()
|
||||
protocols = {}
|
||||
|
||||
@@ -238,7 +231,7 @@ class ApplicationServicesHandler(object):
|
||||
if p not in protocols:
|
||||
protocols[p] = []
|
||||
|
||||
info = yield self.appservice_api.get_3pe_protocol(s, p)
|
||||
info = await self.appservice_api.get_3pe_protocol(s, p)
|
||||
|
||||
if info is not None:
|
||||
protocols[p].append(info)
|
||||
@@ -263,8 +256,7 @@ class ApplicationServicesHandler(object):
|
||||
|
||||
return protocols
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_services_for_event(self, event):
|
||||
async def _get_services_for_event(self, event):
|
||||
"""Retrieve a list of application services interested in this event.
|
||||
|
||||
Args:
|
||||
@@ -280,7 +272,7 @@ class ApplicationServicesHandler(object):
|
||||
# inside of a list comprehension anymore.
|
||||
interested_list = []
|
||||
for s in services:
|
||||
if (yield s.is_interested(event, self.store)):
|
||||
if await s.is_interested(event, self.store):
|
||||
interested_list.append(s)
|
||||
|
||||
return interested_list
|
||||
@@ -288,21 +280,20 @@ class ApplicationServicesHandler(object):
|
||||
def _get_services_for_user(self, user_id):
|
||||
services = self.store.get_app_services()
|
||||
interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
|
||||
return defer.succeed(interested_list)
|
||||
return interested_list
|
||||
|
||||
def _get_services_for_3pn(self, protocol):
|
||||
services = self.store.get_app_services()
|
||||
interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
|
||||
return defer.succeed(interested_list)
|
||||
return interested_list
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _is_unknown_user(self, user_id):
|
||||
async def _is_unknown_user(self, user_id):
|
||||
if not self.is_mine_id(user_id):
|
||||
# we don't know if they are unknown or not since it isn't one of our
|
||||
# users. We can't poke ASes.
|
||||
return False
|
||||
|
||||
user_info = yield self.store.get_user_by_id(user_id)
|
||||
user_info = await self.store.get_user_by_id(user_id)
|
||||
if user_info:
|
||||
return False
|
||||
|
||||
@@ -311,10 +302,9 @@ class ApplicationServicesHandler(object):
|
||||
service_list = [s for s in services if s.sender == user_id]
|
||||
return len(service_list) == 0
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_user_exists(self, user_id):
|
||||
unknown_user = yield self._is_unknown_user(user_id)
|
||||
async def _check_user_exists(self, user_id):
|
||||
unknown_user = await self._is_unknown_user(user_id)
|
||||
if unknown_user:
|
||||
exists = yield self.query_user_exists(user_id)
|
||||
exists = await self.query_user_exists(user_id)
|
||||
return exists
|
||||
return True
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
import unicodedata
|
||||
@@ -24,7 +23,6 @@ import attr
|
||||
import bcrypt # type: ignore[import]
|
||||
import pymacaroons
|
||||
|
||||
import synapse.util.stringutils as stringutils
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
@@ -45,6 +43,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.types import Requester, UserID
|
||||
from synapse.util import stringutils as stringutils
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
@@ -928,7 +928,7 @@ class AuthHandler(BaseHandler):
|
||||
# for the presence of an email address during password reset was
|
||||
# case sensitive).
|
||||
if medium == "email":
|
||||
address = address.lower()
|
||||
address = canonicalise_email(address)
|
||||
|
||||
await self.store.user_add_threepid(
|
||||
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
|
||||
@@ -956,7 +956,7 @@ class AuthHandler(BaseHandler):
|
||||
|
||||
# 'Canonicalise' email addresses as per above
|
||||
if medium == "email":
|
||||
address = address.lower()
|
||||
address = canonicalise_email(address)
|
||||
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
result = await identity_handler.try_unbind_threepid(
|
||||
|
||||
@@ -12,11 +12,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Dict, Optional, Tuple
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
|
||||
@@ -19,8 +19,9 @@
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections import Container
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import attr
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
@@ -742,6 +743,9 @@ class FederationHandler(BaseHandler):
|
||||
# device and recognize the algorithm then we can work out the
|
||||
# exact key to expect. Otherwise check it matches any key we
|
||||
# have for that device.
|
||||
|
||||
current_keys = [] # type: Container[str]
|
||||
|
||||
if device:
|
||||
keys = device.get("keys", {}).get("keys", {})
|
||||
|
||||
@@ -758,15 +762,15 @@ class FederationHandler(BaseHandler):
|
||||
current_keys = keys.values()
|
||||
elif device_id:
|
||||
# We don't have any keys for the device ID.
|
||||
current_keys = []
|
||||
pass
|
||||
else:
|
||||
# The event didn't include a device ID, so we just look for
|
||||
# keys across all devices.
|
||||
current_keys = (
|
||||
current_keys = [
|
||||
key
|
||||
for device in cached_devices
|
||||
for key in device.get("keys", {}).get("keys", {}).values()
|
||||
)
|
||||
]
|
||||
|
||||
# We now check that the sender key matches (one of) the expected
|
||||
# keys.
|
||||
@@ -1011,7 +1015,7 @@ class FederationHandler(BaseHandler):
|
||||
if e_type == EventTypes.Member and event.membership == Membership.JOIN
|
||||
]
|
||||
|
||||
joined_domains = {}
|
||||
joined_domains = {} # type: Dict[str, int]
|
||||
for u, d in joined_users:
|
||||
try:
|
||||
dom = get_domain_from_id(u)
|
||||
@@ -1277,14 +1281,15 @@ class FederationHandler(BaseHandler):
|
||||
try:
|
||||
# Try the host we successfully got a response to /make_join/
|
||||
# request first.
|
||||
host_list = list(target_hosts)
|
||||
try:
|
||||
target_hosts.remove(origin)
|
||||
target_hosts.insert(0, origin)
|
||||
host_list.remove(origin)
|
||||
host_list.insert(0, origin)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
ret = await self.federation_client.send_join(
|
||||
target_hosts, event, room_version_obj
|
||||
host_list, event, room_version_obj
|
||||
)
|
||||
|
||||
origin = ret["origin"]
|
||||
@@ -1562,7 +1567,7 @@ class FederationHandler(BaseHandler):
|
||||
room_version,
|
||||
event.get_pdu_json(),
|
||||
self.hs.hostname,
|
||||
self.hs.config.signing_key[0],
|
||||
self.hs.signing_key,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1584,13 +1589,14 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
# Try the host that we succesfully called /make_leave/ on first for
|
||||
# the /send_leave/ request.
|
||||
host_list = list(target_hosts)
|
||||
try:
|
||||
target_hosts.remove(origin)
|
||||
target_hosts.insert(0, origin)
|
||||
host_list.remove(origin)
|
||||
host_list.insert(0, origin)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
await self.federation_client.send_leave(target_hosts, event)
|
||||
await self.federation_client.send_leave(host_list, event)
|
||||
|
||||
context = await self.state_handler.compute_event_context(event)
|
||||
stream_id = await self.persist_events_and_notify([(event, context)])
|
||||
@@ -1604,7 +1610,7 @@ class FederationHandler(BaseHandler):
|
||||
user_id: str,
|
||||
membership: str,
|
||||
content: JsonDict = {},
|
||||
params: Optional[Dict[str, str]] = None,
|
||||
params: Optional[Dict[str, Union[str, Iterable[str]]]] = None,
|
||||
) -> Tuple[str, EventBase, RoomVersion]:
|
||||
(
|
||||
origin,
|
||||
@@ -2018,8 +2024,8 @@ class FederationHandler(BaseHandler):
|
||||
auth_events_ids = await self.auth.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=True
|
||||
)
|
||||
auth_events = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
||||
auth_events_x = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
|
||||
|
||||
# This is a hack to fix some old rooms where the initial join event
|
||||
# didn't reference the create event in its auth events.
|
||||
@@ -2055,76 +2061,67 @@ class FederationHandler(BaseHandler):
|
||||
# For new (non-backfilled and non-outlier) events we check if the event
|
||||
# passes auth based on the current state. If it doesn't then we
|
||||
# "soft-fail" the event.
|
||||
do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
|
||||
if do_soft_fail_check:
|
||||
extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
|
||||
if backfilled or event.internal_metadata.is_outlier():
|
||||
return
|
||||
|
||||
extrem_ids = set(extrem_ids)
|
||||
prev_event_ids = set(event.prev_event_ids())
|
||||
extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
|
||||
extrem_ids = set(extrem_ids)
|
||||
prev_event_ids = set(event.prev_event_ids())
|
||||
|
||||
if extrem_ids == prev_event_ids:
|
||||
# If they're the same then the current state is the same as the
|
||||
# state at the event, so no point rechecking auth for soft fail.
|
||||
do_soft_fail_check = False
|
||||
if extrem_ids == prev_event_ids:
|
||||
# If they're the same then the current state is the same as the
|
||||
# state at the event, so no point rechecking auth for soft fail.
|
||||
return
|
||||
|
||||
if do_soft_fail_check:
|
||||
room_version = await self.store.get_room_version_id(event.room_id)
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
room_version = await self.store.get_room_version_id(event.room_id)
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
|
||||
# Calculate the "current state".
|
||||
if state is not None:
|
||||
# If we're explicitly given the state then we won't have all the
|
||||
# prev events, and so we have a gap in the graph. In this case
|
||||
# we want to be a little careful as we might have been down for
|
||||
# a while and have an incorrect view of the current state,
|
||||
# however we still want to do checks as gaps are easy to
|
||||
# maliciously manufacture.
|
||||
#
|
||||
# So we use a "current state" that is actually a state
|
||||
# resolution across the current forward extremities and the
|
||||
# given state at the event. This should correctly handle cases
|
||||
# like bans, especially with state res v2.
|
||||
# Calculate the "current state".
|
||||
if state is not None:
|
||||
# If we're explicitly given the state then we won't have all the
|
||||
# prev events, and so we have a gap in the graph. In this case
|
||||
# we want to be a little careful as we might have been down for
|
||||
# a while and have an incorrect view of the current state,
|
||||
# however we still want to do checks as gaps are easy to
|
||||
# maliciously manufacture.
|
||||
#
|
||||
# So we use a "current state" that is actually a state
|
||||
# resolution across the current forward extremities and the
|
||||
# given state at the event. This should correctly handle cases
|
||||
# like bans, especially with state res v2.
|
||||
|
||||
state_sets = await self.state_store.get_state_groups(
|
||||
event.room_id, extrem_ids
|
||||
)
|
||||
state_sets = list(state_sets.values())
|
||||
state_sets.append(state)
|
||||
current_state_ids = await self.state_handler.resolve_events(
|
||||
room_version, state_sets, event
|
||||
)
|
||||
current_state_ids = {
|
||||
k: e.event_id for k, e in current_state_ids.items()
|
||||
}
|
||||
else:
|
||||
current_state_ids = await self.state_handler.get_current_state_ids(
|
||||
event.room_id, latest_event_ids=extrem_ids
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Doing soft-fail check for %s: state %s",
|
||||
event.event_id,
|
||||
current_state_ids,
|
||||
state_sets = await self.state_store.get_state_groups(
|
||||
event.room_id, extrem_ids
|
||||
)
|
||||
state_sets = list(state_sets.values())
|
||||
state_sets.append(state)
|
||||
current_state_ids = await self.state_handler.resolve_events(
|
||||
room_version, state_sets, event
|
||||
)
|
||||
current_state_ids = {k: e.event_id for k, e in current_state_ids.items()}
|
||||
else:
|
||||
current_state_ids = await self.state_handler.get_current_state_ids(
|
||||
event.room_id, latest_event_ids=extrem_ids
|
||||
)
|
||||
|
||||
# Now check if event pass auth against said current state
|
||||
auth_types = auth_types_for_event(event)
|
||||
current_state_ids = [
|
||||
e for k, e in current_state_ids.items() if k in auth_types
|
||||
]
|
||||
logger.debug(
|
||||
"Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
|
||||
)
|
||||
|
||||
current_auth_events = await self.store.get_events(current_state_ids)
|
||||
current_auth_events = {
|
||||
(e.type, e.state_key): e for e in current_auth_events.values()
|
||||
}
|
||||
# Now check if event pass auth against said current state
|
||||
auth_types = auth_types_for_event(event)
|
||||
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
|
||||
|
||||
try:
|
||||
event_auth.check(
|
||||
room_version_obj, event, auth_events=current_auth_events
|
||||
)
|
||||
except AuthError as e:
|
||||
logger.warning("Soft-failing %r because %s", event, e)
|
||||
event.internal_metadata.soft_failed = True
|
||||
current_auth_events = await self.store.get_events(current_state_ids)
|
||||
current_auth_events = {
|
||||
(e.type, e.state_key): e for e in current_auth_events.values()
|
||||
}
|
||||
|
||||
try:
|
||||
event_auth.check(room_version_obj, event, auth_events=current_auth_events)
|
||||
except AuthError as e:
|
||||
logger.warning("Soft-failing %r because %s", event, e)
|
||||
event.internal_metadata.soft_failed = True
|
||||
|
||||
async def on_query_auth(
|
||||
self, origin, event_id, room_id, remote_auth_chain, rejects, missing
|
||||
@@ -2293,10 +2290,10 @@ class FederationHandler(BaseHandler):
|
||||
remote_auth_chain = await self.federation_client.get_event_auth(
|
||||
origin, event.room_id, event.event_id
|
||||
)
|
||||
except RequestSendFailed as e:
|
||||
except RequestSendFailed as e1:
|
||||
# The other side isn't around or doesn't implement the
|
||||
# endpoint, so lets just bail out.
|
||||
logger.info("Failed to get event auth from remote: %s", e)
|
||||
logger.info("Failed to get event auth from remote: %s", e1)
|
||||
return context
|
||||
|
||||
seen_remotes = await self.store.have_seen_events(
|
||||
@@ -2774,7 +2771,8 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
logger.debug("Checking auth on event %r", event.content)
|
||||
|
||||
last_exception = None
|
||||
last_exception = None # type: Optional[Exception]
|
||||
|
||||
# for each public key in the 3pid invite event
|
||||
for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
|
||||
try:
|
||||
@@ -2828,6 +2826,12 @@ class FederationHandler(BaseHandler):
|
||||
return
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
|
||||
if last_exception is None:
|
||||
# we can only get here if get_public_keys() returned an empty list
|
||||
# TODO: make this better
|
||||
raise RuntimeError("no public key in invite event")
|
||||
|
||||
raise last_exception
|
||||
|
||||
async def _check_key_revocation(self, public_key, url):
|
||||
|
||||
@@ -70,7 +70,7 @@ class GroupsLocalWorkerHandler(object):
|
||||
self.clock = hs.get_clock()
|
||||
self.keyring = hs.get_keyring()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.signing_key = hs.signing_key
|
||||
self.server_name = hs.hostname
|
||||
self.notifier = hs.get_notifier()
|
||||
self.attestations = hs.get_groups_attestation_signing()
|
||||
|
||||
@@ -251,10 +251,10 @@ class IdentityHandler(BaseHandler):
|
||||
# 'browser-like' HTTPS.
|
||||
auth_headers = self.federation_http_client.build_auth_headers(
|
||||
destination=None,
|
||||
method="POST",
|
||||
method=b"POST",
|
||||
url_bytes=url_bytes,
|
||||
content=content,
|
||||
destination_is=id_server,
|
||||
destination_is=id_server.encode("ascii"),
|
||||
)
|
||||
headers = {b"Authorization": auth_headers}
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
|
||||
@@ -55,6 +55,9 @@ from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -349,7 +352,7 @@ _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000
|
||||
|
||||
|
||||
class EventCreationHandler(object):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
@@ -814,11 +817,17 @@ class EventCreationHandler(object):
|
||||
403, "This event is not allowed in this context", Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
try:
|
||||
await self.auth.check_from_context(room_version, event, context)
|
||||
except AuthError as err:
|
||||
logger.warning("Denying new event %r because %s", event, err)
|
||||
raise err
|
||||
if event.internal_metadata.is_out_of_band_membership():
|
||||
# the only sort of out-of-band-membership events we expect to see here
|
||||
# are invite rejections we have generated ourselves.
|
||||
assert event.type == EventTypes.Member
|
||||
assert event.content["membership"] == Membership.LEAVE
|
||||
else:
|
||||
try:
|
||||
await self.auth.check_from_context(room_version, event, context)
|
||||
except AuthError as err:
|
||||
logger.warning("Denying new event %r because %s", event, err)
|
||||
raise err
|
||||
|
||||
# Ensure that we can round trip before trying to persist in db
|
||||
try:
|
||||
|
||||
+137
-65
@@ -1,7 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
# Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2016-2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -18,17 +16,21 @@
|
||||
import abc
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from synapse import types
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
||||
from synapse.api.room_versions import EventFormatVersions
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.builder import create_local_event_from_event_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.replication.http.membership import (
|
||||
ReplicationLocallyRejectInviteRestServlet,
|
||||
)
|
||||
from synapse.types import Collection, Requester, RoomAlias, RoomID, UserID
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
from synapse.types import Collection, JsonDict, Requester, RoomAlias, RoomID, UserID
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.distributor import user_joined_room, user_left_room
|
||||
|
||||
@@ -74,10 +76,6 @@ class RoomMemberHandler(object):
|
||||
)
|
||||
if self._is_on_event_persistence_instance:
|
||||
self.persist_event_storage = hs.get_storage().persistence
|
||||
else:
|
||||
self._locally_reject_client = ReplicationLocallyRejectInviteRestServlet.make_client(
|
||||
hs
|
||||
)
|
||||
|
||||
# This is only used to get at ratelimit function, and
|
||||
# maybe_kick_guest_users. It's fine there are multiple of these as
|
||||
@@ -105,46 +103,28 @@ class RoomMemberHandler(object):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _remote_reject_invite(
|
||||
async def remote_reject_invite(
|
||||
self,
|
||||
invite_event_id: str,
|
||||
txn_id: Optional[str],
|
||||
requester: Requester,
|
||||
remote_room_hosts: List[str],
|
||||
room_id: str,
|
||||
target: UserID,
|
||||
content: dict,
|
||||
) -> Tuple[Optional[str], int]:
|
||||
"""Attempt to reject an invite for a room this server is not in. If we
|
||||
fail to do so we locally mark the invite as rejected.
|
||||
content: JsonDict,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Rejects an out-of-band invite we have received from a remote server
|
||||
|
||||
Args:
|
||||
requester
|
||||
remote_room_hosts: List of servers to use to try and reject invite
|
||||
room_id
|
||||
target: The user rejecting the invite
|
||||
content: The content for the rejection event
|
||||
invite_event_id: ID of the invite to be rejected
|
||||
txn_id: optional transaction ID supplied by the client
|
||||
requester: user making the rejection request, according to the access token
|
||||
content: additional content to include in the rejection event.
|
||||
Normally an empty dict.
|
||||
|
||||
Returns:
|
||||
A dictionary to be returned to the client, may
|
||||
include event_id etc, or nothing if we locally rejected
|
||||
event id, stream_id of the leave event
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
|
||||
"""Mark the invite has having been rejected even though we failed to
|
||||
create a leave event for it.
|
||||
"""
|
||||
if self._is_on_event_persistence_instance:
|
||||
return await self.persist_event_storage.locally_reject_invite(
|
||||
user_id, room_id
|
||||
)
|
||||
else:
|
||||
result = await self._locally_reject_client(
|
||||
instance_name=self._event_stream_writer_instance,
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
)
|
||||
return result["stream_id"]
|
||||
|
||||
@abc.abstractmethod
|
||||
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
|
||||
"""Notifies distributor on master process that the user has joined the
|
||||
@@ -288,7 +268,7 @@ class RoomMemberHandler(object):
|
||||
ratelimit: bool = True,
|
||||
content: Optional[dict] = None,
|
||||
require_consent: bool = True,
|
||||
) -> Tuple[Optional[str], int]:
|
||||
) -> Tuple[str, int]:
|
||||
key = (room_id,)
|
||||
|
||||
with (await self.member_linearizer.queue(key)):
|
||||
@@ -319,7 +299,7 @@ class RoomMemberHandler(object):
|
||||
ratelimit: bool = True,
|
||||
content: Optional[dict] = None,
|
||||
require_consent: bool = True,
|
||||
) -> Tuple[Optional[str], int]:
|
||||
) -> Tuple[str, int]:
|
||||
content_specified = bool(content)
|
||||
if content is None:
|
||||
content = {}
|
||||
@@ -485,11 +465,17 @@ class RoomMemberHandler(object):
|
||||
elif effective_membership_state == Membership.LEAVE:
|
||||
if not is_host_in_room:
|
||||
# perhaps we've been invited
|
||||
inviter = await self._get_inviter(target.to_string(), room_id)
|
||||
if not inviter:
|
||||
invite = await self.store.get_invite_for_local_user_in_room(
|
||||
user_id=target.to_string(), room_id=room_id
|
||||
) # type: Optional[RoomsForUser]
|
||||
if not invite:
|
||||
raise SynapseError(404, "Not a known room")
|
||||
|
||||
if self.hs.is_mine(inviter):
|
||||
logger.info(
|
||||
"%s rejects invite to %s from %s", target, room_id, invite.sender
|
||||
)
|
||||
|
||||
if self.hs.is_mine_id(invite.sender):
|
||||
# the inviter was on our server, but has now left. Carry on
|
||||
# with the normal rejection codepath.
|
||||
#
|
||||
@@ -497,10 +483,10 @@ class RoomMemberHandler(object):
|
||||
# active on other servers.
|
||||
pass
|
||||
else:
|
||||
# send the rejection to the inviter's HS.
|
||||
remote_room_hosts = remote_room_hosts + [inviter.domain]
|
||||
return await self._remote_reject_invite(
|
||||
requester, remote_room_hosts, room_id, target, content,
|
||||
# send the rejection to the inviter's HS (with fallback to
|
||||
# local event)
|
||||
return await self.remote_reject_invite(
|
||||
invite.event_id, txn_id, requester, content,
|
||||
)
|
||||
|
||||
return await self._local_membership_update(
|
||||
@@ -1014,33 +1000,119 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
||||
|
||||
return event_id, stream_id
|
||||
|
||||
async def _remote_reject_invite(
|
||||
async def remote_reject_invite(
|
||||
self,
|
||||
invite_event_id: str,
|
||||
txn_id: Optional[str],
|
||||
requester: Requester,
|
||||
remote_room_hosts: List[str],
|
||||
room_id: str,
|
||||
target: UserID,
|
||||
content: dict,
|
||||
) -> Tuple[Optional[str], int]:
|
||||
"""Implements RoomMemberHandler._remote_reject_invite
|
||||
content: JsonDict,
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Rejects an out-of-band invite received from a remote user
|
||||
|
||||
Implements RoomMemberHandler.remote_reject_invite
|
||||
"""
|
||||
invite_event = await self.store.get_event(invite_event_id)
|
||||
room_id = invite_event.room_id
|
||||
target_user = invite_event.state_key
|
||||
|
||||
# first of all, try doing a rejection via the inviting server
|
||||
fed_handler = self.federation_handler
|
||||
try:
|
||||
inviter_id = UserID.from_string(invite_event.sender)
|
||||
event, stream_id = await fed_handler.do_remotely_reject_invite(
|
||||
remote_room_hosts, room_id, target.to_string(), content=content,
|
||||
[inviter_id.domain], room_id, target_user, content=content
|
||||
)
|
||||
return event.event_id, stream_id
|
||||
except Exception as e:
|
||||
# if we were unable to reject the exception, just mark
|
||||
# it as rejected on our end and plough ahead.
|
||||
# if we were unable to reject the invite, we will generate our own
|
||||
# leave event.
|
||||
#
|
||||
# The 'except' clause is very broad, but we need to
|
||||
# capture everything from DNS failures upwards
|
||||
#
|
||||
logger.warning("Failed to reject invite: %s", e)
|
||||
|
||||
stream_id = await self.locally_reject_invite(target.to_string(), room_id)
|
||||
return None, stream_id
|
||||
return await self._locally_reject_invite(
|
||||
invite_event, txn_id, requester, content
|
||||
)
|
||||
|
||||
async def _locally_reject_invite(
|
||||
self,
|
||||
invite_event: EventBase,
|
||||
txn_id: Optional[str],
|
||||
requester: Requester,
|
||||
content: JsonDict,
|
||||
) -> Tuple[str, int]:
|
||||
"""Generate a local invite rejection
|
||||
|
||||
This is called after we fail to reject an invite via a remote server. It
|
||||
generates an out-of-band membership event locally.
|
||||
|
||||
Args:
|
||||
invite_event: the invite to be rejected
|
||||
txn_id: optional transaction ID supplied by the client
|
||||
requester: user making the rejection request, according to the access token
|
||||
content: additional content to include in the rejection event.
|
||||
Normally an empty dict.
|
||||
"""
|
||||
|
||||
room_id = invite_event.room_id
|
||||
target_user = invite_event.state_key
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
content["membership"] = Membership.LEAVE
|
||||
|
||||
# the auth events for the new event are the same as that of the invite, plus
|
||||
# the invite itself.
|
||||
#
|
||||
# the prev_events are just the invite.
|
||||
invite_hash = invite_event.event_id # type: Union[str, Tuple]
|
||||
if room_version.event_format == EventFormatVersions.V1:
|
||||
alg, h = compute_event_reference_hash(invite_event)
|
||||
invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
|
||||
|
||||
auth_events = tuple(invite_event.auth_events) + (invite_hash,)
|
||||
prev_events = (invite_hash,)
|
||||
|
||||
# we cap depth of generated events, to ensure that they are not
|
||||
# rejected by other servers (and so that they can be persisted in
|
||||
# the db)
|
||||
depth = min(invite_event.depth + 1, MAX_DEPTH)
|
||||
|
||||
event_dict = {
|
||||
"depth": depth,
|
||||
"auth_events": auth_events,
|
||||
"prev_events": prev_events,
|
||||
"type": EventTypes.Member,
|
||||
"room_id": room_id,
|
||||
"sender": target_user,
|
||||
"content": content,
|
||||
"state_key": target_user,
|
||||
}
|
||||
|
||||
event = create_local_event_from_event_dict(
|
||||
clock=self.clock,
|
||||
hostname=self.hs.hostname,
|
||||
signing_key=self.hs.signing_key,
|
||||
room_version=room_version,
|
||||
event_dict=event_dict,
|
||||
)
|
||||
event.internal_metadata.outlier = True
|
||||
event.internal_metadata.out_of_band_membership = True
|
||||
if txn_id is not None:
|
||||
event.internal_metadata.txn_id = txn_id
|
||||
if requester.access_token_id is not None:
|
||||
event.internal_metadata.token_id = requester.access_token_id
|
||||
|
||||
EventValidator().validate_new(event, self.config)
|
||||
|
||||
context = await self.state_handler.compute_event_context(event)
|
||||
context.app_service = requester.app_service
|
||||
stream_id = await self.event_creation_handler.handle_new_client_event(
|
||||
requester, event, context, extra_users=[UserID.from_string(target_user)],
|
||||
)
|
||||
return event.event_id, stream_id
|
||||
|
||||
async def _user_joined_room(self, target: UserID, room_id: str) -> None:
|
||||
"""Implements RoomMemberHandler._user_joined_room
|
||||
|
||||
@@ -61,21 +61,22 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
|
||||
|
||||
return ret["event_id"], ret["stream_id"]
|
||||
|
||||
async def _remote_reject_invite(
|
||||
async def remote_reject_invite(
|
||||
self,
|
||||
invite_event_id: str,
|
||||
txn_id: Optional[str],
|
||||
requester: Requester,
|
||||
remote_room_hosts: List[str],
|
||||
room_id: str,
|
||||
target: UserID,
|
||||
content: dict,
|
||||
) -> Tuple[Optional[str], int]:
|
||||
"""Implements RoomMemberHandler._remote_reject_invite
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Rejects an out-of-band invite received from a remote user
|
||||
|
||||
Implements RoomMemberHandler.remote_reject_invite
|
||||
"""
|
||||
ret = await self._remote_reject_client(
|
||||
invite_event_id=invite_event_id,
|
||||
txn_id=txn_id,
|
||||
requester=requester,
|
||||
remote_room_hosts=remote_room_hosts,
|
||||
room_id=room_id,
|
||||
user_id=target.to_string(),
|
||||
content=content,
|
||||
)
|
||||
return ret["event_id"], ret["stream_id"]
|
||||
|
||||
@@ -294,6 +294,9 @@ class TypingHandler(object):
|
||||
rows.sort()
|
||||
|
||||
limited = False
|
||||
# We, unusually, use a strict limit here as we have all the rows in
|
||||
# memory rather than pulling them out of the database with a `LIMIT ?`
|
||||
# clause.
|
||||
if len(rows) > limit:
|
||||
rows = rows[:limit]
|
||||
current_id = rows[-1][0]
|
||||
|
||||
@@ -13,13 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
||||
from synapse.http.server import wrap_json_request_handler
|
||||
from synapse.http.server import DirectServeJsonResource
|
||||
|
||||
|
||||
class AdditionalResource(Resource):
|
||||
class AdditionalResource(DirectServeJsonResource):
|
||||
"""Resource wrapper for additional_resources
|
||||
|
||||
If the user has configured additional_resources, we need to wrap the
|
||||
@@ -41,16 +38,10 @@ class AdditionalResource(Resource):
|
||||
handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
|
||||
function to be called to handle the request.
|
||||
"""
|
||||
Resource.__init__(self)
|
||||
super().__init__()
|
||||
self._handler = handler
|
||||
|
||||
# required by the request_handler wrapper
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def render(self, request):
|
||||
self._async_render(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@wrap_json_request_handler
|
||||
def _async_render(self, request):
|
||||
# Cheekily pass the result straight through, so we don't need to worry
|
||||
# if its an awaitable or not.
|
||||
return self._handler(request)
|
||||
|
||||
@@ -176,7 +176,7 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
def __init__(self, hs, tls_client_options_factory):
|
||||
self.hs = hs
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.signing_key = hs.signing_key
|
||||
self.server_name = hs.hostname
|
||||
|
||||
real_reactor = hs.get_reactor()
|
||||
@@ -562,13 +562,17 @@ class MatrixFederationHttpClient(object):
|
||||
Returns:
|
||||
list[bytes]: a list of headers to be added as "Authorization:" headers
|
||||
"""
|
||||
request = {"method": method, "uri": url_bytes, "origin": self.server_name}
|
||||
request = {
|
||||
"method": method.decode("ascii"),
|
||||
"uri": url_bytes.decode("ascii"),
|
||||
"origin": self.server_name,
|
||||
}
|
||||
|
||||
if destination is not None:
|
||||
request["destination"] = destination
|
||||
request["destination"] = destination.decode("ascii")
|
||||
|
||||
if destination_is is not None:
|
||||
request["destination_is"] = destination_is
|
||||
request["destination_is"] = destination_is.decode("ascii")
|
||||
|
||||
if content is not None:
|
||||
request["content"] = content
|
||||
|
||||
+198
-183
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import collections
|
||||
import html
|
||||
import logging
|
||||
@@ -21,7 +22,7 @@ import types
|
||||
import urllib
|
||||
from http import HTTPStatus
|
||||
from io import BytesIO
|
||||
from typing import Awaitable, Callable, TypeVar, Union
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
import jinja2
|
||||
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
|
||||
@@ -62,99 +63,43 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
|
||||
"""
|
||||
|
||||
|
||||
def wrap_json_request_handler(h):
|
||||
"""Wraps a request handler method with exception handling.
|
||||
|
||||
Also does the wrapping with request.processing as per wrap_async_request_handler.
|
||||
|
||||
The handler method must have a signature of "handle_foo(self, request)",
|
||||
where "request" must be a SynapseRequest.
|
||||
|
||||
The handler must return a deferred or a coroutine. If the deferred succeeds
|
||||
we assume that a response has been sent. If the deferred fails with a SynapseError we use
|
||||
it to send a JSON response with the appropriate HTTP reponse code. If the
|
||||
deferred fails with any other type of error we send a 500 reponse.
|
||||
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
||||
"""Sends a JSON error response to clients.
|
||||
"""
|
||||
|
||||
async def wrapped_request_handler(self, request):
|
||||
try:
|
||||
await h(self, request)
|
||||
except SynapseError as e:
|
||||
code = e.code
|
||||
logger.info("%s SynapseError: %s - %s", request, code, e.msg)
|
||||
if f.check(SynapseError):
|
||||
error_code = f.value.code
|
||||
error_dict = f.value.error_dict()
|
||||
|
||||
# Only respond with an error response if we haven't already started
|
||||
# writing, otherwise lets just kill the connection
|
||||
if request.startedWriting:
|
||||
if request.transport:
|
||||
try:
|
||||
request.transport.abortConnection()
|
||||
except Exception:
|
||||
# abortConnection throws if the connection is already closed
|
||||
pass
|
||||
else:
|
||||
respond_with_json(
|
||||
request,
|
||||
code,
|
||||
e.error_dict(),
|
||||
send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
)
|
||||
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
|
||||
else:
|
||||
error_code = 500
|
||||
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
|
||||
|
||||
except Exception:
|
||||
# failure.Failure() fishes the original Failure out
|
||||
# of our stack, and thus gives us a sensible stack
|
||||
# trace.
|
||||
f = failure.Failure()
|
||||
logger.error(
|
||||
"Failed handle request via %r: %r",
|
||||
request.request_metrics.name,
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
# Only respond with an error response if we haven't already started
|
||||
# writing, otherwise lets just kill the connection
|
||||
if request.startedWriting:
|
||||
if request.transport:
|
||||
try:
|
||||
request.transport.abortConnection()
|
||||
except Exception:
|
||||
# abortConnection throws if the connection is already closed
|
||||
pass
|
||||
else:
|
||||
respond_with_json(
|
||||
request,
|
||||
500,
|
||||
{"error": "Internal server error", "errcode": Codes.UNKNOWN},
|
||||
send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
)
|
||||
logger.error(
|
||||
"Failed handle request via %r: %r",
|
||||
request.request_metrics.name,
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
|
||||
return wrap_async_request_handler(wrapped_request_handler)
|
||||
|
||||
|
||||
TV = TypeVar("TV")
|
||||
|
||||
|
||||
def wrap_html_request_handler(
|
||||
h: Callable[[TV, SynapseRequest], Awaitable]
|
||||
) -> Callable[[TV, SynapseRequest], Awaitable[None]]:
|
||||
"""Wraps a request handler method with exception handling.
|
||||
|
||||
Also does the wrapping with request.processing as per wrap_async_request_handler.
|
||||
|
||||
The handler method must have a signature of "handle_foo(self, request)",
|
||||
where "request" must be a SynapseRequest.
|
||||
"""
|
||||
|
||||
async def wrapped_request_handler(self, request):
|
||||
try:
|
||||
await h(self, request)
|
||||
except Exception:
|
||||
f = failure.Failure()
|
||||
return_html_error(f, request, HTML_ERROR_TEMPLATE)
|
||||
|
||||
return wrap_async_request_handler(wrapped_request_handler)
|
||||
# Only respond with an error response if we haven't already started writing,
|
||||
# otherwise lets just kill the connection
|
||||
if request.startedWriting:
|
||||
if request.transport:
|
||||
try:
|
||||
request.transport.abortConnection()
|
||||
except Exception:
|
||||
# abortConnection throws if the connection is already closed
|
||||
pass
|
||||
else:
|
||||
respond_with_json(
|
||||
request,
|
||||
error_code,
|
||||
error_dict,
|
||||
send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
)
|
||||
|
||||
|
||||
def return_html_error(
|
||||
@@ -249,7 +194,113 @@ class HttpServer(object):
|
||||
pass
|
||||
|
||||
|
||||
class JsonResource(HttpServer, resource.Resource):
|
||||
class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
||||
"""Base class for resources that have async handlers.
|
||||
|
||||
Sub classes can either implement `_async_render_<METHOD>` to handle
|
||||
requests by method, or override `_async_render` to handle all requests.
|
||||
|
||||
Args:
|
||||
extract_context: Whether to attempt to extract the opentracing
|
||||
context from the request the servlet is handling.
|
||||
"""
|
||||
|
||||
def __init__(self, extract_context=False):
|
||||
super().__init__()
|
||||
|
||||
self._extract_context = extract_context
|
||||
|
||||
def render(self, request):
|
||||
""" This gets called by twisted every time someone sends us a request.
|
||||
"""
|
||||
defer.ensureDeferred(self._async_render_wrapper(request))
|
||||
return NOT_DONE_YET
|
||||
|
||||
@wrap_async_request_handler
|
||||
async def _async_render_wrapper(self, request):
|
||||
"""This is a wrapper that delegates to `_async_render` and handles
|
||||
exceptions, return values, metrics, etc.
|
||||
"""
|
||||
try:
|
||||
request.request_metrics.name = self.__class__.__name__
|
||||
|
||||
with trace_servlet(request, self._extract_context):
|
||||
callback_return = await self._async_render(request)
|
||||
|
||||
if callback_return is not None:
|
||||
code, response = callback_return
|
||||
self._send_response(request, code, response)
|
||||
except Exception:
|
||||
# failure.Failure() fishes the original Failure out
|
||||
# of our stack, and thus gives us a sensible stack
|
||||
# trace.
|
||||
f = failure.Failure()
|
||||
self._send_error_response(f, request)
|
||||
|
||||
async def _async_render(self, request):
|
||||
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
|
||||
no appropriate method exists. Can be overriden in sub classes for
|
||||
different routing.
|
||||
"""
|
||||
|
||||
method_handler = getattr(
|
||||
self, "_async_render_%s" % (request.method.decode("ascii"),), None
|
||||
)
|
||||
if method_handler:
|
||||
raw_callback_return = method_handler(request)
|
||||
|
||||
# Is it synchronous? We'll allow this for now.
|
||||
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
|
||||
callback_return = await raw_callback_return
|
||||
else:
|
||||
callback_return = raw_callback_return
|
||||
|
||||
return callback_return
|
||||
|
||||
_unrecognised_request_handler(request)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _send_response(
|
||||
self, request: SynapseRequest, code: int, response_object: Any,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def _send_error_response(
|
||||
self, f: failure.Failure, request: SynapseRequest,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DirectServeJsonResource(_AsyncResource):
|
||||
"""A resource that will call `self._async_on_<METHOD>` on new requests,
|
||||
formatting responses and errors as JSON.
|
||||
"""
|
||||
|
||||
def _send_response(
|
||||
self, request, code, response_object,
|
||||
):
|
||||
"""Implements _AsyncResource._send_response
|
||||
"""
|
||||
# TODO: Only enable CORS for the requests that need it.
|
||||
respond_with_json(
|
||||
request,
|
||||
code,
|
||||
response_object,
|
||||
send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
canonical_json=self.canonical_json,
|
||||
)
|
||||
|
||||
def _send_error_response(
|
||||
self, f: failure.Failure, request: SynapseRequest,
|
||||
) -> None:
|
||||
"""Implements _AsyncResource._send_error_response
|
||||
"""
|
||||
return_json_error(f, request)
|
||||
|
||||
|
||||
class JsonResource(DirectServeJsonResource):
|
||||
""" This implements the HttpServer interface and provides JSON support for
|
||||
Resources.
|
||||
|
||||
@@ -269,17 +320,15 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
"_PathEntry", ["pattern", "callback", "servlet_classname"]
|
||||
)
|
||||
|
||||
def __init__(self, hs, canonical_json=True):
|
||||
resource.Resource.__init__(self)
|
||||
def __init__(self, hs, canonical_json=True, extract_context=False):
|
||||
super().__init__(extract_context)
|
||||
|
||||
self.canonical_json = canonical_json
|
||||
self.clock = hs.get_clock()
|
||||
self.path_regexs = {}
|
||||
self.hs = hs
|
||||
|
||||
def register_paths(
|
||||
self, method, path_patterns, callback, servlet_classname, trace=True
|
||||
):
|
||||
def register_paths(self, method, path_patterns, callback, servlet_classname):
|
||||
"""
|
||||
Registers a request handler against a regular expression. Later request URLs are
|
||||
checked against these regular expressions in order to identify an appropriate
|
||||
@@ -295,74 +344,23 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
|
||||
servlet_classname (str): The name of the handler to be used in prometheus
|
||||
and opentracing logs.
|
||||
|
||||
trace (bool): Whether we should start a span to trace the servlet.
|
||||
"""
|
||||
method = method.encode("utf-8") # method is bytes on py3
|
||||
|
||||
if trace:
|
||||
# We don't extract the context from the servlet because we can't
|
||||
# trust the sender
|
||||
callback = trace_servlet(servlet_classname)(callback)
|
||||
|
||||
for path_pattern in path_patterns:
|
||||
logger.debug("Registering for %s %s", method, path_pattern.pattern)
|
||||
self.path_regexs.setdefault(method, []).append(
|
||||
self._PathEntry(path_pattern, callback, servlet_classname)
|
||||
)
|
||||
|
||||
def render(self, request):
|
||||
""" This gets called by twisted every time someone sends us a request.
|
||||
"""
|
||||
defer.ensureDeferred(self._async_render(request))
|
||||
return NOT_DONE_YET
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render(self, request):
|
||||
""" This gets called from render() every time someone sends us a request.
|
||||
This checks if anyone has registered a callback for that method and
|
||||
path.
|
||||
"""
|
||||
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
|
||||
|
||||
# Make sure we have a name for this handler in prometheus.
|
||||
request.request_metrics.name = servlet_classname
|
||||
|
||||
# Now trigger the callback. If it returns a response, we send it
|
||||
# here. If it throws an exception, that is handled by the wrapper
|
||||
# installed by @request_handler.
|
||||
kwargs = intern_dict(
|
||||
{
|
||||
name: urllib.parse.unquote(value) if value else value
|
||||
for name, value in group_dict.items()
|
||||
}
|
||||
)
|
||||
|
||||
callback_return = callback(request, **kwargs)
|
||||
|
||||
# Is it synchronous? We'll allow this for now.
|
||||
if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
|
||||
callback_return = await callback_return
|
||||
|
||||
if callback_return is not None:
|
||||
code, response = callback_return
|
||||
self._send_response(request, code, response)
|
||||
|
||||
def _get_handler_for_request(self, request):
|
||||
"""Finds a callback method to handle the given request
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request):
|
||||
def _get_handler_for_request(
|
||||
self, request: SynapseRequest
|
||||
) -> Tuple[Callable, str, Dict[str, str]]:
|
||||
"""Finds a callback method to handle the given request.
|
||||
|
||||
Returns:
|
||||
Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
|
||||
label to use for that method in prometheus metrics, and the
|
||||
dict mapping keys to path components as specified in the
|
||||
handler's path match regexp.
|
||||
|
||||
The callback will normally be a method registered via
|
||||
register_paths, so will return (possibly via Deferred) either
|
||||
None, or a tuple of (http code, response body).
|
||||
A tuple of the callback to use, the name of the servlet, and the
|
||||
key word arguments to pass to the callback
|
||||
"""
|
||||
request_path = request.path.decode("ascii")
|
||||
|
||||
@@ -377,42 +375,59 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
||||
return _unrecognised_request_handler, "unrecognised_request_handler", {}
|
||||
|
||||
def _send_response(
|
||||
self, request, code, response_json_object, response_code_message=None
|
||||
):
|
||||
# TODO: Only enable CORS for the requests that need it.
|
||||
respond_with_json(
|
||||
request,
|
||||
code,
|
||||
response_json_object,
|
||||
send_cors=True,
|
||||
response_code_message=response_code_message,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
canonical_json=self.canonical_json,
|
||||
async def _async_render(self, request):
|
||||
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
|
||||
|
||||
# Make sure we have an appopriate name for this handler in prometheus
|
||||
# (rather than the default of JsonResource).
|
||||
request.request_metrics.name = servlet_classname
|
||||
|
||||
# Now trigger the callback. If it returns a response, we send it
|
||||
# here. If it throws an exception, that is handled by the wrapper
|
||||
# installed by @request_handler.
|
||||
kwargs = intern_dict(
|
||||
{
|
||||
name: urllib.parse.unquote(value) if value else value
|
||||
for name, value in group_dict.items()
|
||||
}
|
||||
)
|
||||
|
||||
raw_callback_return = callback(request, **kwargs)
|
||||
|
||||
class DirectServeResource(resource.Resource):
|
||||
def render(self, request):
|
||||
# Is it synchronous? We'll allow this for now.
|
||||
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
|
||||
callback_return = await raw_callback_return
|
||||
else:
|
||||
callback_return = raw_callback_return
|
||||
|
||||
return callback_return
|
||||
|
||||
|
||||
class DirectServeHtmlResource(_AsyncResource):
|
||||
"""A resource that will call `self._async_on_<METHOD>` on new requests,
|
||||
formatting responses and errors as HTML.
|
||||
"""
|
||||
|
||||
# The error template to use for this resource
|
||||
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
|
||||
|
||||
def _send_response(
|
||||
self, request: SynapseRequest, code: int, response_object: Any,
|
||||
):
|
||||
"""Implements _AsyncResource._send_response
|
||||
"""
|
||||
Render the request, using an asynchronous render handler if it exists.
|
||||
# We expect to get bytes for us to write
|
||||
assert isinstance(response_object, bytes)
|
||||
html_bytes = response_object
|
||||
|
||||
respond_with_html_bytes(request, 200, html_bytes)
|
||||
|
||||
def _send_error_response(
|
||||
self, f: failure.Failure, request: SynapseRequest,
|
||||
) -> None:
|
||||
"""Implements _AsyncResource._send_error_response
|
||||
"""
|
||||
async_render_callback_name = "_async_render_" + request.method.decode("ascii")
|
||||
|
||||
# Try and get the async renderer
|
||||
callback = getattr(self, async_render_callback_name, None)
|
||||
|
||||
# No async renderer for this request method.
|
||||
if not callback:
|
||||
return super().render(request)
|
||||
|
||||
resp = trace_servlet(self.__class__.__name__)(callback)(request)
|
||||
|
||||
# If it's a coroutine, turn it into a Deferred
|
||||
if isinstance(resp, types.CoroutineType):
|
||||
defer.ensureDeferred(resp)
|
||||
|
||||
return NOT_DONE_YET
|
||||
return_html_error(f, request, self.ERROR_TEMPLATE)
|
||||
|
||||
|
||||
class StaticResource(File):
|
||||
|
||||
@@ -164,12 +164,10 @@ Gotchas
|
||||
than one caller? Will all of those calling functions have be in a context
|
||||
with an active span?
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import types
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Type
|
||||
|
||||
@@ -181,6 +179,7 @@ from twisted.internet import defer
|
||||
from synapse.config import ConfigError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.server import HomeServer
|
||||
|
||||
# Helper class
|
||||
@@ -227,6 +226,7 @@ except ImportError:
|
||||
tags = _DummyTagNames
|
||||
try:
|
||||
from jaeger_client import Config as JaegerConfig
|
||||
|
||||
from synapse.logging.scopecontextmanager import LogContextScopeManager
|
||||
except ImportError:
|
||||
JaegerConfig = None # type: ignore
|
||||
@@ -793,48 +793,42 @@ def tag_args(func):
|
||||
return _tag_args_inner
|
||||
|
||||
|
||||
def trace_servlet(servlet_name, extract_context=False):
|
||||
"""Decorator which traces a serlet. It starts a span with some servlet specific
|
||||
tags such as the servlet_name and request information
|
||||
@contextlib.contextmanager
|
||||
def trace_servlet(request: "SynapseRequest", extract_context: bool = False):
|
||||
"""Returns a context manager which traces a request. It starts a span
|
||||
with some servlet specific tags such as the request metrics name and
|
||||
request information.
|
||||
|
||||
Args:
|
||||
servlet_name (str): The name to be used for the span's operation_name
|
||||
extract_context (bool): Whether to attempt to extract the opentracing
|
||||
request
|
||||
extract_context: Whether to attempt to extract the opentracing
|
||||
context from the request the servlet is handling.
|
||||
|
||||
"""
|
||||
|
||||
def _trace_servlet_inner_1(func):
|
||||
if not opentracing:
|
||||
return func
|
||||
if opentracing is None:
|
||||
yield
|
||||
return
|
||||
|
||||
@wraps(func)
|
||||
async def _trace_servlet_inner(request, *args, **kwargs):
|
||||
request_tags = {
|
||||
"request_id": request.get_request_id(),
|
||||
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
||||
tags.HTTP_METHOD: request.get_method(),
|
||||
tags.HTTP_URL: request.get_redacted_uri(),
|
||||
tags.PEER_HOST_IPV6: request.getClientIP(),
|
||||
}
|
||||
request_tags = {
|
||||
"request_id": request.get_request_id(),
|
||||
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
|
||||
tags.HTTP_METHOD: request.get_method(),
|
||||
tags.HTTP_URL: request.get_redacted_uri(),
|
||||
tags.PEER_HOST_IPV6: request.getClientIP(),
|
||||
}
|
||||
|
||||
if extract_context:
|
||||
scope = start_active_span_from_request(
|
||||
request, servlet_name, tags=request_tags
|
||||
)
|
||||
else:
|
||||
scope = start_active_span(servlet_name, tags=request_tags)
|
||||
request_name = request.request_metrics.name
|
||||
if extract_context:
|
||||
scope = start_active_span_from_request(request, request_name, tags=request_tags)
|
||||
else:
|
||||
scope = start_active_span(request_name, tags=request_tags)
|
||||
|
||||
with scope:
|
||||
result = func(request, *args, **kwargs)
|
||||
with scope:
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# We set the operation name again in case its changed (which happens
|
||||
# with JsonResource).
|
||||
scope.span.set_operation_name(request.request_metrics.name)
|
||||
|
||||
if not isinstance(result, (types.CoroutineType, defer.Deferred)):
|
||||
# Some servlets aren't async and just return results
|
||||
# directly, so we handle that here.
|
||||
return result
|
||||
|
||||
return await result
|
||||
|
||||
return _trace_servlet_inner
|
||||
|
||||
return _trace_servlet_inner_1
|
||||
scope.span.set_tag("request_tag", request.request_metrics.start_context.tag)
|
||||
|
||||
@@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, Optional, Set
|
||||
from prometheus_client.core import REGISTRY, Counter, Gauge
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
||||
|
||||
@@ -212,7 +213,14 @@ def run_as_background_process(desc, func, *args, **kwargs):
|
||||
|
||||
return (yield result)
|
||||
except Exception:
|
||||
logger.exception("Background process '%s' threw an exception", desc)
|
||||
# failure.Failure() fishes the original Failure out of our stack, and
|
||||
# thus gives us a sensible stack trace.
|
||||
f = Failure()
|
||||
logger.error(
|
||||
"Background process '%s' threw an exception",
|
||||
desc,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
)
|
||||
finally:
|
||||
_background_process_in_flight_count.labels(desc).dec()
|
||||
|
||||
|
||||
+1
-1
@@ -83,7 +83,7 @@ class _NotifierUserStream(object):
|
||||
self.current_token = current_token
|
||||
|
||||
# The last token for which we should wake up any streams that have a
|
||||
# token that comes before it. This gets updated everytime we get poked.
|
||||
# token that comes before it. This gets updated every time we get poked.
|
||||
# We start it at the current token since if we get any streams
|
||||
# that have a token from before we have no idea whether they should be
|
||||
# woken up or not, so lets just wake them up.
|
||||
|
||||
@@ -20,6 +20,7 @@ from prometheus_client import Counter
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.logging import opentracing
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.push import PusherConfigException
|
||||
@@ -305,12 +306,23 @@ class HttpPusher(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _build_notification_dict(self, event, tweaks, badge):
|
||||
priority = "low"
|
||||
if (
|
||||
event.type == EventTypes.Encrypted
|
||||
or tweaks.get("highlight")
|
||||
or tweaks.get("sound")
|
||||
):
|
||||
# HACK send our push as high priority only if it generates a sound, highlight
|
||||
# or may do so (i.e. is encrypted so has unknown effects).
|
||||
priority = "high"
|
||||
|
||||
if self.data.get("format") == "event_id_only":
|
||||
d = {
|
||||
"notification": {
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"counts": {"unread": badge},
|
||||
"prio": priority,
|
||||
"devices": [
|
||||
{
|
||||
"app_id": self.app_id,
|
||||
@@ -334,9 +346,8 @@ class HttpPusher(object):
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"sender": event.user_id,
|
||||
"counts": { # -- we don't mark messages as read yet so
|
||||
# we have no way of knowing
|
||||
# Just set the badge to 1 until we have read receipts
|
||||
"prio": priority,
|
||||
"counts": {
|
||||
"unread": badge,
|
||||
# 'missed_calls': 2
|
||||
},
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Pattern
|
||||
from typing import Any, Dict, List, Pattern, Union
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import UserID
|
||||
@@ -72,13 +72,36 @@ def _test_ineq_condition(condition, number):
|
||||
return False
|
||||
|
||||
|
||||
def tweaks_for_actions(actions):
|
||||
def tweaks_for_actions(actions: List[Union[str, Dict]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Converts a list of actions into a `tweaks` dict (which can then be passed to
|
||||
the push gateway).
|
||||
|
||||
This function ignores all actions other than `set_tweak` actions, and treats
|
||||
absent `value`s as `True`, which agrees with the only spec-defined treatment
|
||||
of absent `value`s (namely, for `highlight` tweaks).
|
||||
|
||||
Args:
|
||||
actions: list of actions
|
||||
e.g. [
|
||||
{"set_tweak": "a", "value": "AAA"},
|
||||
{"set_tweak": "b", "value": "BBB"},
|
||||
{"set_tweak": "highlight"},
|
||||
"notify"
|
||||
]
|
||||
|
||||
Returns:
|
||||
dictionary of tweaks for those actions
|
||||
e.g. {"a": "AAA", "b": "BBB", "highlight": True}
|
||||
"""
|
||||
tweaks = {}
|
||||
for a in actions:
|
||||
if not isinstance(a, dict):
|
||||
continue
|
||||
if "set_tweak" in a and "value" in a:
|
||||
tweaks[a["set_tweak"]] = a["value"]
|
||||
if "set_tweak" in a:
|
||||
# value is allowed to be absent in which case the value assumed
|
||||
# should be True.
|
||||
tweaks[a["set_tweak"]] = a.get("value", True)
|
||||
return tweaks
|
||||
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ REQUIREMENTS = [
|
||||
"pymacaroons>=0.13.0",
|
||||
"msgpack>=0.5.2",
|
||||
"phonenumbers>=8.2.0",
|
||||
"prometheus_client>=0.0.18,<0.8.0",
|
||||
"prometheus_client>=0.0.18,<0.9.0",
|
||||
# we use attr.validators.deep_iterable, which arrived in 19.1.0
|
||||
"attrs>=19.1.0",
|
||||
"netaddr>=0.7.18",
|
||||
|
||||
@@ -30,7 +30,8 @@ REPLICATION_PREFIX = "/_synapse/replication"
|
||||
|
||||
class ReplicationRestResource(JsonResource):
|
||||
def __init__(self, hs):
|
||||
JsonResource.__init__(self, hs, canonical_json=False)
|
||||
# We enable extracting jaeger contexts here as these are internal APIs.
|
||||
super().__init__(hs, canonical_json=False, extract_context=True)
|
||||
self.register_servlets(hs)
|
||||
|
||||
def register_servlets(self, hs):
|
||||
|
||||
@@ -28,11 +28,7 @@ from synapse.api.errors import (
|
||||
RequestSendFailed,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.logging.opentracing import (
|
||||
inject_active_span_byte_dict,
|
||||
trace,
|
||||
trace_servlet,
|
||||
)
|
||||
from synapse.logging.opentracing import inject_active_span_byte_dict, trace
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
@@ -96,11 +92,11 @@ class ReplicationEndpoint(object):
|
||||
# assert here that sub classes don't try and use the name.
|
||||
assert (
|
||||
"instance_name" not in self.PATH_ARGS
|
||||
), "`instance_name` is a reserved paramater name"
|
||||
), "`instance_name` is a reserved parameter name"
|
||||
assert (
|
||||
"instance_name"
|
||||
not in signature(self.__class__._serialize_payload).parameters
|
||||
), "`instance_name` is a reserved paramater name"
|
||||
), "`instance_name` is a reserved parameter name"
|
||||
|
||||
assert self.METHOD in ("PUT", "POST", "GET")
|
||||
|
||||
@@ -240,11 +236,8 @@ class ReplicationEndpoint(object):
|
||||
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
|
||||
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
|
||||
|
||||
handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
|
||||
# We don't let register paths trace this servlet using the default tracing
|
||||
# options because we wish to extract the context explicitly.
|
||||
http_server.register_paths(
|
||||
method, [pattern], handler, self.__class__.__name__, trace=False
|
||||
method, [pattern], handler, self.__class__.__name__,
|
||||
)
|
||||
|
||||
def _cached_handler(self, request, txn_id, **kwargs):
|
||||
|
||||
@@ -14,11 +14,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import Requester, UserID
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.util.distributor import user_joined_room, user_left_room
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -88,49 +88,54 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
||||
|
||||
|
||||
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
||||
"""Rejects the invite for the user and room.
|
||||
"""Rejects an out-of-band invite we have received from a remote server
|
||||
|
||||
Request format:
|
||||
|
||||
POST /_synapse/replication/remote_reject_invite/:room_id/:user_id
|
||||
POST /_synapse/replication/remote_reject_invite/:event_id
|
||||
|
||||
{
|
||||
"txn_id": ...,
|
||||
"requester": ...,
|
||||
"remote_room_hosts": [...],
|
||||
"content": { ... }
|
||||
}
|
||||
"""
|
||||
|
||||
NAME = "remote_reject_invite"
|
||||
PATH_ARGS = ("room_id", "user_id")
|
||||
PATH_ARGS = ("invite_event_id",)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
|
||||
|
||||
self.federation_handler = hs.get_handlers().federation_handler
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.member_handler = hs.get_room_member_handler()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
|
||||
def _serialize_payload( # type: ignore
|
||||
invite_event_id: str,
|
||||
txn_id: Optional[str],
|
||||
requester: Requester,
|
||||
content: JsonDict,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
requester(Requester)
|
||||
room_id (str)
|
||||
user_id (str)
|
||||
remote_room_hosts (list[str]): Servers to try and reject via
|
||||
invite_event_id: ID of the invite to be rejected
|
||||
txn_id: optional transaction ID supplied by the client
|
||||
requester: user making the rejection request, according to the access token
|
||||
content: additional content to include in the rejection event.
|
||||
Normally an empty dict.
|
||||
"""
|
||||
return {
|
||||
"txn_id": txn_id,
|
||||
"requester": requester.serialize(),
|
||||
"remote_room_hosts": remote_room_hosts,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
async def _handle_request(self, request, room_id, user_id):
|
||||
async def _handle_request(self, request, invite_event_id):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
remote_room_hosts = content["remote_room_hosts"]
|
||||
txn_id = content["txn_id"]
|
||||
event_content = content["content"]
|
||||
|
||||
requester = Requester.deserialize(self.store, content["requester"])
|
||||
@@ -138,60 +143,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
||||
if requester.user:
|
||||
request.authenticated_entity = requester.user.to_string()
|
||||
|
||||
logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
|
||||
|
||||
try:
|
||||
event, stream_id = await self.federation_handler.do_remotely_reject_invite(
|
||||
remote_room_hosts, room_id, user_id, event_content,
|
||||
)
|
||||
event_id = event.event_id
|
||||
except Exception as e:
|
||||
# if we were unable to reject the exception, just mark
|
||||
# it as rejected on our end and plough ahead.
|
||||
#
|
||||
# The 'except' clause is very broad, but we need to
|
||||
# capture everything from DNS failures upwards
|
||||
#
|
||||
logger.warning("Failed to reject invite: %s", e)
|
||||
|
||||
stream_id = await self.member_handler.locally_reject_invite(
|
||||
user_id, room_id
|
||||
)
|
||||
event_id = None
|
||||
# hopefully we're now on the master, so this won't recurse!
|
||||
event_id, stream_id = await self.member_handler.remote_reject_invite(
|
||||
invite_event_id, txn_id, requester, event_content,
|
||||
)
|
||||
|
||||
return 200, {"event_id": event_id, "stream_id": stream_id}
|
||||
|
||||
|
||||
class ReplicationLocallyRejectInviteRestServlet(ReplicationEndpoint):
|
||||
"""Rejects the invite for the user and room locally.
|
||||
|
||||
Request format:
|
||||
|
||||
POST /_synapse/replication/locally_reject_invite/:room_id/:user_id
|
||||
|
||||
{}
|
||||
"""
|
||||
|
||||
NAME = "locally_reject_invite"
|
||||
PATH_ARGS = ("room_id", "user_id")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
self.member_handler = hs.get_room_member_handler()
|
||||
|
||||
@staticmethod
|
||||
def _serialize_payload(room_id, user_id):
|
||||
return {}
|
||||
|
||||
async def _handle_request(self, request, room_id, user_id):
|
||||
logger.info("locally_reject_invite: %s out of room: %s", user_id, room_id)
|
||||
|
||||
stream_id = await self.member_handler.locally_reject_invite(user_id, room_id)
|
||||
|
||||
return 200, {"stream_id": stream_id}
|
||||
|
||||
|
||||
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
||||
"""Notifies that a user has joined or left the room
|
||||
|
||||
@@ -245,4 +204,3 @@ def register_servlets(hs, http_server):
|
||||
ReplicationRemoteJoinRestServlet(hs).register(http_server)
|
||||
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
|
||||
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
|
||||
ReplicationLocallyRejectInviteRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream
|
||||
from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
|
||||
from synapse.storage.data_stores.main.tags import TagsWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
@@ -39,12 +40,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
|
||||
return self._account_data_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "tag_account_data":
|
||||
if stream_name == TagAccountDataStream.NAME:
|
||||
self._account_data_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self.get_tags_for_user.invalidate((row.user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
|
||||
elif stream_name == "account_data":
|
||||
elif stream_name == AccountDataStream.NAME:
|
||||
self._account_data_id_gen.advance(token)
|
||||
for row in rows:
|
||||
if not row.room_id:
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.replication.tcp.streams import ToDeviceStream
|
||||
from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
@@ -44,7 +45,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
||||
)
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "to_device":
|
||||
if stream_name == ToDeviceStream.NAME:
|
||||
self._device_inbox_id_gen.advance(token)
|
||||
for row in rows:
|
||||
if row.entity.startswith("@"):
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.replication.tcp.streams import GroupServerStream
|
||||
from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
@@ -38,7 +39,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
||||
return self._group_updates_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "groups":
|
||||
if stream_name == GroupServerStream.NAME:
|
||||
self._group_updates_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.tcp.streams import PresenceStream
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.data_stores.main.presence import PresenceStore
|
||||
from synapse.storage.database import Database
|
||||
@@ -42,7 +43,7 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||
return self._presence_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "presence":
|
||||
if stream_name == PresenceStream.NAME:
|
||||
self._presence_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self.presence_stream_cache.entity_has_changed(row.user_id, token)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.tcp.streams import PushRulesStream
|
||||
from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
|
||||
|
||||
from .events import SlavedEventStore
|
||||
@@ -30,7 +31,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
||||
return self._push_rules_stream_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "push_rules":
|
||||
if stream_name == PushRulesStream.NAME:
|
||||
self._push_rules_stream_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self.get_push_rules_for_user.invalidate((row.user_id,))
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.tcp.streams import PushersStream
|
||||
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
|
||||
@@ -32,6 +33,6 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "pushers":
|
||||
if stream_name == PushersStream.NAME:
|
||||
self._pushers_id_gen.advance(token)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -14,20 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.tcp.streams import ReceiptsStream
|
||||
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
# So, um, we want to borrow a load of functions intended for reading from
|
||||
# a DataStore, but we don't want to take functions that either write to the
|
||||
# DataStore or are cached and don't have cache invalidation logic.
|
||||
#
|
||||
# Rather than write duplicate versions of those functions, or lift them to
|
||||
# a common base class, we going to grab the underlying __func__ object from
|
||||
# the method descriptor on the DataStore and chuck them into our class.
|
||||
|
||||
|
||||
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
@@ -52,7 +45,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||
self.get_receipts_for_room.invalidate((room_id, receipt_type))
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "receipts":
|
||||
if stream_name == ReceiptsStream.NAME:
|
||||
self._receipts_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self.invalidate_caches_for_receipt(
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.tcp.streams import PublicRoomsStream
|
||||
from synapse.storage.data_stores.main.room import RoomWorkerStore
|
||||
from synapse.storage.database import Database
|
||||
|
||||
@@ -31,7 +32,7 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
|
||||
return self._public_room_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "public_rooms":
|
||||
if stream_name == PublicRoomsStream.NAME:
|
||||
self._public_room_id_gen.advance(token)
|
||||
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -25,7 +25,7 @@ Structure of the module:
|
||||
* command.py - the definitions of all the valid commands
|
||||
* protocol.py - the TCP protocol classes
|
||||
* resource.py - handles streaming stream updates to replications
|
||||
* streams/ - the definitons of all the valid streams
|
||||
* streams/ - the definitions of all the valid streams
|
||||
|
||||
|
||||
The general interaction of the classes are:
|
||||
|
||||
@@ -33,8 +33,8 @@ from synapse.util.async_helpers import timeout_deferred
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -18,18 +18,11 @@ The VALID_SERVER_COMMANDS and VALID_CLIENT_COMMANDS define which commands are
|
||||
allowed to be sent by which side.
|
||||
"""
|
||||
import abc
|
||||
import json
|
||||
import logging
|
||||
import platform
|
||||
from typing import Tuple, Type
|
||||
|
||||
if platform.python_implementation() == "PyPy":
|
||||
import json
|
||||
|
||||
_json_encoder = json.JSONEncoder()
|
||||
else:
|
||||
import simplejson as json # type: ignore[no-redef] # noqa: F821
|
||||
|
||||
_json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
|
||||
_json_encoder = json.JSONEncoder()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -54,7 +47,7 @@ class Command(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def to_line(self) -> str:
|
||||
"""Serialises the comamnd for the wire. Does not include the command
|
||||
"""Serialises the command for the wire. Does not include the command
|
||||
prefix.
|
||||
"""
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
|
||||
|
||||
@@ -149,10 +148,11 @@ class ReplicationCommandHandler:
|
||||
using TCP.
|
||||
"""
|
||||
if hs.config.redis.redis_enabled:
|
||||
import txredisapi
|
||||
|
||||
from synapse.replication.tcp.redis import (
|
||||
RedisDirectTcpReplicationClientFactory,
|
||||
)
|
||||
import txredisapi
|
||||
|
||||
logger.info(
|
||||
"Connecting to redis (host=%r port=%r)",
|
||||
|
||||
@@ -317,7 +317,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
|
||||
def _queue_command(self, cmd):
|
||||
"""Queue the command until the connection is ready to write to again.
|
||||
"""
|
||||
logger.debug("[%s] Queing as conn %r, cmd: %r", self.id(), self.state, cmd)
|
||||
logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
|
||||
self.pending_commands.append(cmd)
|
||||
|
||||
if len(self.pending_commands) > self.max_line_buffer:
|
||||
|
||||
@@ -177,7 +177,7 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
|
||||
Args:
|
||||
hs
|
||||
outbound_redis_connection: A connection to redis that will be used to
|
||||
send outbound commands (this is seperate to the redis connection
|
||||
send outbound commands (this is separate to the redis connection
|
||||
used to subscribe).
|
||||
"""
|
||||
|
||||
|
||||
@@ -198,26 +198,6 @@ def current_token_without_instance(
|
||||
return lambda instance_name: current_token()
|
||||
|
||||
|
||||
def db_query_to_update_function(
|
||||
query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
|
||||
) -> UpdateFunction:
|
||||
"""Wraps a db query function which returns a list of rows to make it
|
||||
suitable for use as an `update_function` for the Stream class
|
||||
"""
|
||||
|
||||
async def update_function(instance_name, from_token, upto_token, limit):
|
||||
rows = await query_function(from_token, upto_token, limit)
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
limited = False
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
return update_function
|
||||
|
||||
|
||||
def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
|
||||
"""Makes a suitable function for use as an `update_function` that queries
|
||||
the master process for updates.
|
||||
@@ -393,7 +373,7 @@ class PushersStream(Stream):
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(store.get_pushers_stream_token),
|
||||
db_query_to_update_function(store.get_all_updated_pushers_rows),
|
||||
store.get_all_updated_pushers_rows,
|
||||
)
|
||||
|
||||
|
||||
@@ -421,27 +401,13 @@ class CachesStream(Stream):
|
||||
ROW_TYPE = CachesStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
self.store.get_cache_stream_token,
|
||||
self._update_function,
|
||||
store.get_cache_stream_token,
|
||||
store.get_all_updated_caches,
|
||||
)
|
||||
|
||||
async def _update_function(
|
||||
self, instance_name: str, from_token: int, upto_token: int, limit: int
|
||||
):
|
||||
rows = await self.store.get_all_updated_caches(
|
||||
instance_name, from_token, upto_token, limit
|
||||
)
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
limited = False
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
|
||||
class PublicRoomsStream(Stream):
|
||||
"""The public rooms list changed
|
||||
@@ -465,7 +431,7 @@ class PublicRoomsStream(Stream):
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(store.get_current_public_room_stream_id),
|
||||
db_query_to_update_function(store.get_all_new_public_rooms),
|
||||
store.get_all_new_public_rooms,
|
||||
)
|
||||
|
||||
|
||||
@@ -486,7 +452,7 @@ class DeviceListsStream(Stream):
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(store.get_device_stream_token),
|
||||
db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
|
||||
store.get_all_device_list_changes_for_remotes,
|
||||
)
|
||||
|
||||
|
||||
@@ -504,7 +470,7 @@ class ToDeviceStream(Stream):
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(store.get_to_device_stream_token),
|
||||
db_query_to_update_function(store.get_all_new_device_messages),
|
||||
store.get_all_new_device_messages,
|
||||
)
|
||||
|
||||
|
||||
@@ -524,7 +490,7 @@ class TagAccountDataStream(Stream):
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(store.get_max_account_data_stream_id),
|
||||
db_query_to_update_function(store.get_all_updated_tags),
|
||||
store.get_all_updated_tags,
|
||||
)
|
||||
|
||||
|
||||
@@ -612,7 +578,7 @@ class GroupServerStream(Stream):
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(store.get_group_stream_token),
|
||||
db_query_to_update_function(store.get_all_groups_changes),
|
||||
store.get_all_groups_changes,
|
||||
)
|
||||
|
||||
|
||||
@@ -630,7 +596,5 @@ class UserSignatureStream(Stream):
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(store.get_device_stream_token),
|
||||
db_query_to_update_function(
|
||||
store.get_all_user_signature_changes_for_remotes
|
||||
),
|
||||
store.get_all_user_signature_changes_for_remotes,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import heapq
|
||||
from collections import Iterable
|
||||
from typing import List, Tuple, Type
|
||||
@@ -22,7 +21,6 @@ import attr
|
||||
|
||||
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
|
||||
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
|
||||
This stream contains rows of various types. Each row therefore contains a 'type'
|
||||
@@ -64,7 +62,7 @@ class BaseEventsStreamRow(object):
|
||||
Specifies how to identify, serialize and deserialize the different types.
|
||||
"""
|
||||
|
||||
# Unique string that ids the type. Must be overriden in sub classes.
|
||||
# Unique string that ids the type. Must be overridden in sub classes.
|
||||
TypeId = None # type: str
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Awaitable, Callable, Dict, Optional
|
||||
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
@@ -26,8 +27,9 @@ from synapse.http.servlet import (
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.rest.well_known import WellKnownBuilder
|
||||
from synapse.types import UserID
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -113,7 +115,7 @@ class LoginRestServlet(RestServlet):
|
||||
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
|
||||
)
|
||||
|
||||
def on_GET(self, request):
|
||||
def on_GET(self, request: SynapseRequest):
|
||||
flows = []
|
||||
if self.jwt_enabled:
|
||||
flows.append({"type": LoginRestServlet.JWT_TYPE})
|
||||
@@ -141,10 +143,10 @@ class LoginRestServlet(RestServlet):
|
||||
|
||||
return 200, {"flows": flows}
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
def on_OPTIONS(self, request: SynapseRequest):
|
||||
return 200, {}
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest):
|
||||
self._address_ratelimiter.ratelimit(request.getClientIP())
|
||||
|
||||
login_submission = parse_json_object_from_request(request)
|
||||
@@ -153,9 +155,9 @@ class LoginRestServlet(RestServlet):
|
||||
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
||||
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
|
||||
):
|
||||
result = await self.do_jwt_login(login_submission)
|
||||
result = await self._do_jwt_login(login_submission)
|
||||
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
||||
result = await self.do_token_login(login_submission)
|
||||
result = await self._do_token_login(login_submission)
|
||||
else:
|
||||
result = await self._do_other_login(login_submission)
|
||||
except KeyError:
|
||||
@@ -166,14 +168,14 @@ class LoginRestServlet(RestServlet):
|
||||
result["well_known"] = well_known_data
|
||||
return 200, result
|
||||
|
||||
async def _do_other_login(self, login_submission):
|
||||
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||
"""Handle non-token/saml/jwt logins
|
||||
|
||||
Args:
|
||||
login_submission:
|
||||
|
||||
Returns:
|
||||
dict: HTTP response
|
||||
HTTP response
|
||||
"""
|
||||
# Log the request we got, but only certain fields to minimise the chance of
|
||||
# logging someone's password (even if they accidentally put it in the wrong
|
||||
@@ -206,11 +208,14 @@ class LoginRestServlet(RestServlet):
|
||||
if medium is None or address is None:
|
||||
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||
|
||||
# For emails, canonicalise the address.
|
||||
# We store all email addresses canonicalised in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
if medium == "email":
|
||||
# For emails, transform the address to lowercase.
|
||||
# We store all email addreses as lowercase in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
address = address.lower()
|
||||
try:
|
||||
address = canonicalise_email(address)
|
||||
except ValueError as e:
|
||||
raise SynapseError(400, str(e))
|
||||
|
||||
# We also apply account rate limiting using the 3PID as a key, as
|
||||
# otherwise using 3PID bypasses the ratelimiting based on user ID.
|
||||
@@ -288,25 +293,30 @@ class LoginRestServlet(RestServlet):
|
||||
return result
|
||||
|
||||
async def _complete_login(
|
||||
self, user_id, login_submission, callback=None, create_non_existent_users=False
|
||||
):
|
||||
self,
|
||||
user_id: str,
|
||||
login_submission: JsonDict,
|
||||
callback: Optional[
|
||||
Callable[[Dict[str, str]], Awaitable[Dict[str, str]]]
|
||||
] = None,
|
||||
create_non_existent_users: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Called when we've successfully authed the user and now need to
|
||||
actually login them in (e.g. create devices). This gets called on
|
||||
all succesful logins.
|
||||
all successful logins.
|
||||
|
||||
Applies the ratelimiting for succesful login attempts against an
|
||||
Applies the ratelimiting for successful login attempts against an
|
||||
account.
|
||||
|
||||
Args:
|
||||
user_id (str): ID of the user to register.
|
||||
login_submission (dict): Dictionary of login information.
|
||||
callback (func|None): Callback function to run after registration.
|
||||
create_non_existent_users (bool): Whether to create the user if
|
||||
they don't exist. Defaults to False.
|
||||
user_id: ID of the user to register.
|
||||
login_submission: Dictionary of login information.
|
||||
callback: Callback function to run after registration.
|
||||
create_non_existent_users: Whether to create the user if they don't
|
||||
exist. Defaults to False.
|
||||
|
||||
Returns:
|
||||
result (Dict[str,str]): Dictionary of account information after
|
||||
successful registration.
|
||||
result: Dictionary of account information after successful registration.
|
||||
"""
|
||||
|
||||
# Before we actually log them in we check if they've already logged in
|
||||
@@ -340,7 +350,7 @@ class LoginRestServlet(RestServlet):
|
||||
|
||||
return result
|
||||
|
||||
async def do_token_login(self, login_submission):
|
||||
async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||
token = login_submission["token"]
|
||||
auth_handler = self.auth_handler
|
||||
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
@@ -350,7 +360,7 @@ class LoginRestServlet(RestServlet):
|
||||
result = await self._complete_login(user_id, login_submission)
|
||||
return result
|
||||
|
||||
async def do_jwt_login(self, login_submission):
|
||||
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
|
||||
token = login_submission.get("token", None)
|
||||
if token is None:
|
||||
raise LoginError(
|
||||
|
||||
@@ -217,10 +217,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
|
||||
)
|
||||
event_id = event.event_id
|
||||
|
||||
ret = {} # type: dict
|
||||
if event_id:
|
||||
set_tag("event_id", event_id)
|
||||
ret = {"event_id": event_id}
|
||||
set_tag("event_id", event_id)
|
||||
ret = {"event_id": event_id}
|
||||
return 200, ret
|
||||
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ class VoipRestServlet(RestServlet):
|
||||
# We need to use standard padded base64 encoding here
|
||||
# encode_base64 because we need to add the standard padding to get the
|
||||
# same result as the TURN server.
|
||||
password = base64.b64encode(mac.digest())
|
||||
password = base64.b64encode(mac.digest()).decode("ascii")
|
||||
|
||||
elif turnUris and turnUsername and turnPassword and userLifetime:
|
||||
username = turnUsername
|
||||
|
||||
@@ -30,7 +30,7 @@ from synapse.http.servlet import (
|
||||
from synapse.push.mailer import Mailer, load_jinja2_templates
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.stringutils import assert_valid_client_secret, random_string
|
||||
from synapse.util.threepids import check_3pid_allowed
|
||||
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
|
||||
|
||||
from ._base import client_patterns, interactive_auth_handler
|
||||
|
||||
@@ -83,7 +83,15 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||
client_secret = body["client_secret"]
|
||||
assert_valid_client_secret(client_secret)
|
||||
|
||||
email = body["email"]
|
||||
# Canonicalise the email address. The addresses are all stored canonicalised
|
||||
# in the database. This allows the user to reset his password without having to
|
||||
# know the exact spelling (eg. upper and lower case) of address in the database.
|
||||
# Stored in the database "foo@bar.com"
|
||||
# User requests with "FOO@bar.com" would raise a Not Found error
|
||||
try:
|
||||
email = canonicalise_email(body["email"])
|
||||
except ValueError as e:
|
||||
raise SynapseError(400, str(e))
|
||||
send_attempt = body["send_attempt"]
|
||||
next_link = body.get("next_link") # Optional param
|
||||
|
||||
@@ -94,6 +102,10 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
# The email will be sent to the stored address.
|
||||
# This avoids a potential account hijack by requesting a password reset to
|
||||
# an email address which is controlled by the attacker but which, after
|
||||
# canonicalisation, matches the one in our database.
|
||||
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
||||
"email", email
|
||||
)
|
||||
@@ -274,10 +286,13 @@ class PasswordRestServlet(RestServlet):
|
||||
if "medium" not in threepid or "address" not in threepid:
|
||||
raise SynapseError(500, "Malformed threepid")
|
||||
if threepid["medium"] == "email":
|
||||
# For emails, transform the address to lowercase.
|
||||
# We store all email addreses as lowercase in the DB.
|
||||
# For emails, canonicalise the address.
|
||||
# We store all email addresses canonicalised in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
threepid["address"] = threepid["address"].lower()
|
||||
try:
|
||||
threepid["address"] = canonicalise_email(threepid["address"])
|
||||
except ValueError as e:
|
||||
raise SynapseError(400, str(e))
|
||||
# if using email, we must know about the email they're authing with!
|
||||
threepid_user_id = await self.datastore.get_user_id_by_threepid(
|
||||
threepid["medium"], threepid["address"]
|
||||
@@ -392,7 +407,16 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||
client_secret = body["client_secret"]
|
||||
assert_valid_client_secret(client_secret)
|
||||
|
||||
email = body["email"]
|
||||
# Canonicalise the email address. The addresses are all stored canonicalised
|
||||
# in the database.
|
||||
# This ensures that the validation email is sent to the canonicalised address
|
||||
# as it will later be entered into the database.
|
||||
# Otherwise the email will be sent to "FOO@bar.com" and stored as
|
||||
# "foo@bar.com" in database.
|
||||
try:
|
||||
email = canonicalise_email(body["email"])
|
||||
except ValueError as e:
|
||||
raise SynapseError(400, str(e))
|
||||
send_attempt = body["send_attempt"]
|
||||
next_link = body.get("next_link") # Optional param
|
||||
|
||||
@@ -403,9 +427,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existing_user_id = await self.store.get_user_id_by_threepid(
|
||||
"email", body["email"]
|
||||
)
|
||||
existing_user_id = await self.store.get_user_id_by_threepid("email", email)
|
||||
|
||||
if existing_user_id is not None:
|
||||
if self.config.request_token_inhibit_3pid_errors:
|
||||
|
||||
@@ -47,7 +47,7 @@ from synapse.push.mailer import load_jinja2_templates
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
from synapse.util.stringutils import assert_valid_client_secret, random_string
|
||||
from synapse.util.threepids import check_3pid_allowed
|
||||
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
|
||||
|
||||
from ._base import client_patterns, interactive_auth_handler
|
||||
|
||||
@@ -116,7 +116,14 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||
client_secret = body["client_secret"]
|
||||
assert_valid_client_secret(client_secret)
|
||||
|
||||
email = body["email"]
|
||||
# For emails, canonicalise the address.
|
||||
# We store all email addresses canonicalised in the DB.
|
||||
# (See on_POST in EmailThreepidRequestTokenRestServlet
|
||||
# in synapse/rest/client/v2_alpha/account.py)
|
||||
try:
|
||||
email = canonicalise_email(body["email"])
|
||||
except ValueError as e:
|
||||
raise SynapseError(400, str(e))
|
||||
send_attempt = body["send_attempt"]
|
||||
next_link = body.get("next_link") # Optional param
|
||||
|
||||
@@ -128,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
||||
"email", body["email"]
|
||||
"email", email
|
||||
)
|
||||
|
||||
if existing_user_id is not None:
|
||||
@@ -552,6 +559,15 @@ class RegisterRestServlet(RestServlet):
|
||||
if login_type in auth_result:
|
||||
medium = auth_result[login_type]["medium"]
|
||||
address = auth_result[login_type]["address"]
|
||||
# For emails, canonicalise the address.
|
||||
# We store all email addresses canonicalised in the DB.
|
||||
# (See on_POST in EmailThreepidRequestTokenRestServlet
|
||||
# in synapse/rest/client/v2_alpha/account.py)
|
||||
if medium == "email":
|
||||
try:
|
||||
address = canonicalise_email(address)
|
||||
except ValueError as e:
|
||||
raise SynapseError(400, str(e))
|
||||
|
||||
existing_user_id = await self.store.get_user_id_by_threepid(
|
||||
medium, address
|
||||
|
||||
@@ -26,11 +26,7 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import NotFoundError, StoreError, SynapseError
|
||||
from synapse.config import ConfigError
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
respond_with_html,
|
||||
wrap_html_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeHtmlResource, respond_with_html
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.types import UserID
|
||||
|
||||
@@ -48,7 +44,7 @@ else:
|
||||
return a == b
|
||||
|
||||
|
||||
class ConsentResource(DirectServeResource):
|
||||
class ConsentResource(DirectServeHtmlResource):
|
||||
"""A twisted Resource to display a privacy policy and gather consent to it
|
||||
|
||||
When accessed via GET, returns the privacy policy via a template.
|
||||
@@ -119,7 +115,6 @@ class ConsentResource(DirectServeResource):
|
||||
|
||||
self._hmac_secret = hs.config.form_secret.encode("utf-8")
|
||||
|
||||
@wrap_html_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
"""
|
||||
Args:
|
||||
@@ -160,7 +155,6 @@ class ConsentResource(DirectServeResource):
|
||||
except TemplateNotFound:
|
||||
raise NotFoundError("Unknown policy version")
|
||||
|
||||
@wrap_html_request_handler
|
||||
async def _async_render_POST(self, request):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -20,17 +20,13 @@ from signedjson.sign import sign_json
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.crypto.keyring import ServerKeyFetcher
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
respond_with_json_bytes,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes
|
||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteKey(DirectServeResource):
|
||||
class RemoteKey(DirectServeJsonResource):
|
||||
"""HTTP resource for retreiving the TLS certificate and NACL signature
|
||||
verification keys for a collection of servers. Checks that the reported
|
||||
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
|
||||
@@ -92,13 +88,14 @@ class RemoteKey(DirectServeResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs):
|
||||
super().__init__()
|
||||
|
||||
self.fetcher = ServerKeyFetcher(hs)
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||
self.config = hs.config
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
if len(request.postpath) == 1:
|
||||
(server,) = request.postpath
|
||||
@@ -115,7 +112,6 @@ class RemoteKey(DirectServeResource):
|
||||
|
||||
await self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_POST(self, request):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
|
||||
@@ -14,16 +14,10 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
respond_with_json,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
|
||||
|
||||
class MediaConfigResource(DirectServeResource):
|
||||
class MediaConfigResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs):
|
||||
@@ -33,11 +27,9 @@ class MediaConfigResource(DirectServeResource):
|
||||
self.auth = hs.get_auth()
|
||||
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
await self.auth.get_user_by_req(request)
|
||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
async def _async_render_OPTIONS(self, request):
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@@ -15,18 +15,14 @@
|
||||
import logging
|
||||
|
||||
import synapse.http.servlet
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
set_cors_headers,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||
|
||||
from ._base import parse_media_id, respond_404
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadResource(DirectServeResource):
|
||||
class DownloadResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo):
|
||||
@@ -34,10 +30,6 @@ class DownloadResource(DirectServeResource):
|
||||
self.media_repo = media_repo
|
||||
self.server_name = hs.hostname
|
||||
|
||||
# this is expected by @wrap_json_request_handler
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
set_cors_headers(request)
|
||||
request.setHeader(
|
||||
|
||||
@@ -34,10 +34,9 @@ from twisted.internet.error import DNSLookupError
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
DirectServeJsonResource,
|
||||
respond_with_json,
|
||||
respond_with_json_bytes,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
@@ -58,7 +57,7 @@ OG_TAG_NAME_MAXLEN = 50
|
||||
OG_TAG_VALUE_MAXLEN = 1000
|
||||
|
||||
|
||||
class PreviewUrlResource(DirectServeResource):
|
||||
class PreviewUrlResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo, media_storage):
|
||||
@@ -108,11 +107,10 @@ class PreviewUrlResource(DirectServeResource):
|
||||
self._start_expire_url_cache_data, 10 * 1000
|
||||
)
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
async def _async_render_OPTIONS(self, request):
|
||||
request.setHeader(b"Allow", b"OPTIONS, GET")
|
||||
return respond_with_json(request, 200, {}, send_cors=True)
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
|
||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||
|
||||
@@ -16,11 +16,7 @@
|
||||
|
||||
import logging
|
||||
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
set_cors_headers,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
|
||||
from ._base import (
|
||||
@@ -34,7 +30,7 @@ from ._base import (
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThumbnailResource(DirectServeResource):
|
||||
class ThumbnailResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo, media_storage):
|
||||
@@ -45,9 +41,7 @@ class ThumbnailResource(DirectServeResource):
|
||||
self.media_storage = media_storage
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.server_name = hs.hostname
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
set_cors_headers(request)
|
||||
server_name, media_id, _ = parse_media_id(request)
|
||||
|
||||
@@ -12,11 +12,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import PIL.Image as Image
|
||||
from PIL import Image as Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -15,20 +15,14 @@
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import (
|
||||
DirectServeResource,
|
||||
respond_with_json,
|
||||
wrap_json_request_handler,
|
||||
)
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_string
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UploadResource(DirectServeResource):
|
||||
class UploadResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo):
|
||||
@@ -43,11 +37,9 @@ class UploadResource(DirectServeResource):
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def render_OPTIONS(self, request):
|
||||
async def _async_render_OPTIONS(self, request):
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@wrap_json_request_handler
|
||||
async def _async_render_POST(self, request):
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
|
||||
@@ -14,18 +14,17 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from synapse.http.server import DirectServeResource, wrap_html_request_handler
|
||||
from synapse.http.server import DirectServeHtmlResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OIDCCallbackResource(DirectServeResource):
|
||||
class OIDCCallbackResource(DirectServeHtmlResource):
|
||||
isLeaf = 1
|
||||
|
||||
def __init__(self, hs):
|
||||
super().__init__()
|
||||
self._oidc_handler = hs.get_oidc_handler()
|
||||
|
||||
@wrap_html_request_handler
|
||||
async def _async_render_GET(self, request):
|
||||
return await self._oidc_handler.handle_oidc_callback(request)
|
||||
await self._oidc_handler.handle_oidc_callback(request)
|
||||
|
||||
@@ -16,10 +16,10 @@
|
||||
from twisted.python import failure
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import DirectServeResource, return_html_error
|
||||
from synapse.http.server import DirectServeHtmlResource, return_html_error
|
||||
|
||||
|
||||
class SAML2ResponseResource(DirectServeResource):
|
||||
class SAML2ResponseResource(DirectServeHtmlResource):
|
||||
"""A Twisted web resource which handles the SAML response"""
|
||||
|
||||
isLeaf = 1
|
||||
|
||||
+1
-2
@@ -19,7 +19,6 @@ Injectable secrets module for Synapse.
|
||||
See https://docs.python.org/3/library/secrets.html#module-secrets for the API
|
||||
used in Python 3.6, and the API emulated in Python 2.7.
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
# secrets is available since python 3.6
|
||||
@@ -31,8 +30,8 @@ if sys.version_info[0:2] >= (3, 6):
|
||||
|
||||
|
||||
else:
|
||||
import os
|
||||
import binascii
|
||||
import os
|
||||
|
||||
class Secrets(object):
|
||||
def token_bytes(self, nbytes=32):
|
||||
|
||||
@@ -232,6 +232,8 @@ class HomeServer(object):
|
||||
|
||||
self._reactor = reactor
|
||||
self.hostname = hostname
|
||||
# the key we use to sign events and requests
|
||||
self.signing_key = config.key.signing_key[0]
|
||||
self.config = config
|
||||
self._building = {}
|
||||
self._listening_services = []
|
||||
|
||||
@@ -16,10 +16,12 @@
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Iterable, Optional, Tuple
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.replication.tcp.streams import BackfillStream, CachesStream
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStream,
|
||||
EventsStreamCurrentStateRow,
|
||||
EventsStreamEventRow,
|
||||
)
|
||||
@@ -44,13 +46,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
|
||||
async def get_all_updated_caches(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
):
|
||||
"""Fetches cache invalidation rows between the two given IDs written
|
||||
by the given instance. Returns at most `limit` rows.
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
"""Get updates for caches replication stream.
|
||||
|
||||
Args:
|
||||
instance_name: The writer we want to fetch updates from. Unused
|
||||
here since there is only ever one writer.
|
||||
last_id: The token to fetch updates from. Exclusive.
|
||||
current_id: The token to fetch updates up to. Inclusive.
|
||||
limit: The requested limit for the number of rows to return. The
|
||||
function may return more or fewer rows.
|
||||
|
||||
Returns:
|
||||
A tuple consisting of: the updates, a token to use to fetch
|
||||
subsequent updates, and whether we returned fewer rows than exists
|
||||
between the requested tokens due to the limit.
|
||||
|
||||
The token returned can be used in a subsequent call to this
|
||||
function to get further updatees.
|
||||
|
||||
The updates are a list of 2-tuples of stream ID and the row data
|
||||
"""
|
||||
|
||||
if last_id == current_id:
|
||||
return []
|
||||
return [], current_id, False
|
||||
|
||||
def get_all_updated_caches_txn(txn):
|
||||
# We purposefully don't bound by the current token, as we want to
|
||||
@@ -64,17 +83,24 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (last_id, instance_name, limit))
|
||||
return txn.fetchall()
|
||||
updates = [(row[0], row[1:]) for row in txn]
|
||||
limited = False
|
||||
upto_token = current_id
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
return await self.db.runInteraction(
|
||||
"get_all_updated_caches", get_all_updated_caches_txn
|
||||
)
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "events":
|
||||
if stream_name == EventsStream.NAME:
|
||||
for row in rows:
|
||||
self._process_event_stream_row(token, row)
|
||||
elif stream_name == "backfill":
|
||||
elif stream_name == BackfillStream.NAME:
|
||||
for row in rows:
|
||||
self._invalidate_caches_for_event(
|
||||
-token,
|
||||
@@ -86,7 +112,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
row.relates_to,
|
||||
backfilled=True,
|
||||
)
|
||||
elif stream_name == "caches":
|
||||
elif stream_name == CachesStream.NAME:
|
||||
if self._cache_id_gen:
|
||||
self._cache_id_gen.advance(instance_name, token)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
@@ -207,31 +208,46 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
|
||||
)
|
||||
|
||||
def get_all_new_device_messages(self, last_pos, current_pos, limit):
|
||||
"""
|
||||
async def get_all_new_device_messages(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
"""Get updates for to device replication stream.
|
||||
|
||||
Args:
|
||||
last_pos(int):
|
||||
current_pos(int):
|
||||
limit(int):
|
||||
instance_name: The writer we want to fetch updates from. Unused
|
||||
here since there is only ever one writer.
|
||||
last_id: The token to fetch updates from. Exclusive.
|
||||
current_id: The token to fetch updates up to. Inclusive.
|
||||
limit: The requested limit for the number of rows to return. The
|
||||
function may return more or fewer rows.
|
||||
|
||||
Returns:
|
||||
A deferred list of rows from the device inbox
|
||||
A tuple consisting of: the updates, a token to use to fetch
|
||||
subsequent updates, and whether we returned fewer rows than exists
|
||||
between the requested tokens due to the limit.
|
||||
|
||||
The token returned can be used in a subsequent call to this
|
||||
function to get further updatees.
|
||||
|
||||
The updates are a list of 2-tuples of stream ID and the row data
|
||||
"""
|
||||
if last_pos == current_pos:
|
||||
return defer.succeed([])
|
||||
|
||||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
def get_all_new_device_messages_txn(txn):
|
||||
# We limit like this as we might have multiple rows per stream_id, and
|
||||
# we want to make sure we always get all entries for any stream_id
|
||||
# we return.
|
||||
upper_pos = min(current_pos, last_pos + limit)
|
||||
upper_pos = min(current_id, last_id + limit)
|
||||
sql = (
|
||||
"SELECT max(stream_id), user_id"
|
||||
" FROM device_inbox"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" GROUP BY user_id"
|
||||
)
|
||||
txn.execute(sql, (last_pos, upper_pos))
|
||||
rows = txn.fetchall()
|
||||
txn.execute(sql, (last_id, upper_pos))
|
||||
updates = [(row[0], row[1:]) for row in txn]
|
||||
|
||||
sql = (
|
||||
"SELECT max(stream_id), destination"
|
||||
@@ -239,15 +255,21 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" GROUP BY destination"
|
||||
)
|
||||
txn.execute(sql, (last_pos, upper_pos))
|
||||
rows.extend(txn)
|
||||
txn.execute(sql, (last_id, upper_pos))
|
||||
updates.extend((row[0], row[1:]) for row in txn)
|
||||
|
||||
# Order by ascending stream ordering
|
||||
rows.sort()
|
||||
updates.sort()
|
||||
|
||||
return rows
|
||||
limited = False
|
||||
upto_token = current_id
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
|
||||
return self.db.runInteraction(
|
||||
return updates, upto_token, limited
|
||||
|
||||
return await self.db.runInteraction(
|
||||
"get_all_new_device_messages", get_all_new_device_messages_txn
|
||||
)
|
||||
|
||||
|
||||
@@ -582,32 +582,58 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
return set()
|
||||
|
||||
async def get_all_device_list_changes_for_remotes(
|
||||
self, from_key: int, to_key: int, limit: int,
|
||||
) -> List[Tuple[int, str]]:
|
||||
"""Return a list of `(stream_id, entity)` which is the combined list of
|
||||
changes to devices and which destinations need to be poked. Entity is
|
||||
either a user ID (starting with '@') or a remote destination.
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
"""Get updates for device lists replication stream.
|
||||
|
||||
Args:
|
||||
instance_name: The writer we want to fetch updates from. Unused
|
||||
here since there is only ever one writer.
|
||||
last_id: The token to fetch updates from. Exclusive.
|
||||
current_id: The token to fetch updates up to. Inclusive.
|
||||
limit: The requested limit for the number of rows to return. The
|
||||
function may return more or fewer rows.
|
||||
|
||||
Returns:
|
||||
A tuple consisting of: the updates, a token to use to fetch
|
||||
subsequent updates, and whether we returned fewer rows than exists
|
||||
between the requested tokens due to the limit.
|
||||
|
||||
The token returned can be used in a subsequent call to this
|
||||
function to get further updatees.
|
||||
|
||||
The updates are a list of 2-tuples of stream ID and the row data
|
||||
"""
|
||||
|
||||
# This query Does The Right Thing where it'll correctly apply the
|
||||
# bounds to the inner queries.
|
||||
sql = """
|
||||
SELECT stream_id, entity FROM (
|
||||
SELECT stream_id, user_id AS entity FROM device_lists_stream
|
||||
UNION ALL
|
||||
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
|
||||
) AS e
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
LIMIT ?
|
||||
"""
|
||||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
return await self.db.execute(
|
||||
def _get_all_device_list_changes_for_remotes(txn):
|
||||
# This query Does The Right Thing where it'll correctly apply the
|
||||
# bounds to the inner queries.
|
||||
sql = """
|
||||
SELECT stream_id, entity FROM (
|
||||
SELECT stream_id, user_id AS entity FROM device_lists_stream
|
||||
UNION ALL
|
||||
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
|
||||
) AS e
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
updates = [(row[0], row[1:]) for row in txn]
|
||||
limited = False
|
||||
upto_token = current_id
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
return await self.db.runInteraction(
|
||||
"get_all_device_list_changes_for_remotes",
|
||||
None,
|
||||
sql,
|
||||
from_key,
|
||||
to_key,
|
||||
limit,
|
||||
_get_all_device_list_changes_for_remotes,
|
||||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
|
||||
@@ -479,34 +479,61 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
|
||||
return result
|
||||
|
||||
def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
|
||||
"""Return a list of changes from the user signature stream to notify remotes.
|
||||
async def get_all_user_signature_changes_for_remotes(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
"""Get updates for groups replication stream.
|
||||
|
||||
Note that the user signature stream represents when a user signs their
|
||||
device with their user-signing key, which is not published to other
|
||||
users or servers, so no `destination` is needed in the returned
|
||||
list. However, this is needed to poke workers.
|
||||
|
||||
Args:
|
||||
from_key (int): the stream ID to start at (exclusive)
|
||||
to_key (int): the stream ID to end at (inclusive)
|
||||
instance_name: The writer we want to fetch updates from. Unused
|
||||
here since there is only ever one writer.
|
||||
last_id: The token to fetch updates from. Exclusive.
|
||||
current_id: The token to fetch updates up to. Inclusive.
|
||||
limit: The requested limit for the number of rows to return. The
|
||||
function may return more or fewer rows.
|
||||
|
||||
Returns:
|
||||
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
|
||||
A tuple consisting of: the updates, a token to use to fetch
|
||||
subsequent updates, and whether we returned fewer rows than exists
|
||||
between the requested tokens due to the limit.
|
||||
|
||||
The token returned can be used in a subsequent call to this
|
||||
function to get further updatees.
|
||||
|
||||
The updates are a list of 2-tuples of stream ID and the row data
|
||||
"""
|
||||
sql = """
|
||||
SELECT stream_id, from_user_id AS user_id
|
||||
FROM user_signature_stream
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
return self.db.execute(
|
||||
|
||||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
def _get_all_user_signature_changes_for_remotes_txn(txn):
|
||||
sql = """
|
||||
SELECT stream_id, from_user_id AS user_id
|
||||
FROM user_signature_stream
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
|
||||
updates = [(row[0], (row[1:])) for row in txn]
|
||||
|
||||
limited = False
|
||||
upto_token = current_id
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
return await self.db.runInteraction(
|
||||
"get_all_user_signature_changes_for_remotes",
|
||||
None,
|
||||
sql,
|
||||
from_key,
|
||||
to_key,
|
||||
limit,
|
||||
_get_all_user_signature_changes_for_remotes_txn,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections import OrderedDict, namedtuple
|
||||
@@ -28,12 +27,7 @@ from prometheus_client import Counter
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.metrics
|
||||
from synapse.api.constants import (
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
Membership,
|
||||
RelationTypes,
|
||||
)
|
||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
from synapse.events import EventBase # noqa: F401
|
||||
@@ -48,8 +42,8 @@ from synapse.util.frozenutils import frozendict_json_encoder
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.data_stores.main import DataStore
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.data_stores.main import DataStore
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -820,7 +814,6 @@ class PersistEventsStore:
|
||||
"event_reference_hashes",
|
||||
"event_search",
|
||||
"event_to_state_groups",
|
||||
"local_invites",
|
||||
"state_events",
|
||||
"rejections",
|
||||
"redactions",
|
||||
@@ -1197,65 +1190,27 @@ class PersistEventsStore:
|
||||
(event.state_key,),
|
||||
)
|
||||
|
||||
# We update the local_invites table only if the event is "current",
|
||||
# i.e., its something that has just happened. If the event is an
|
||||
# outlier it is only current if its an "out of band membership",
|
||||
# like a remote invite or a rejection of a remote invite.
|
||||
is_new_state = not backfilled and (
|
||||
not event.internal_metadata.is_outlier()
|
||||
or event.internal_metadata.is_out_of_band_membership()
|
||||
)
|
||||
is_mine = self.is_mine_id(event.state_key)
|
||||
if is_new_state and is_mine:
|
||||
if event.membership == Membership.INVITE:
|
||||
self.db.simple_insert_txn(
|
||||
txn,
|
||||
table="local_invites",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"invitee": event.state_key,
|
||||
"inviter": event.sender,
|
||||
"room_id": event.room_id,
|
||||
"stream_id": event.internal_metadata.stream_ordering,
|
||||
},
|
||||
)
|
||||
else:
|
||||
sql = (
|
||||
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
|
||||
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
||||
" AND replaced_by is NULL"
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
event.internal_metadata.stream_ordering,
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
event.state_key,
|
||||
),
|
||||
)
|
||||
|
||||
# We also update the `local_current_membership` table with
|
||||
# latest invite info. This will usually get updated by the
|
||||
# `current_state_events` handling, unless its an outlier.
|
||||
if event.internal_metadata.is_outlier():
|
||||
# This should only happen for out of band memberships, so
|
||||
# we add a paranoia check.
|
||||
assert event.internal_metadata.is_out_of_band_membership()
|
||||
|
||||
self.db.simple_upsert_txn(
|
||||
txn,
|
||||
table="local_current_membership",
|
||||
keyvalues={
|
||||
"room_id": event.room_id,
|
||||
"user_id": event.state_key,
|
||||
},
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"membership": event.membership,
|
||||
},
|
||||
)
|
||||
# We update the local_current_membership table only if the event is
|
||||
# "current", i.e., its something that has just happened.
|
||||
#
|
||||
# This will usually get updated by the `current_state_events` handling,
|
||||
# unless its an outlier, and an outlier is only "current" if it's an "out of
|
||||
# band membership", like a remote invite or a rejection of a remote invite.
|
||||
if (
|
||||
self.is_mine_id(event.state_key)
|
||||
and not backfilled
|
||||
and event.internal_metadata.is_outlier()
|
||||
and event.internal_metadata.is_out_of_band_membership()
|
||||
):
|
||||
self.db.simple_upsert_txn(
|
||||
txn,
|
||||
table="local_current_membership",
|
||||
keyvalues={"room_id": event.room_id, "user_id": event.state_key},
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"membership": event.membership,
|
||||
},
|
||||
)
|
||||
|
||||
def _handle_event_relations(self, txn, event):
|
||||
"""Handles inserting relation data during peristence of events
|
||||
@@ -1586,31 +1541,3 @@ class PersistEventsStore:
|
||||
if not ev.internal_metadata.is_outlier()
|
||||
],
|
||||
)
|
||||
|
||||
async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
|
||||
"""Mark the invite has having been rejected even though we failed to
|
||||
create a leave event for it.
|
||||
"""
|
||||
|
||||
sql = (
|
||||
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
|
||||
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
|
||||
" AND replaced_by is NULL"
|
||||
)
|
||||
|
||||
def f(txn, stream_ordering):
|
||||
txn.execute(sql, (stream_ordering, True, room_id, user_id))
|
||||
|
||||
# We also clear this entry from `local_current_membership`.
|
||||
# Ideally we'd point to a leave event, but we don't have one, so
|
||||
# nevermind.
|
||||
self.db.simple_delete_txn(
|
||||
txn,
|
||||
table="local_current_membership",
|
||||
keyvalues={"room_id": room_id, "user_id": user_id},
|
||||
)
|
||||
|
||||
with self._stream_id_gen.get_next() as stream_ordering:
|
||||
await self.db.runInteraction("locally_reject_invite", f, stream_ordering)
|
||||
|
||||
return stream_ordering
|
||||
|
||||
@@ -38,6 +38,8 @@ from synapse.events.utils import prune_event
|
||||
from synapse.logging.context import PreserveLoggingContext, current_context
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.replication.tcp.streams import BackfillStream
|
||||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
@@ -80,10 +82,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
# We are the process in charge of generating stream ids for events,
|
||||
# so instantiate ID generators based on the database
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
"events",
|
||||
"stream_ordering",
|
||||
extra_tables=[("local_invites", "stream_id")],
|
||||
db_conn, "events", "stream_ordering",
|
||||
)
|
||||
self._backfill_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
@@ -113,9 +112,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
self._event_fetch_ongoing = 0
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == "events":
|
||||
if stream_name == EventsStream.NAME:
|
||||
self._stream_id_gen.advance(token)
|
||||
elif stream_name == "backfill":
|
||||
elif stream_name == BackfillStream.NAME:
|
||||
self._backfill_id_gen.advance(-token)
|
||||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Tuple
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
@@ -526,13 +528,35 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
"get_groups_changes_for_user", _get_groups_changes_for_user_txn
|
||||
)
|
||||
|
||||
def get_all_groups_changes(self, from_token, to_token, limit):
|
||||
from_token = int(from_token)
|
||||
has_changed = self._group_updates_stream_cache.has_any_entity_changed(
|
||||
from_token
|
||||
)
|
||||
async def get_all_groups_changes(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
"""Get updates for groups replication stream.
|
||||
|
||||
Args:
|
||||
instance_name: The writer we want to fetch updates from. Unused
|
||||
here since there is only ever one writer.
|
||||
last_id: The token to fetch updates from. Exclusive.
|
||||
current_id: The token to fetch updates up to. Inclusive.
|
||||
limit: The requested limit for the number of rows to return. The
|
||||
function may return more or fewer rows.
|
||||
|
||||
Returns:
|
||||
A tuple consisting of: the updates, a token to use to fetch
|
||||
subsequent updates, and whether we returned fewer rows than exists
|
||||
between the requested tokens due to the limit.
|
||||
|
||||
The token returned can be used in a subsequent call to this
|
||||
function to get further updatees.
|
||||
|
||||
The updates are a list of 2-tuples of stream ID and the row data
|
||||
"""
|
||||
|
||||
last_id = int(last_id)
|
||||
has_changed = self._group_updates_stream_cache.has_any_entity_changed(last_id)
|
||||
|
||||
if not has_changed:
|
||||
return defer.succeed([])
|
||||
return [], current_id, False
|
||||
|
||||
def _get_all_groups_changes_txn(txn):
|
||||
sql = """
|
||||
@@ -541,13 +565,21 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (from_token, to_token, limit))
|
||||
return [
|
||||
(stream_id, group_id, user_id, gtype, json.loads(content_json))
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
updates = [
|
||||
(stream_id, (group_id, user_id, gtype, json.loads(content_json)))
|
||||
for stream_id, group_id, user_id, gtype, content_json in txn
|
||||
]
|
||||
|
||||
return self.db.runInteraction(
|
||||
limited = False
|
||||
upto_token = current_id
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
|
||||
return updates, upto_token, limited
|
||||
|
||||
return await self.db.runInteraction(
|
||||
"get_all_groups_changes", _get_all_groups_changes_txn
|
||||
)
|
||||
|
||||
|
||||
@@ -361,7 +361,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
"event_push_summary",
|
||||
"pusher_throttle",
|
||||
"group_summary_rooms",
|
||||
"local_invites",
|
||||
"room_account_data",
|
||||
"room_tags",
|
||||
"local_current_membership",
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Iterable, Iterator
|
||||
from typing import Iterable, Iterator, List, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
|
||||
@@ -98,77 +98,69 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
|
||||
return rows
|
||||
|
||||
def get_all_updated_pushers(self, last_id, current_id, limit):
|
||||
if last_id == current_id:
|
||||
return defer.succeed(([], []))
|
||||
async def get_all_updated_pushers_rows(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
"""Get updates for pushers replication stream.
|
||||
|
||||
def get_all_updated_pushers_txn(txn):
|
||||
sql = (
|
||||
"SELECT id, user_name, access_token, profile_tag, kind,"
|
||||
" app_id, app_display_name, device_display_name, pushkey, ts,"
|
||||
" lang, data"
|
||||
" FROM pushers"
|
||||
" WHERE ? < id AND id <= ?"
|
||||
" ORDER BY id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
updated = txn.fetchall()
|
||||
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, app_id, pushkey"
|
||||
" FROM deleted_pushers"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
deleted = txn.fetchall()
|
||||
|
||||
return updated, deleted
|
||||
|
||||
return self.db.runInteraction(
|
||||
"get_all_updated_pushers", get_all_updated_pushers_txn
|
||||
)
|
||||
|
||||
def get_all_updated_pushers_rows(self, last_id, current_id, limit):
|
||||
"""Get all the pushers that have changed between the given tokens.
|
||||
Args:
|
||||
instance_name: The writer we want to fetch updates from. Unused
|
||||
here since there is only ever one writer.
|
||||
last_id: The token to fetch updates from. Exclusive.
|
||||
current_id: The token to fetch updates up to. Inclusive.
|
||||
limit: The requested limit for the number of rows to return. The
|
||||
function may return more or fewer rows.
|
||||
|
||||
Returns:
|
||||
Deferred(list(tuple)): each tuple consists of:
|
||||
stream_id (str)
|
||||
user_id (str)
|
||||
app_id (str)
|
||||
pushkey (str)
|
||||
was_deleted (bool): whether the pusher was added/updated (False)
|
||||
or deleted (True)
|
||||
A tuple consisting of: the updates, a token to use to fetch
|
||||
subsequent updates, and whether we returned fewer rows than exists
|
||||
between the requested tokens due to the limit.
|
||||
|
||||
The token returned can be used in a subsequent call to this
|
||||
function to get further updatees.
|
||||
|
||||
The updates are a list of 2-tuples of stream ID and the row data
|
||||
"""
|
||||
|
||||
if last_id == current_id:
|
||||
return defer.succeed([])
|
||||
return [], current_id, False
|
||||
|
||||
def get_all_updated_pushers_rows_txn(txn):
|
||||
sql = (
|
||||
"SELECT id, user_name, app_id, pushkey"
|
||||
" FROM pushers"
|
||||
" WHERE ? < id AND id <= ?"
|
||||
" ORDER BY id ASC LIMIT ?"
|
||||
)
|
||||
sql = """
|
||||
SELECT id, user_name, app_id, pushkey
|
||||
FROM pushers
|
||||
WHERE ? < id AND id <= ?
|
||||
ORDER BY id ASC LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
results = [list(row) + [False] for row in txn]
|
||||
updates = [
|
||||
(stream_id, (user_name, app_id, pushkey, False))
|
||||
for stream_id, user_name, app_id, pushkey in txn
|
||||
]
|
||||
|
||||
sql = (
|
||||
"SELECT stream_id, user_id, app_id, pushkey"
|
||||
" FROM deleted_pushers"
|
||||
" WHERE ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
sql = """
|
||||
SELECT stream_id, user_id, app_id, pushkey
|
||||
FROM deleted_pushers
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC LIMIT ?
|
||||
"""
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
updates.extend(
|
||||
(stream_id, (user_name, app_id, pushkey, True))
|
||||
for stream_id, user_name, app_id, pushkey in txn
|
||||
)
|
||||
|
||||
results.extend(list(row) + [True] for row in txn)
|
||||
results.sort() # Sort so that they're ordered by stream id
|
||||
updates.sort() # Sort so that they're ordered by stream id
|
||||
|
||||
return results
|
||||
limited = False
|
||||
upper_bound = current_id
|
||||
if len(updates) >= limit:
|
||||
limited = True
|
||||
upper_bound = updates[-1][0]
|
||||
|
||||
return self.db.runInteraction(
|
||||
return updates, upper_bound, limited
|
||||
|
||||
return await self.db.runInteraction(
|
||||
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
|
||||
)
|
||||
|
||||
|
||||
@@ -803,7 +803,32 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
|
||||
return total_media_quarantined
|
||||
|
||||
def get_all_new_public_rooms(self, prev_id, current_id, limit):
|
||||
async def get_all_new_public_rooms(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
"""Get updates for public rooms replication stream.
|
||||
|
||||
Args:
|
||||
instance_name: The writer we want to fetch updates from. Unused
|
||||
here since there is only ever one writer.
|
||||
last_id: The token to fetch updates from. Exclusive.
|
||||
current_id: The token to fetch updates up to. Inclusive.
|
||||
limit: The requested limit for the number of rows to return. The
|
||||
function may return more or fewer rows.
|
||||
|
||||
Returns:
|
||||
A tuple consisting of: the updates, a token to use to fetch
|
||||
subsequent updates, and whether we returned fewer rows than exists
|
||||
between the requested tokens due to the limit.
|
||||
|
||||
The token returned can be used in a subsequent call to this
|
||||
function to get further updatees.
|
||||
|
||||
The updates are a list of 2-tuples of stream ID and the row data
|
||||
"""
|
||||
if last_id == current_id:
|
||||
return [], current_id, False
|
||||
|
||||
def get_all_new_public_rooms(txn):
|
||||
sql = """
|
||||
SELECT stream_id, room_id, visibility, appservice_id, network_id
|
||||
@@ -813,13 +838,17 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (prev_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
updates = [(row[0], row[1:]) for row in txn]
|
||||
limited = False
|
||||
upto_token = current_id
|
||||
if len(updates) >= limit:
|
||||
upto_token = updates[-1][0]
|
||||
limited = True
|
||||
|
||||
if prev_id == current_id:
|
||||
return defer.succeed([])
|
||||
return updates, upto_token, limited
|
||||
|
||||
return self.db.runInteraction(
|
||||
return await self.db.runInteraction(
|
||||
"get_all_new_public_rooms", get_all_new_public_rooms
|
||||
)
|
||||
|
||||
|
||||
@@ -11,11 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import simplejson
|
||||
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.prepare_database import get_statements
|
||||
|
||||
@@ -66,7 +64,7 @@ def run_create(cur, database_engine, *args, **kwargs):
|
||||
"max_stream_id_exclusive": max_stream_id + 1,
|
||||
"rows_inserted": 0,
|
||||
}
|
||||
progress_json = simplejson.dumps(progress)
|
||||
progress_json = json.dumps(progress)
|
||||
|
||||
sql = (
|
||||
"INSERT into background_updates (update_name, progress_json)"
|
||||
|
||||
@@ -11,11 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import simplejson
|
||||
|
||||
from synapse.storage.prepare_database import get_statements
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs):
|
||||
"max_stream_id_exclusive": max_stream_id + 1,
|
||||
"rows_inserted": 0,
|
||||
}
|
||||
progress_json = simplejson.dumps(progress)
|
||||
progress_json = json.dumps(progress)
|
||||
|
||||
sql = (
|
||||
"INSERT into background_updates (update_name, progress_json)"
|
||||
|
||||
@@ -11,11 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import simplejson
|
||||
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.prepare_database import get_statements
|
||||
|
||||
@@ -50,7 +48,7 @@ def run_create(cur, database_engine, *args, **kwargs):
|
||||
"rows_inserted": 0,
|
||||
"have_added_indexes": False,
|
||||
}
|
||||
progress_json = simplejson.dumps(progress)
|
||||
progress_json = json.dumps(progress)
|
||||
|
||||
sql = (
|
||||
"INSERT into background_updates (update_name, progress_json)"
|
||||
|
||||
@@ -11,11 +11,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
import simplejson
|
||||
|
||||
from synapse.storage.prepare_database import get_statements
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -45,7 +43,7 @@ def run_create(cur, database_engine, *args, **kwargs):
|
||||
"max_stream_id_exclusive": max_stream_id + 1,
|
||||
"rows_inserted": 0,
|
||||
}
|
||||
progress_json = simplejson.dumps(progress)
|
||||
progress_json = json.dumps(progress)
|
||||
|
||||
sql = (
|
||||
"INSERT into background_updates (update_name, progress_json)"
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Tuple
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
@@ -53,18 +54,32 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||
|
||||
return deferred
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_all_updated_tags(self, last_id, current_id, limit):
|
||||
"""Get all the client tags that have changed on the server
|
||||
async def get_all_updated_tags(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, tuple]], int, bool]:
|
||||
"""Get updates for tags replication stream.
|
||||
|
||||
Args:
|
||||
last_id(int): The position to fetch from.
|
||||
current_id(int): The position to fetch up to.
|
||||
instance_name: The writer we want to fetch updates from. Unused
|
||||
here since there is only ever one writer.
|
||||
last_id: The token to fetch updates from. Exclusive.
|
||||
current_id: The token to fetch updates up to. Inclusive.
|
||||
limit: The requested limit for the number of rows to return. The
|
||||
function may return more or fewer rows.
|
||||
|
||||
Returns:
|
||||
A deferred list of tuples of stream_id int, user_id string,
|
||||
room_id string, tag string and content string.
|
||||
A tuple consisting of: the updates, a token to use to fetch
|
||||
subsequent updates, and whether we returned fewer rows than exists
|
||||
between the requested tokens due to the limit.
|
||||
|
||||
The token returned can be used in a subsequent call to this
|
||||
function to get further updatees.
|
||||
|
||||
The updates are a list of 2-tuples of stream ID and the row data
|
||||
"""
|
||||
|
||||
if last_id == current_id:
|
||||
return []
|
||||
return [], current_id, False
|
||||
|
||||
def get_all_updated_tags_txn(txn):
|
||||
sql = (
|
||||
@@ -76,7 +91,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
tag_ids = yield self.db.runInteraction(
|
||||
tag_ids = await self.db.runInteraction(
|
||||
"get_all_updated_tags", get_all_updated_tags_txn
|
||||
)
|
||||
|
||||
@@ -89,21 +104,27 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||
for tag, content in txn:
|
||||
tags.append(json.dumps(tag) + ":" + content)
|
||||
tag_json = "{" + ",".join(tags) + "}"
|
||||
results.append((stream_id, user_id, room_id, tag_json))
|
||||
results.append((stream_id, (user_id, room_id, tag_json)))
|
||||
|
||||
return results
|
||||
|
||||
batch_size = 50
|
||||
results = []
|
||||
for i in range(0, len(tag_ids), batch_size):
|
||||
tags = yield self.db.runInteraction(
|
||||
tags = await self.db.runInteraction(
|
||||
"get_all_updated_tag_content",
|
||||
get_tag_content,
|
||||
tag_ids[i : i + batch_size],
|
||||
)
|
||||
results.extend(tags)
|
||||
|
||||
return results
|
||||
limited = False
|
||||
upto_token = current_id
|
||||
if len(results) >= limit:
|
||||
upto_token = results[-1][0]
|
||||
limited = True
|
||||
|
||||
return results, upto_token, limited
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_updated_tags(self, user_id, stream_id):
|
||||
|
||||
@@ -17,10 +17,10 @@ from typing import Any, Dict, Optional, Union
|
||||
|
||||
import attr
|
||||
|
||||
import synapse.util.stringutils as stringutils
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import stringutils as stringutils
|
||||
|
||||
|
||||
@attr.s
|
||||
|
||||
@@ -92,10 +92,10 @@ class PostgresEngine(BaseDatabaseEngine):
|
||||
errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
|
||||
|
||||
if ctype != "C":
|
||||
errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
|
||||
errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (ctype,))
|
||||
|
||||
if errors:
|
||||
raise IncorrectDatabaseSetup(
|
||||
logger.warning(
|
||||
"Database is incorrectly configured:\n\n%s\n\n"
|
||||
"See docs/postgres.md for more information." % ("\n".join(errors))
|
||||
)
|
||||
|
||||
@@ -783,9 +783,3 @@ class EventsPersistenceStorage(object):
|
||||
|
||||
for user_id in left_users:
|
||||
await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
|
||||
|
||||
async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
|
||||
"""Mark the invite has having been rejected even though we failed to
|
||||
create a leave event for it.
|
||||
"""
|
||||
return await self.persist_events_store.locally_reject_invite(user_id, room_id)
|
||||
|
||||
@@ -12,12 +12,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Iterable, Iterator, List, Tuple
|
||||
|
||||
from typing_extensions import Protocol
|
||||
|
||||
|
||||
"""
|
||||
Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
|
||||
"""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user