1
0

Merge branch 'develop' into anoa/v2_is

This commit is contained in:
Andrew Morgan
2019-08-30 13:18:37 +01:00
113 changed files with 3409 additions and 763 deletions
+3 -2
View File
@@ -6,6 +6,7 @@ services:
image: postgres:9.5
environment:
POSTGRES_PASSWORD: postgres
command: -c fsync=off
testenv:
image: python:3.5
@@ -16,6 +17,6 @@ services:
SYNAPSE_POSTGRES_HOST: postgres
SYNAPSE_POSTGRES_USER: postgres
SYNAPSE_POSTGRES_PASSWORD: postgres
working_dir: /app
working_dir: /src
volumes:
- ..:/app
- ..:/src
+3 -2
View File
@@ -6,6 +6,7 @@ services:
image: postgres:11
environment:
POSTGRES_PASSWORD: postgres
command: -c fsync=off
testenv:
image: python:3.7
@@ -16,6 +17,6 @@ services:
SYNAPSE_POSTGRES_HOST: postgres
SYNAPSE_POSTGRES_USER: postgres
SYNAPSE_POSTGRES_PASSWORD: postgres
working_dir: /app
working_dir: /src
volumes:
- ..:/app
- ..:/src
+3 -2
View File
@@ -6,6 +6,7 @@ services:
image: postgres:9.5
environment:
POSTGRES_PASSWORD: postgres
command: -c fsync=off
testenv:
image: python:3.7
@@ -16,6 +17,6 @@ services:
SYNAPSE_POSTGRES_HOST: postgres
SYNAPSE_POSTGRES_USER: postgres
SYNAPSE_POSTGRES_PASSWORD: postgres
working_dir: /app
working_dir: /src
volumes:
- ..:/app
- ..:/src
+1 -1
View File
@@ -27,7 +27,7 @@ git config --global user.name "A robot"
# Fetch and merge. If it doesn't work, it will raise due to set -e.
git fetch -u origin $GITBASE
git merge --no-edit origin/$GITBASE
git merge --no-edit --no-commit origin/$GITBASE
# Show what we are after.
git --no-pager show -s
+82 -12
View File
@@ -1,8 +1,7 @@
env:
CODECOV_TOKEN: "2dd7eb9b-0eda-45fe-a47c-9b5ac040045f"
COVERALLS_REPO_TOKEN: wsJWOby6j0uCYFiCes3r0XauxO27mx8lD
steps:
- command:
- "python -m pip install tox"
- "tox -e check_codestyle"
@@ -10,6 +9,7 @@ steps:
plugins:
- docker#v3.0.1:
image: "python:3.6"
mount-buildkite-agent: false
- command:
- "python -m pip install tox"
@@ -18,6 +18,7 @@ steps:
plugins:
- docker#v3.0.1:
image: "python:3.6"
mount-buildkite-agent: false
- command:
- "python -m pip install tox"
@@ -26,6 +27,7 @@ steps:
plugins:
- docker#v3.0.1:
image: "python:3.6"
mount-buildkite-agent: false
- command:
- "python -m pip install tox"
@@ -36,6 +38,7 @@ steps:
- docker#v3.0.1:
image: "python:3.6"
propagate-environment: true
mount-buildkite-agent: false
- command:
- "python -m pip install tox"
@@ -44,21 +47,35 @@ steps:
plugins:
- docker#v3.0.1:
image: "python:3.6"
mount-buildkite-agent: false
- command:
- "python -m pip install tox"
- "tox -e mypy"
label: ":mypy: mypy"
plugins:
- docker#v3.0.1:
image: "python:3.5"
mount-buildkite-agent: false
- wait
- command:
- "apt-get update && apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev"
- "python3.5 -m pip install tox"
- "tox -e py35-old,codecov"
- "tox -e py35-old,combine"
label: ":python: 3.5 / SQLite / Old Deps"
env:
TRIAL_FLAGS: "-j 2"
LANG: "C.UTF-8"
plugins:
- docker#v3.0.1:
image: "ubuntu:xenial" # We use xenail to get an old sqlite and python
image: "ubuntu:xenial" # We use xenial to get an old sqlite and python
workdir: "/src"
mount-buildkite-agent: false
propagate-environment: true
- matrix-org/coveralls#v1.0:
parallel: "true"
retry:
automatic:
- exit_status: -1
@@ -68,14 +85,18 @@ steps:
- command:
- "python -m pip install tox"
- "tox -e py35,codecov"
- "tox -e py35,combine"
label: ":python: 3.5 / SQLite"
env:
TRIAL_FLAGS: "-j 2"
plugins:
- docker#v3.0.1:
image: "python:3.5"
workdir: "/src"
mount-buildkite-agent: false
propagate-environment: true
- matrix-org/coveralls#v1.0:
parallel: "true"
retry:
automatic:
- exit_status: -1
@@ -85,14 +106,18 @@ steps:
- command:
- "python -m pip install tox"
- "tox -e py36,codecov"
- "tox -e py36,combine"
label: ":python: 3.6 / SQLite"
env:
TRIAL_FLAGS: "-j 2"
plugins:
- docker#v3.0.1:
image: "python:3.6"
workdir: "/src"
mount-buildkite-agent: false
propagate-environment: true
- matrix-org/coveralls#v1.0:
parallel: "true"
retry:
automatic:
- exit_status: -1
@@ -102,14 +127,18 @@ steps:
- command:
- "python -m pip install tox"
- "tox -e py37,codecov"
- "tox -e py37,combine"
label: ":python: 3.7 / SQLite"
env:
TRIAL_FLAGS: "-j 2"
plugins:
- docker#v3.0.1:
image: "python:3.7"
workdir: "/src"
mount-buildkite-agent: false
propagate-environment: true
- matrix-org/coveralls#v1.0:
parallel: "true"
retry:
automatic:
- exit_status: -1
@@ -123,12 +152,14 @@ steps:
env:
TRIAL_FLAGS: "-j 8"
command:
- "bash -c 'python -m pip install tox && python -m tox -e py35-postgres,codecov'"
- "bash -c 'python -m pip install tox && python -m tox -e py35-postgres,combine'"
plugins:
- docker-compose#v2.1.0:
run: testenv
config:
- .buildkite/docker-compose.py35.pg95.yaml
- matrix-org/coveralls#v1.0:
parallel: "true"
retry:
automatic:
- exit_status: -1
@@ -142,12 +173,14 @@ steps:
env:
TRIAL_FLAGS: "-j 8"
command:
- "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,codecov'"
- "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,combine'"
plugins:
- docker-compose#v2.1.0:
run: testenv
config:
- .buildkite/docker-compose.py37.pg95.yaml
- matrix-org/coveralls#v1.0:
parallel: "true"
retry:
automatic:
- exit_status: -1
@@ -161,12 +194,14 @@ steps:
env:
TRIAL_FLAGS: "-j 8"
command:
- "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,codecov'"
- "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,combine'"
plugins:
- docker-compose#v2.1.0:
run: testenv
config:
- .buildkite/docker-compose.py37.pg11.yaml
- matrix-org/coveralls#v1.0:
parallel: "true"
retry:
automatic:
- exit_status: -1
@@ -174,7 +209,6 @@ steps:
- exit_status: 2
limit: 2
- label: "SyTest - :python: 3.5 / SQLite / Monolith"
agents:
queue: "medium"
@@ -187,6 +221,16 @@ steps:
propagate-environment: true
always-pull: true
workdir: "/src"
entrypoint: ["/bin/sh", "-e", "-c"]
mount-buildkite-agent: false
volumes: ["./logs:/logs"]
- artifacts#v1.2.0:
upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/coverage.xml" ]
- matrix-org/annotate:
path: "logs/annotate.md"
style: "error"
- matrix-org/coveralls#v1.0:
parallel: "true"
retry:
automatic:
- exit_status: -1
@@ -208,6 +252,16 @@ steps:
propagate-environment: true
always-pull: true
workdir: "/src"
entrypoint: ["/bin/sh", "-e", "-c"]
mount-buildkite-agent: false
volumes: ["./logs:/logs"]
- artifacts#v1.2.0:
upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/coverage.xml" ]
- matrix-org/annotate:
path: "logs/annotate.md"
style: "error"
- matrix-org/coveralls#v1.0:
parallel: "true"
retry:
automatic:
- exit_status: -1
@@ -232,9 +286,25 @@ steps:
propagate-environment: true
always-pull: true
workdir: "/src"
entrypoint: ["/bin/sh", "-e", "-c"]
mount-buildkite-agent: false
volumes: ["./logs:/logs"]
- artifacts#v1.2.0:
upload: [ "logs/**/*.log", "logs/**/*.log.*", "logs/coverage.xml" ]
- matrix-org/annotate:
path: "logs/annotate.md"
style: "error"
- matrix-org/coveralls#v1.0:
parallel: "true"
retry:
automatic:
- exit_status: -1
limit: 2
- exit_status: 2
limit: 2
- wait: ~
continue_on_failure: true
- label: Trigger webhook
command: "curl -k https://coveralls.io/webhook?repo_token=$COVERALLS_REPO_TOKEN -d \"payload[build_num]=$BUILDKITE_BUILD_NUMBER&payload[status]=done\""
+2 -1
View File
@@ -1,7 +1,8 @@
[run]
branch = True
parallel = True
include = synapse/*
include=$TOP/synapse/*
data_file = $TOP/.coverage
[report]
precision = 2
+3 -2
View File
@@ -20,6 +20,7 @@ _trial_temp*/
/*.signing.key
/env/
/homeserver*.yaml
/logs
/media_store/
/uploads
@@ -29,8 +30,9 @@ _trial_temp*/
/.vscode/
# build products
/.coverage*
!/.coveragerc
/.coverage*
/.mypy_cache/
/.tox
/build/
/coverage.*
@@ -38,4 +40,3 @@ _trial_temp*/
/docs/build/
/htmlcov
/pip-wheel-metadata/
+1
View File
@@ -0,0 +1 @@
Lay the groundwork for structured logging output.
+1
View File
@@ -0,0 +1 @@
Make Opentracing work in worker mode.
+1
View File
@@ -0,0 +1 @@
Add an admin API to purge old rooms from the database.
+1
View File
@@ -0,0 +1 @@
Add retry to well-known lookups if we have recently seen a valid well-known record for the server.
+1
View File
@@ -0,0 +1 @@
Pass opentracing contexts between servers when transmitting EDUs.
+1
View File
@@ -0,0 +1 @@
Opentracing for room and e2e keys.
+1
View File
@@ -0,0 +1 @@
Add unstable support for MSC2197 (filtered search requests over federation), in order to allow upcoming room directory query performance improvements.
+1
View File
@@ -0,0 +1 @@
Correctly retry all hosts returned from SRV when we fail to connect.
+1
View File
@@ -0,0 +1 @@
Remove shared secret registration from client/r0/register endpoint. Contributed by Awesome Technologies Innovationslabor GmbH.
+1
View File
@@ -0,0 +1 @@
Add admin API endpoint for setting whether or not a user is a server administrator.
+1
View File
@@ -0,0 +1 @@
Add missing index on users_in_public_rooms to improve the performance of directory queries.
+1
View File
@@ -0,0 +1 @@
Add config option to sign remote key query responses with a separate key.
+1
View File
@@ -0,0 +1 @@
Improve the logging when we have an error when fetching signing keys.
+1
View File
@@ -0,0 +1 @@
Add support for config templating.
+1
View File
@@ -0,0 +1 @@
Users with the type of "support" or "bot" are no longer required to consent.
+1
View File
@@ -0,0 +1 @@
Let synctl accept a directory of config files.
+1
View File
@@ -0,0 +1 @@
Increase max display name size to 256.
+1
View File
@@ -0,0 +1 @@
Fix error message which referred to public_base_url instead of public_baseurl. Thanks to @aaronraimist for the fix!
+1
View File
@@ -0,0 +1 @@
Add support for database engine-specific schema deltas, based on file extension.
+1
View File
@@ -0,0 +1 @@
Add admin API endpoint for getting whether or not a user is a server administrator.
+1
View File
@@ -0,0 +1 @@
Fix a cache-invalidation bug for worker-based deployments.
+1
View File
@@ -0,0 +1 @@
Update Buildkite pipeline to use plugins instead of buildkite-agent commands.
+1
View File
@@ -0,0 +1 @@
Add link in sample config to the logging config schema.
+5 -5
View File
@@ -17,7 +17,7 @@ By default, the image expects a single volume, located at ``/data``, that will h
* the appservices configuration.
You are free to use separate volumes depending on storage endpoints at your
disposal. For instance, ``/data/media`` coud be stored on a large but low
disposal. For instance, ``/data/media`` could be stored on a large but low
performance hdd storage while other files could be stored on high performance
endpoints.
@@ -27,8 +27,8 @@ configuration file there. Multiple application services are supported.
## Generating a configuration file
The first step is to genearte a valid config file. To do this, you can run the
image with the `generate` commandline option.
The first step is to generate a valid config file. To do this, you can run the
image with the `generate` command line option.
You will need to specify values for the `SYNAPSE_SERVER_NAME` and
`SYNAPSE_REPORT_STATS` environment variable, and mount a docker volume to store
@@ -59,7 +59,7 @@ The following environment variables are supported in `generate` mode:
* `SYNAPSE_CONFIG_PATH`: path to the file to be generated. Defaults to
`<SYNAPSE_CONFIG_DIR>/homeserver.yaml`.
* `SYNAPSE_DATA_DIR`: where the generated config will put persistent data
such as the datatase and media store. Defaults to `/data`.
such as the database and media store. Defaults to `/data`.
* `UID`, `GID`: the user id and group id to use for creating the data
directories. Defaults to `991`, `991`.
@@ -115,7 +115,7 @@ not given).
To migrate from a dynamic configuration file to a static one, run the docker
container once with the environment variables set, and `migrate_config`
commandline option. For example:
command line option. For example:
```
docker run -it --rm \
+18
View File
@@ -0,0 +1,18 @@
Purge room API
==============
This API will remove all trace of a room from your database.
All local users must have left the room before it can be removed.
The API is:
```
POST /_synapse/admin/v1/purge_room
{
"room_id": "!room:id"
}
```
You must authenticate using the access token of an admin user.
+39
View File
@@ -84,3 +84,42 @@ with a body of:
}
including an ``access_token`` of a server admin.
Get whether a user is a server administrator or not
===================================================
The api is::
GET /_synapse/admin/v1/users/<user_id>/admin
including an ``access_token`` of a server admin.
A response body like the following is returned:
.. code:: json
{
"admin": true
}
Change whether a user is a server administrator or not
======================================================
Note that you cannot demote yourself.
The api is::
PUT /_synapse/admin/v1/users/<user_id>/admin
with a body of:
.. code:: json
{
"admin": true
}
including an ``access_token`` of a server admin.
+25 -2
View File
@@ -32,7 +32,7 @@ It is up to the remote server to decide what it does with the spans
it creates. This is called the sampling policy and it can be configured
through Jaeger's settings.
For OpenTracing concepts see
For OpenTracing concepts see
https://opentracing.io/docs/overview/what-is-tracing/.
For more information about Jaeger's implementation see
@@ -79,7 +79,7 @@ Homeserver whitelisting
The homeserver whitelist is configured using regular expressions. A list of regular
expressions can be given and their union will be compared when propagating any
spans contexts to another homeserver.
spans contexts to another homeserver.
Though it's mostly safe to send and receive span contexts to and from
untrusted users since span contexts are usually opaque ids it can lead to
@@ -92,6 +92,29 @@ two problems, namely:
but that doesn't prevent another server sending you baggage which will be logged
to OpenTracing's logs.
==========
EDU FORMAT
==========
EDUs can contain tracing data in their content. This is not specced but
it could be of interest for other homeservers.
EDU format (if you're using jaeger):
.. code-block:: json
{
"edu_type": "type",
"content": {
"org.matrix.opentracing_context": {
"uber-trace-id": "fe57cf3e65083289"
}
}
}
Though you don't have to use jaeger you must inject the span context into
`org.matrix.opentracing_context` using the opentracing `Format.TEXT_MAP` inject method.
==================
Configuring Jaeger
==================
+18 -9
View File
@@ -205,9 +205,9 @@ listeners:
#
- port: 8008
tls: false
bind_addresses: ['::1', '127.0.0.1']
type: http
x_forwarded: true
bind_addresses: ['::1', '127.0.0.1']
resources:
- names: [client, federation]
@@ -392,10 +392,10 @@ listeners:
# permission to listen on port 80.
#
acme:
# ACME support is disabled by default. Uncomment the following line
# (and tls_certificate_path and tls_private_key_path above) to enable it.
# ACME support is disabled by default. Set this to `true` and uncomment
# tls_certificate_path and tls_private_key_path above to enable it.
#
#enabled: true
enabled: False
# Endpoint to use to request certificates. If you only want to test,
# use Let's Encrypt's staging url:
@@ -406,17 +406,17 @@ acme:
# Port number to listen on for the HTTP-01 challenge. Change this if
# you are forwarding connections through Apache/Nginx/etc.
#
#port: 80
port: 80
# Local addresses to listen on for incoming connections.
# Again, you may want to change this if you are forwarding connections
# through Apache/Nginx/etc.
#
#bind_addresses: ['::', '0.0.0.0']
bind_addresses: ['::', '0.0.0.0']
# How many days remaining on a certificate before it is renewed.
#
#reprovision_threshold: 30
reprovision_threshold: 30
# The domain that the certificate should be for. Normally this
# should be the same as your Matrix domain (i.e., 'server_name'), but,
@@ -430,7 +430,7 @@ acme:
#
# If not set, defaults to your 'server_name'.
#
#domain: matrix.example.com
domain: matrix.example.com
# file to use for the account key. This will be generated if it doesn't
# exist.
@@ -485,7 +485,8 @@ database:
## Logging ##
# A yaml python logging config file
# A yaml python logging config file as described by
# https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
#
log_config: "CONFDIR/SERVERNAME.log.config"
@@ -1027,6 +1028,14 @@ signing_key_path: "CONFDIR/SERVERNAME.signing.key"
#
#trusted_key_servers:
# - server_name: "matrix.org"
#
# The signing keys to use when acting as a trusted key server. If not specified
# defaults to the server signing key.
#
# Can contain multiple keys, one per line.
#
#key_server_signing_keys_path: "key_server_signing_keys.key"
# Enable SAML2 for registration and login. Uses pysaml2.
+83
View File
@@ -0,0 +1,83 @@
# Structured Logging
A structured logging system can be useful when your logs are destined for a machine to parse and process. By maintaining its machine-readable characteristics, it enables more efficient searching and aggregations when consumed by software such as the "ELK stack".
Synapse's structured logging system is configured via the file that Synapse's `log_config` config option points to. The file must be YAML and contain `structured: true`. It must contain a list of "drains" (places where logs go to).
A structured logging configuration looks similar to the following:
```yaml
structured: true
loggers:
synapse:
level: INFO
synapse.storage.SQL:
level: WARNING
drains:
console:
type: console
location: stdout
file:
type: file_json
location: homeserver.log
```
The above logging config will set Synapse as 'INFO' logging level by default, with the SQL layer at 'WARNING', and will have two logging drains (to the console and to a file, stored as JSON).
## Drain Types
Drain types can be specified by the `type` key.
### `console`
Outputs human-readable logs to the console.
Arguments:
- `location`: Either `stdout` or `stderr`.
### `console_json`
Outputs machine-readable JSON logs to the console.
Arguments:
- `location`: Either `stdout` or `stderr`.
### `console_json_terse`
Outputs machine-readable JSON logs to the console, separated by newlines. This
format is not designed to be read and re-formatted into human-readable text, but
is optimal for a logging aggregation system.
Arguments:
- `location`: Either `stdout` or `stderr`.
### `file`
Outputs human-readable logs to a file.
Arguments:
- `location`: An absolute path to the file to log to.
### `file_json`
Outputs machine-readable logs to a file.
Arguments:
- `location`: An absolute path to the file to log to.
### `network_json_terse`
Delivers machine-readable JSON logs to a log aggregator over TCP. This is
compatible with LogStash's TCP input with the codec set to `json_lines`.
Arguments:
- `host`: Hostname or IP address of the log aggregator.
- `port`: Numerical port to contact on the host.
+2 -1
View File
@@ -122,7 +122,8 @@ class UserTypes(object):
"""
SUPPORT = "support"
ALL_USER_TYPES = (SUPPORT,)
BOT = "bot"
ALL_USER_TYPES = (SUPPORT, BOT)
class RelationTypes(object):
+7 -5
View File
@@ -36,18 +36,20 @@ from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
# list of tuples of function, args list, kwargs dict
_sighup_callbacks = []
def register_sighup(func):
def register_sighup(func, *args, **kwargs):
"""
Register a function to be called when a SIGHUP occurs.
Args:
func (function): Function to be called when sent a SIGHUP signal.
Will be called with a single argument, the homeserver.
Will be called with a single default argument, the homeserver.
*args, **kwargs: args and kwargs to be passed to the target function.
"""
_sighup_callbacks.append(func)
_sighup_callbacks.append((func, args, kwargs))
def start_worker_reactor(appname, config, run_command=reactor.run):
@@ -248,8 +250,8 @@ def start(hs, listeners=None):
# we're not using systemd.
sdnotify(b"RELOADING=1")
for i in _sighup_callbacks:
i(hs)
for i, args, kwargs in _sighup_callbacks:
i(hs, *args, **kwargs)
sdnotify(b"READY=1")
+2 -2
View File
@@ -227,8 +227,6 @@ def start(config_options):
config.start_pushers = False
config.send_federation = False
setup_logging(config, use_worker_options=True)
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
@@ -241,6 +239,8 @@ def start(config_options):
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
ss.setup()
# We use task.react as the basic run command as it correctly handles tearing
+2 -2
View File
@@ -141,8 +141,6 @@ def start(config_options):
assert config.worker_app == "synapse.app.appservice"
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
@@ -167,6 +165,8 @@ def start(config_options):
database_engine=database_engine,
)
setup_logging(ps, config, use_worker_options=True)
ps.setup()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, ps, config.worker_listeners
+2 -2
View File
@@ -179,8 +179,6 @@ def start(config_options):
assert config.worker_app == "synapse.app.client_reader"
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
@@ -193,6 +191,8 @@ def start(config_options):
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
ss.setup()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, ss, config.worker_listeners
+2 -2
View File
@@ -175,8 +175,6 @@ def start(config_options):
assert config.worker_replication_http_port is not None
setup_logging(config, use_worker_options=True)
# This should only be done on the user directory worker or the master
config.update_user_directory = False
@@ -192,6 +190,8 @@ def start(config_options):
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
ss.setup()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, ss, config.worker_listeners
+2 -2
View File
@@ -160,8 +160,6 @@ def start(config_options):
assert config.worker_app == "synapse.app.federation_reader"
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
@@ -174,6 +172,8 @@ def start(config_options):
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
ss.setup()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, ss, config.worker_listeners
+2 -2
View File
@@ -171,8 +171,6 @@ def start(config_options):
assert config.worker_app == "synapse.app.federation_sender"
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
@@ -197,6 +195,8 @@ def start(config_options):
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
ss.setup()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, ss, config.worker_listeners
+2 -2
View File
@@ -232,8 +232,6 @@ def start(config_options):
assert config.worker_main_http_uri is not None
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
@@ -246,6 +244,8 @@ def start(config_options):
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
ss.setup()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, ss, config.worker_listeners
+2 -2
View File
@@ -341,8 +341,6 @@ def setup(config_options):
# generating config files and shouldn't try to continue.
sys.exit(0)
synapse.config.logger.setup_logging(config, use_worker_options=False)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
@@ -356,6 +354,8 @@ def setup(config_options):
database_engine=database_engine,
)
synapse.config.logger.setup_logging(hs, config, use_worker_options=False)
logger.info("Preparing database: %s...", config.database_config["name"])
try:
+2 -2
View File
@@ -155,8 +155,6 @@ def start(config_options):
"Please add ``enable_media_repo: false`` to the main config\n"
)
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
@@ -169,6 +167,8 @@ def start(config_options):
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
ss.setup()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, ss, config.worker_listeners
+2 -2
View File
@@ -184,8 +184,6 @@ def start(config_options):
assert config.worker_app == "synapse.app.pusher"
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
if config.start_pushers:
@@ -210,6 +208,8 @@ def start(config_options):
database_engine=database_engine,
)
setup_logging(ps, config, use_worker_options=True)
ps.setup()
def start():
+2 -2
View File
@@ -435,8 +435,6 @@ def start(config_options):
assert config.worker_app == "synapse.app.synchrotron"
setup_logging(config, use_worker_options=True)
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
@@ -450,6 +448,8 @@ def start(config_options):
application_service_handler=SynchrotronApplicationService(),
)
setup_logging(ss, config, use_worker_options=True)
ss.setup()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, ss, config.worker_listeners
+2 -2
View File
@@ -197,8 +197,6 @@ def start(config_options):
assert config.worker_app == "synapse.app.user_dir"
setup_logging(config, use_worker_options=True)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
@@ -223,6 +221,8 @@ def start(config_options):
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)
ss.setup()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, ss, config.worker_listeners
+4 -3
View File
@@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import ConfigError
from ._base import ConfigError, find_config_files
# export ConfigError if somebody does import *
# export ConfigError and find_config_files if somebody does
# import *
# this is largely a fudge to stop PEP8 moaning about the import
__all__ = ["ConfigError"]
__all__ = ["ConfigError", "find_config_files"]
+37
View File
@@ -181,6 +181,11 @@ class Config(object):
generate_secrets=False,
report_stats=None,
open_private_ports=False,
listeners=None,
database_conf=None,
tls_certificate_path=None,
tls_private_key_path=None,
acme_domain=None,
):
"""Build a default configuration file
@@ -207,6 +212,33 @@ class Config(object):
open_private_ports (bool): True to leave private ports (such as the non-TLS
HTTP listener) open to the internet.
listeners (list(dict)|None): A list of descriptions of the listeners
synapse should start with each of which specifies a port (str), a list of
resources (list(str)), tls (bool) and type (str). For example:
[{
"port": 8448,
"resources": [{"names": ["federation"]}],
"tls": True,
"type": "http",
},
{
"port": 443,
"resources": [{"names": ["client"]}],
"tls": False,
"type": "http",
}],
database (str|None): The database type to configure, either `psycog2`
or `sqlite3`.
tls_certificate_path (str|None): The path to the tls certificate.
tls_private_key_path (str|None): The path to the tls private key.
acme_domain (str|None): The domain acme will try to validate. If
specified acme will be enabled.
Returns:
str: the yaml config file
"""
@@ -220,6 +252,11 @@ class Config(object):
generate_secrets=generate_secrets,
report_stats=report_stats,
open_private_ports=open_private_ports,
listeners=listeners,
database_conf=database_conf,
tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path,
acme_domain=acme_domain,
)
)
+19 -8
View File
@@ -13,6 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from textwrap import indent
import yaml
from ._base import Config
@@ -38,20 +41,28 @@ class DatabaseConfig(Config):
self.set_databasepath(config.get("database_path"))
def generate_config_section(self, data_dir_path, **kwargs):
database_path = os.path.join(data_dir_path, "homeserver.db")
return (
"""\
## Database ##
database:
# The database engine name
def generate_config_section(self, data_dir_path, database_conf, **kwargs):
if not database_conf:
database_path = os.path.join(data_dir_path, "homeserver.db")
database_conf = (
"""# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
args:
# Path to the database
database: "%(database_path)s"
"""
% locals()
)
else:
database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip()
return (
"""\
## Database ##
database:
%(database_conf)s
# Number of events to cache in memory.
#
#event_cache_size: 10K
+1 -1
View File
@@ -115,7 +115,7 @@ class EmailConfig(Config):
missing.append("email." + k)
if config.get("public_baseurl") is None:
missing.append("public_base_url")
missing.append("public_baseurl")
if len(missing) > 0:
raise RuntimeError(
+30 -4
View File
@@ -76,7 +76,7 @@ class KeyConfig(Config):
config_dir_path, config["server_name"] + ".signing.key"
)
self.signing_key = self.read_signing_key(signing_key_path)
self.signing_key = self.read_signing_keys(signing_key_path, "signing_key")
self.old_signing_keys = self.read_old_signing_keys(
config.get("old_signing_keys", {})
@@ -85,6 +85,14 @@ class KeyConfig(Config):
config.get("key_refresh_interval", "1d")
)
key_server_signing_keys_path = config.get("key_server_signing_keys_path")
if key_server_signing_keys_path:
self.key_server_signing_keys = self.read_signing_keys(
key_server_signing_keys_path, "key_server_signing_keys_path"
)
else:
self.key_server_signing_keys = list(self.signing_key)
# if neither trusted_key_servers nor perspectives are given, use the default.
if "perspectives" not in config and "trusted_key_servers" not in config:
key_servers = [{"server_name": "matrix.org"}]
@@ -210,16 +218,34 @@ class KeyConfig(Config):
#
#trusted_key_servers:
# - server_name: "matrix.org"
#
# The signing keys to use when acting as a trusted key server. If not specified
# defaults to the server signing key.
#
# Can contain multiple keys, one per line.
#
#key_server_signing_keys_path: "key_server_signing_keys.key"
"""
% locals()
)
def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key")
def read_signing_keys(self, signing_key_path, name):
"""Read the signing keys in the given path.
Args:
signing_key_path (str)
name (str): Associated config key name
Returns:
list[SigningKey]
"""
signing_keys = self.read_file(signing_key_path, name)
try:
return read_signing_keys(signing_keys.splitlines(True))
except Exception as e:
raise ConfigError("Error reading signing_key: %s" % (str(e)))
raise ConfigError("Error reading %s: %s" % (name, str(e)))
def read_old_signing_keys(self, old_signing_keys):
keys = {}
+63 -43
View File
@@ -25,6 +25,10 @@ from twisted.logger import STDLibLogObserver, globalLogBeginner
import synapse
from synapse.app import _base as appbase
from synapse.logging._structured import (
reload_structured_logging,
setup_structured_logging,
)
from synapse.logging.context import LoggingContextFilter
from synapse.util.versionstring import get_version_string
@@ -85,7 +89,8 @@ class LoggingConfig(Config):
"""\
## Logging ##
# A yaml python logging config file
# A yaml python logging config file as described by
# https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
#
log_config: "%(log_config)s"
"""
@@ -119,21 +124,10 @@ class LoggingConfig(Config):
log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
def setup_logging(config, use_worker_options=False):
""" Set up python logging
Args:
config (LoggingConfig | synapse.config.workers.WorkerConfig):
configuration data
use_worker_options (bool): True to use the 'worker_log_config' option
instead of 'log_config'.
register_sighup (func | None): Function to call to register a
sighup handler.
def _setup_stdlib_logging(config, log_config):
"""
Set up Python stdlib logging.
"""
log_config = config.worker_log_config if use_worker_options else config.log_config
if log_config is None:
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
@@ -151,35 +145,10 @@ def setup_logging(config, use_worker_options=False):
handler.addFilter(LoggingContextFilter(request=""))
logger.addHandler(handler)
else:
logging.config.dictConfig(log_config)
def load_log_config():
with open(log_config, "r") as f:
logging.config.dictConfig(yaml.safe_load(f))
def sighup(*args):
# it might be better to use a file watcher or something for this.
load_log_config()
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
load_log_config()
appbase.register_sighup(sighup)
# make sure that the first thing we log is a thing we can grep backwards
# for
logging.warn("***** STARTING SERVER *****")
logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
# It's critical to point twisted's internal logging somewhere, otherwise it
# stacks up and leaks kup to 64K object;
# see: https://twistedmatrix.com/trac/ticket/8164
#
# Routing to the python logging framework could be a performance problem if
# the handlers blocked for a long time as python.logging is a blocking API
# see https://twistedmatrix.com/documents/current/core/howto/logger.html
# filed as https://github.com/matrix-org/synapse/issues/1727
#
# However this may not be too much of a problem if we are just writing to a file.
# Route Twisted's native logging through to the standard library logging
# system.
observer = STDLibLogObserver()
def _log(event):
@@ -201,3 +170,54 @@ def setup_logging(config, use_worker_options=False):
)
if not config.no_redirect_stdio:
print("Redirected stdout/stderr to logs")
def _reload_stdlib_logging(*args, log_config=None):
logger = logging.getLogger("")
if not log_config:
logger.warn("Reloaded a blank config?")
logging.config.dictConfig(log_config)
def setup_logging(hs, config, use_worker_options=False):
"""
Set up the logging subsystem.
Args:
config (LoggingConfig | synapse.config.workers.WorkerConfig):
configuration data
use_worker_options (bool): True to use the 'worker_log_config' option
instead of 'log_config'.
"""
log_config = config.worker_log_config if use_worker_options else config.log_config
def read_config(*args, callback=None):
if log_config is None:
return None
with open(log_config, "rb") as f:
log_config_body = yaml.safe_load(f.read())
if callback:
callback(log_config=log_config_body)
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
return log_config_body
log_config_body = read_config()
if log_config_body and log_config_body.get("structured") is True:
setup_structured_logging(hs, config, log_config_body)
appbase.register_sighup(read_config, callback=reload_structured_logging)
else:
_setup_stdlib_logging(config, log_config_body)
appbase.register_sighup(read_config, callback=_reload_stdlib_logging)
# make sure that the first thing we log is a thing we can grep backwards
# for
logging.warn("***** STARTING SERVER *****")
logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
+67 -17
View File
@@ -17,8 +17,11 @@
import logging
import os.path
import re
from textwrap import indent
import attr
import yaml
from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -352,7 +355,7 @@ class ServerConfig(Config):
return any(l["tls"] for l in self.listeners)
def generate_config_section(
self, server_name, data_dir_path, open_private_ports, **kwargs
self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
):
_, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None:
@@ -366,11 +369,68 @@ class ServerConfig(Config):
# Bring DEFAULT_ROOM_VERSION into the local-scope for use in the
# default config string
default_room_version = DEFAULT_ROOM_VERSION
secure_listeners = []
unsecure_listeners = []
private_addresses = ["::1", "127.0.0.1"]
if listeners:
for listener in listeners:
if listener["tls"]:
secure_listeners.append(listener)
else:
# If we don't want open ports we need to bind the listeners
# to some address other than 0.0.0.0. Here we chose to use
# localhost.
# If the addresses are already bound we won't overwrite them
# however.
if not open_private_ports:
listener.setdefault("bind_addresses", private_addresses)
unsecure_http_binding = "port: %i\n tls: false" % (unsecure_port,)
if not open_private_ports:
unsecure_http_binding += (
"\n bind_addresses: ['::1', '127.0.0.1']"
unsecure_listeners.append(listener)
secure_http_bindings = indent(
yaml.dump(secure_listeners), " " * 10
).lstrip()
unsecure_http_bindings = indent(
yaml.dump(unsecure_listeners), " " * 10
).lstrip()
if not unsecure_listeners:
unsecure_http_bindings = (
"""- port: %(unsecure_port)s
tls: false
type: http
x_forwarded: true"""
% locals()
)
if not open_private_ports:
unsecure_http_bindings += (
"\n bind_addresses: ['::1', '127.0.0.1']"
)
unsecure_http_bindings += """
resources:
- names: [client, federation]
compress: false"""
if listeners:
# comment out this block
unsecure_http_bindings = "#" + re.sub(
"\n {10}",
lambda match: match.group(0) + "#",
unsecure_http_bindings,
)
if not secure_listeners:
secure_http_bindings = (
"""#- port: %(bind_port)s
# type: http
# tls: true
# resources:
# - names: [client, federation]"""
% locals()
)
return (
@@ -556,11 +616,7 @@ class ServerConfig(Config):
# will also need to give Synapse a TLS key and certificate: see the TLS section
# below.)
#
#- port: %(bind_port)s
# type: http
# tls: true
# resources:
# - names: [client, federation]
%(secure_http_bindings)s
# Unsecure HTTP listener: for when matrix traffic passes through a reverse proxy
# that unwraps TLS.
@@ -568,13 +624,7 @@ class ServerConfig(Config):
# If you plan to use a reverse proxy, please see
# https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.rst.
#
- %(unsecure_http_binding)s
type: http
x_forwarded: true
resources:
- names: [client, federation]
compress: false
%(unsecure_http_bindings)s
# example additional_resources:
#
+38 -12
View File
@@ -239,12 +239,38 @@ class TlsConfig(Config):
self.tls_fingerprints.append({"sha256": sha256_fingerprint})
def generate_config_section(
self, config_dir_path, server_name, data_dir_path, **kwargs
self,
config_dir_path,
server_name,
data_dir_path,
tls_certificate_path,
tls_private_key_path,
acme_domain,
**kwargs
):
"""If the acme_domain is specified acme will be enabled.
If the TLS paths are not specified the default will be certs in the
config directory"""
base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt"
tls_private_key_path = base_key_name + ".tls.key"
if bool(tls_certificate_path) != bool(tls_private_key_path):
raise ConfigError(
"Please specify both a cert path and a key path or neither."
)
tls_enabled = (
"" if tls_certificate_path and tls_private_key_path or acme_domain else "#"
)
if not tls_certificate_path:
tls_certificate_path = base_key_name + ".tls.crt"
if not tls_private_key_path:
tls_private_key_path = base_key_name + ".tls.key"
acme_enabled = bool(acme_domain)
acme_domain = "matrix.example.com"
default_acme_account_file = os.path.join(data_dir_path, "acme_account.key")
# this is to avoid the max line length. Sorrynotsorry
@@ -269,11 +295,11 @@ class TlsConfig(Config):
# instance, if using certbot, use `fullchain.pem` as your certificate,
# not `cert.pem`).
#
#tls_certificate_path: "%(tls_certificate_path)s"
%(tls_enabled)stls_certificate_path: "%(tls_certificate_path)s"
# PEM-encoded private key for TLS
#
#tls_private_key_path: "%(tls_private_key_path)s"
%(tls_enabled)stls_private_key_path: "%(tls_private_key_path)s"
# Whether to verify TLS server certificates for outbound federation requests.
#
@@ -340,10 +366,10 @@ class TlsConfig(Config):
# permission to listen on port 80.
#
acme:
# ACME support is disabled by default. Uncomment the following line
# (and tls_certificate_path and tls_private_key_path above) to enable it.
# ACME support is disabled by default. Set this to `true` and uncomment
# tls_certificate_path and tls_private_key_path above to enable it.
#
#enabled: true
enabled: %(acme_enabled)s
# Endpoint to use to request certificates. If you only want to test,
# use Let's Encrypt's staging url:
@@ -354,17 +380,17 @@ class TlsConfig(Config):
# Port number to listen on for the HTTP-01 challenge. Change this if
# you are forwarding connections through Apache/Nginx/etc.
#
#port: 80
port: 80
# Local addresses to listen on for incoming connections.
# Again, you may want to change this if you are forwarding connections
# through Apache/Nginx/etc.
#
#bind_addresses: ['::', '0.0.0.0']
bind_addresses: ['::', '0.0.0.0']
# How many days remaining on a certificate before it is renewed.
#
#reprovision_threshold: 30
reprovision_threshold: 30
# The domain that the certificate should be for. Normally this
# should be the same as your Matrix domain (i.e., 'server_name'), but,
@@ -378,7 +404,7 @@ class TlsConfig(Config):
#
# If not set, defaults to your 'server_name'.
#
#domain: matrix.example.com
domain: %(acme_domain)s
# file to use for the account key. This will be generated if it doesn't
# exist.
+9 -14
View File
@@ -18,7 +18,6 @@ import logging
from collections import defaultdict
import six
from six import raise_from
from six.moves import urllib
import attr
@@ -30,7 +29,6 @@ from signedjson.key import (
from signedjson.sign import (
SignatureVerifyException,
encode_canonical_json,
sign_json,
signature_ids,
verify_signed_json,
)
@@ -540,13 +538,7 @@ class BaseV2KeyFetcher(object):
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
# re-sign the json with our own key, so that it is ready if we are asked to
# give it out as a notary server
signed_key_json = sign_json(
response_json, self.config.server_name, self.config.signing_key[0]
)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
key_json_bytes = encode_canonical_json(response_json)
yield make_deferred_yieldable(
defer.gatherResults(
@@ -558,7 +550,7 @@ class BaseV2KeyFetcher(object):
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
key_json_bytes=key_json_bytes,
)
for key_id in verify_keys
],
@@ -657,9 +649,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
},
)
except (NotRetryingDestination, RequestSendFailed) as e:
raise_from(KeyLookupError("Failed to connect to remote server"), e)
# these both have str() representations which we can't really improve upon
raise KeyLookupError(str(e))
except HttpResponseException as e:
raise_from(KeyLookupError("Remote server returned an error"), e)
raise KeyLookupError("Remote server returned an error: %s" % (e,))
keys = {}
added_keys = []
@@ -821,9 +814,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
timeout=10000,
)
except (NotRetryingDestination, RequestSendFailed) as e:
raise_from(KeyLookupError("Failed to connect to remote server"), e)
# these both have str() representations which we can't really improve
# upon
raise KeyLookupError(str(e))
except HttpResponseException as e:
raise_from(KeyLookupError("Remote server returned an error"), e)
raise KeyLookupError("Remote server returned an error: %s" % (e,))
if response["server_name"] != server_name:
raise KeyLookupError(
+10 -6
View File
@@ -43,6 +43,7 @@ from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
from synapse.logging.context import nested_logging_context
from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
from synapse.logging.utils import log_function
from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
@@ -507,6 +508,7 @@ class FederationServer(FederationBase):
def on_query_user_devices(self, origin, user_id):
return self.on_query_request("user_devices", user_id)
@trace
@defer.inlineCallbacks
@log_function
def on_claim_client_keys(self, origin, content):
@@ -515,6 +517,7 @@ class FederationServer(FederationBase):
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
results = yield self.store.claim_e2e_one_time_keys(query)
json_result = {}
@@ -808,12 +811,13 @@ class FederationHandlerRegistry(object):
if not handler:
logger.warn("No handler registered for EDU type %s", edu_type)
try:
yield handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception:
logger.exception("Failed to handle edu %r", edu_type)
with start_active_span_from_edu(content, "handle_edu"):
try:
yield handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception:
logger.exception("Failed to handle edu %r", edu_type)
def on_query(self, query_type, args):
handler = self.query_handlers.get(query_type)
@@ -14,11 +14,19 @@
# limitations under the License.
import logging
from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import HttpResponseException
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Transaction
from synapse.logging.opentracing import (
extract_text_map,
set_tag,
start_active_span_follows_from,
tags,
)
from synapse.util.metrics import measure_func
logger = logging.getLogger(__name__)
@@ -44,93 +52,109 @@ class TransactionManager(object):
@defer.inlineCallbacks
def send_new_transaction(self, destination, pending_pdus, pending_edus):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
# Make a transaction-sending opentracing span. This span follows on from
# all the edus in that transaction. This needs to be done since there is
# no active span here, so if the edus were not received by the remote the
# span would have no causality and it would be forgotten.
# The span_contexts is a generator so that it won't be evaluated if
# opentracing is disabled. (Yay speed!)
success = True
logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id)
logger.debug(
"TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)",
destination,
txn_id,
len(pdus),
len(edus),
span_contexts = (
extract_text_map(json.loads(edu.get_context())) for edu in pending_edus
)
transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
origin=self._server_name,
destination=destination,
pdus=pdus,
edus=edus,
)
with start_active_span_follows_from("send_transaction", span_contexts):
self._next_txn_id += 1
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
edus = pending_edus
logger.info(
"TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
destination,
txn_id,
transaction.transaction_id,
len(pdus),
len(edus),
)
success = True
# Actually send the transaction
logger.debug("TX [%s] _attempt_new_transaction", destination)
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
txn_id = str(self._next_txn_id)
try:
response = yield self._transport_layer.send_transaction(
transaction, json_data_cb
logger.debug(
"TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)",
destination,
txn_id,
len(pdus),
len(edus),
)
code = 200
except HttpResponseException as e:
code = e.code
response = e.response
if e.code in (401, 404, 429) or 500 <= e.code:
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
raise e
transaction = Transaction.create_new(
origin_server_ts=int(self.clock.time_msec()),
transaction_id=txn_id,
origin=self._server_name,
destination=destination,
pdus=pdus,
edus=edus,
)
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
self._next_txn_id += 1
if code == 200:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.info(
"TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
destination,
txn_id,
transaction.transaction_id,
len(pdus),
len(edus),
)
# Actually send the transaction
# FIXME (erikj): This is a bit of a hack to make the Pdu age
# keys work
def json_data_cb():
data = transaction.get_dict()
now = int(self.clock.time_msec())
if "pdus" in data:
for p in data["pdus"]:
if "age_ts" in p:
unsigned = p.setdefault("unsigned", {})
unsigned["age"] = now - int(p["age_ts"])
del p["age_ts"]
return data
try:
response = yield self._transport_layer.send_transaction(
transaction, json_data_cb
)
code = 200
except HttpResponseException as e:
code = e.code
response = e.response
if e.code in (401, 404, 429) or 500 <= e.code:
logger.info(
"TX [%s] {%s} got %d response", destination, txn_id, code
)
raise e
logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
if code == 200:
for e_id, r in response.get("pdus", {}).items():
if "error" in r:
logger.warn(
"TX [%s] {%s} Remote returned error for %s: %s",
destination,
txn_id,
e_id,
r,
)
else:
for p in pdus:
logger.warn(
"TX [%s] {%s} Remote returned error for %s: %s",
"TX [%s] {%s} Failed to send event %s",
destination,
txn_id,
e_id,
r,
p.event_id,
)
else:
for p in pdus:
logger.warn(
"TX [%s] {%s} Failed to send event %s",
destination,
txn_id,
p.event_id,
)
success = False
success = False
return success
set_tag(tags.ERROR, not success)
return success
+28 -12
View File
@@ -327,21 +327,37 @@ class TransportLayerClient(object):
include_all_networks=False,
third_party_instance_id=None,
):
path = _create_v1_path("/publicRooms")
if search_filter:
# this uses MSC2197 (Search Filtering over Federation)
path = _create_v1_path("/publicRooms")
args = {"include_all_networks": "true" if include_all_networks else "false"}
if third_party_instance_id:
args["third_party_instance_id"] = (third_party_instance_id,)
if limit:
args["limit"] = [str(limit)]
if since_token:
args["since"] = [since_token]
data = {"include_all_networks": "true" if include_all_networks else "false"}
if third_party_instance_id:
data["third_party_instance_id"] = third_party_instance_id
if limit:
data["limit"] = str(limit)
if since_token:
data["since"] = since_token
# TODO(erikj): Actually send the search_filter across federation.
data["filter"] = search_filter
response = yield self.client.get_json(
destination=remote_server, path=path, args=args, ignore_backoff=True
)
response = yield self.client.post_json(
destination=remote_server, path=path, data=data, ignore_backoff=True
)
else:
path = _create_v1_path("/publicRooms")
args = {"include_all_networks": "true" if include_all_networks else "false"}
if third_party_instance_id:
args["third_party_instance_id"] = (third_party_instance_id,)
if limit:
args["limit"] = [str(limit)]
if since_token:
args["since"] = [since_token]
response = yield self.client.get_json(
destination=remote_server, path=path, args=args, ignore_backoff=True
)
return response
+64 -15
View File
@@ -38,7 +38,12 @@ from synapse.http.servlet import (
parse_string_from_args,
)
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import start_active_span_from_context, tags
from synapse.logging.opentracing import (
start_active_span,
start_active_span_from_request,
tags,
whitelisted_homeserver,
)
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
@@ -288,20 +293,28 @@ class BaseFederationServlet(object):
logger.warn("authenticate_request failed: %s", e)
raise
# Start an opentracing span
with start_active_span_from_context(
request.requestHeaders,
"incoming-federation-request",
tags={
"request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
tags.PEER_HOST_IPV6: request.getClientIP(),
"authenticated_entity": origin,
"servlet_name": request.request_metrics.name,
},
):
request_tags = {
"request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
tags.PEER_HOST_IPV6: request.getClientIP(),
"authenticated_entity": origin,
"servlet_name": request.request_metrics.name,
}
# Only accept the span context if the origin is authenticated
# and whitelisted
if origin and whitelisted_homeserver(origin):
scope = start_active_span_from_request(
request, "incoming-federation-request", tags=request_tags
)
else:
scope = start_active_span(
"incoming-federation-request", tags=request_tags
)
with scope:
if origin:
with ratelimiter.ratelimit(origin) as d:
await d
@@ -757,6 +770,42 @@ class PublicRoomList(BaseFederationServlet):
)
return 200, data
async def on_POST(self, origin, content, query):
# This implements MSC2197 (Search Filtering over Federation)
if not self.allow_access:
raise FederationDeniedError(origin)
limit = int(content.get("limit", 100))
since_token = content.get("since", None)
search_filter = content.get("filter", None)
include_all_networks = content.get("include_all_networks", False)
third_party_instance_id = content.get("third_party_instance_id", None)
if include_all_networks:
network_tuple = None
if third_party_instance_id is not None:
raise SynapseError(
400, "Can't use include_all_networks with an explicit network"
)
elif third_party_instance_id is None:
network_tuple = ThirdPartyInstanceID(None, None)
else:
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
if search_filter is None:
logger.warning("Nonefilter")
data = await self.handler.get_local_public_room_list(
limit=limit,
since_token=since_token,
search_filter=search_filter,
network_tuple=network_tuple,
from_federation=True,
)
return 200, data
class FederationVersionServlet(BaseFederationServlet):
PATH = "/version"
+3
View File
@@ -38,6 +38,9 @@ class Edu(JsonEncodedObject):
internal_keys = ["origin", "destination"]
def get_context(self):
return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")
class Transaction(JsonEncodedObject):
""" A transaction is a list of Pdus and Edus to be sent to a remote home
+19
View File
@@ -94,6 +94,25 @@ class AdminHandler(BaseHandler):
return ret
def get_user_server_admin(self, user):
"""
Get the admin bit on a user.
Args:
user_id (UserID): the (necessarily local) user to manipulate
"""
return self.store.is_server_admin(user)
def set_user_server_admin(self, user, admin):
"""
Set the admin bit on a user.
Args:
user_id (UserID): the (necessarily local) user to manipulate
admin (bool): whether or not the user should be an admin of this server
"""
return self.store.set_server_admin(user, admin)
@defer.inlineCallbacks
def export_user_data(self, user_id, writer):
"""Write all data we have on the user to the given writer.
+21 -6
View File
@@ -15,9 +15,17 @@
import logging
from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.logging.opentracing import (
get_active_span_text_map,
set_tag,
start_active_span,
whitelisted_homeserver,
)
from synapse.types import UserID, get_domain_from_id
from synapse.util.stringutils import random_string
@@ -100,14 +108,21 @@ class DeviceMessageHandler(object):
message_id = random_string(16)
context = get_active_span_text_map()
remote_edu_contents = {}
for destination, messages in remote_messages.items():
remote_edu_contents[destination] = {
"messages": messages,
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
}
with start_active_span("to_device_for_user"):
set_tag("destination", destination)
remote_edu_contents[destination] = {
"messages": messages,
"sender": sender_user_id,
"type": message_type,
"message_id": message_id,
"org.matrix.opentracing_context": json.dumps(context)
if whitelisted_homeserver(destination)
else None,
}
stream_id = yield self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
+51 -1
View File
@@ -24,6 +24,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException, SynapseError
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
from synapse.types import UserID, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.retryutils import NotRetryingDestination
@@ -46,6 +47,7 @@ class E2eKeysHandler(object):
"client_keys", self.on_federation_query_client_keys
)
@trace
@defer.inlineCallbacks
def query_devices(self, query_body, timeout):
""" Handle a device key query from a client
@@ -81,6 +83,9 @@ class E2eKeysHandler(object):
else:
remote_queries[user_id] = device_ids
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
# First get local devices.
failures = {}
results = {}
@@ -121,6 +126,7 @@ class E2eKeysHandler(object):
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
@trace
@defer.inlineCallbacks
def do_remote_query(destination):
"""This is called when we are querying the device list of a user on
@@ -185,6 +191,8 @@ class E2eKeysHandler(object):
except Exception as e:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
yield make_deferred_yieldable(
defer.gatherResults(
@@ -198,6 +206,7 @@ class E2eKeysHandler(object):
return {"device_keys": results, "failures": failures}
@trace
@defer.inlineCallbacks
def query_local_devices(self, query):
"""Get E2E device keys for local users
@@ -210,6 +219,7 @@ class E2eKeysHandler(object):
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
local_query = []
result_dict = {}
@@ -217,6 +227,14 @@ class E2eKeysHandler(object):
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
logger.warning("Request for keys for non-local user %s", user_id)
log_kv(
{
"message": "Requested a local key for a user which"
" was not local to the homeserver",
"user_id": user_id,
}
)
set_tag("error", True)
raise SynapseError(400, "Not a user here")
if not device_ids:
@@ -241,6 +259,7 @@ class E2eKeysHandler(object):
r["unsigned"]["device_display_name"] = display_name
result_dict[user_id][device_id] = r
log_kv(results)
return result_dict
@defer.inlineCallbacks
@@ -251,6 +270,7 @@ class E2eKeysHandler(object):
res = yield self.query_local_devices(device_keys_query)
return {"device_keys": res}
@trace
@defer.inlineCallbacks
def claim_one_time_keys(self, query, timeout):
local_query = []
@@ -265,6 +285,9 @@ class E2eKeysHandler(object):
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
@@ -276,8 +299,10 @@ class E2eKeysHandler(object):
key_id: json.loads(json_bytes)
}
@trace
@defer.inlineCallbacks
def claim_client_keys(destination):
set_tag("destination", destination)
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys(
@@ -290,6 +315,8 @@ class E2eKeysHandler(object):
except Exception as e:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
yield make_deferred_yieldable(
defer.gatherResults(
@@ -313,9 +340,11 @@ class E2eKeysHandler(object):
),
)
log_kv({"one_time_keys": json_result, "failures": failures})
return {"one_time_keys": json_result, "failures": failures}
@defer.inlineCallbacks
@tag_args
def upload_keys_for_user(self, user_id, device_id, keys):
time_now = self.clock.time_msec()
@@ -329,6 +358,13 @@ class E2eKeysHandler(object):
user_id,
time_now,
)
log_kv(
{
"message": "Updating device_keys for user.",
"user_id": user_id,
"device_id": device_id,
}
)
# TODO: Sign the JSON with the server key
changed = yield self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys
@@ -336,12 +372,24 @@ class E2eKeysHandler(object):
if changed:
# Only notify about device updates *if* the keys actually changed
yield self.device_handler.notify_device_update(user_id, [device_id])
else:
log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
log_kv(
{
"message": "Updating one_time_keys for device.",
"user_id": user_id,
"device_id": device_id,
}
)
yield self._upload_one_time_keys_for_user(
user_id, device_id, time_now, one_time_keys
)
else:
log_kv(
{"message": "Did not update one_time_keys", "reason": "no keys given"}
)
# the device should have been registered already, but it may have been
# deleted due to a race with a DELETE request. Or we may be using an
@@ -352,6 +400,7 @@ class E2eKeysHandler(object):
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", result)
return {"one_time_key_counts": result}
@defer.inlineCallbacks
@@ -395,6 +444,7 @@ class E2eKeysHandler(object):
(algorithm, key_id, encode_canonical_json(key).decode("ascii"))
)
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
+26 -2
View File
@@ -26,6 +26,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
from synapse.logging.opentracing import log_kv, trace
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -49,6 +50,7 @@ class E2eRoomKeysHandler(object):
# changed.
self._upload_linearizer = Linearizer("upload_room_keys_lock")
@trace
@defer.inlineCallbacks
def get_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
@@ -84,8 +86,10 @@ class E2eRoomKeysHandler(object):
user_id, version, room_id, session_id
)
log_kv(results)
return results
@trace
@defer.inlineCallbacks
def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
@@ -107,6 +111,7 @@ class E2eRoomKeysHandler(object):
with (yield self._upload_linearizer.queue(user_id)):
yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
@trace
@defer.inlineCallbacks
def upload_room_keys(self, user_id, version, room_keys):
"""Bulk upload a list of room keys into a given backup version, asserting
@@ -186,7 +191,14 @@ class E2eRoomKeysHandler(object):
session_id(str): the session whose room_key we're setting
room_key(dict): the room_key being set
"""
log_kv(
{
"message": "Trying to upload room key",
"room_id": room_id,
"session_id": session_id,
"user_id": user_id,
}
)
# get the room_key for this particular row
current_room_key = None
try:
@@ -195,14 +207,23 @@ class E2eRoomKeysHandler(object):
)
except StoreError as e:
if e.code == 404:
pass
log_kv(
{
"message": "Room key not found.",
"room_id": room_id,
"user_id": user_id,
}
)
else:
raise
if self._should_replace_room_key(current_room_key, room_key):
log_kv({"message": "Replacing room key."})
yield self.store.set_e2e_room_key(
user_id, version, room_id, session_id, room_key
)
else:
log_kv({"message": "Not replacing room_key."})
@staticmethod
def _should_replace_room_key(current_room_key, room_key):
@@ -236,6 +257,7 @@ class E2eRoomKeysHandler(object):
return False
return True
@trace
@defer.inlineCallbacks
def create_version(self, user_id, version_info):
"""Create a new backup version. This automatically becomes the new
@@ -294,6 +316,7 @@ class E2eRoomKeysHandler(object):
raise
return res
@trace
@defer.inlineCallbacks
def delete_version(self, user_id, version=None):
"""Deletes a given version of the user's e2e_room_keys backup
@@ -314,6 +337,7 @@ class E2eRoomKeysHandler(object):
else:
raise
@trace
@defer.inlineCallbacks
def update_version(self, user_id, version, version_info):
"""Update the info about a given version of the user's backup
+3 -2
View File
@@ -326,8 +326,9 @@ class FederationHandler(BaseHandler):
ours = yield self.store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
# type: list[dict[tuple[str, str], str]]
state_maps = list(ours.values())
state_maps = list(
ours.values()
) # type: list[dict[tuple[str, str], str]]
# we don't need this any more, let's delete it.
del ours
+4 -1
View File
@@ -24,7 +24,7 @@ from twisted.internet import defer
from twisted.internet.defer import succeed
from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes
from synapse.api.errors import (
AuthError,
Codes,
@@ -469,6 +469,9 @@ class EventCreationHandler(object):
u = yield self.store.get_user_by_id(user_id)
assert u is not None
if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
# support and bot users are not required to consent
return
if u["appservice_id"] is not None:
# users registered by an appservice are exempt
return
+17
View File
@@ -70,6 +70,7 @@ class PaginationHandler(object):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self._server_name = hs.hostname
self.pagination_lock = ReadWriteLock()
self._purges_in_progress_by_room = set()
@@ -153,6 +154,22 @@ class PaginationHandler(object):
"""
return self._purges_by_id.get(purge_id)
async def purge_room(self, room_id):
"""Purge the given room from the database"""
with (await self.pagination_lock.write(room_id)):
# check we know about the room
await self.store.get_room_version(room_id)
# first check that we have no users in this room
joined = await defer.maybeDeferred(
self.store.is_host_joined, room_id, self._server_name
)
if joined:
raise SynapseError(400, "Users are still joined to this room")
await self.store.purge_room(room_id)
@defer.inlineCallbacks
def get_messages(
self,
+1 -1
View File
@@ -34,7 +34,7 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
MAX_DISPLAYNAME_LEN = 100
MAX_DISPLAYNAME_LEN = 256
MAX_AVATAR_URL_LEN = 1000
+28 -1
View File
@@ -25,6 +25,7 @@ from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, HttpResponseException
from synapse.types import ThirdPartyInstanceID
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
@@ -485,7 +486,33 @@ class RoomListHandler(BaseHandler):
return {"chunk": [], "total_room_count_estimate": 0}
if search_filter:
# We currently don't support searching across federation, so we have
# Searching across federation is defined in MSC2197.
# However, the remote homeserver may or may not actually support it.
# So we first try an MSC2197 remote-filtered search, then fall back
# to a locally-filtered search if we must.
try:
res = yield self._get_remote_list_cached(
server_name,
limit=limit,
since_token=since_token,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
search_filter=search_filter,
)
return res
except HttpResponseException as hre:
syn_err = hre.to_synapse_error()
if hre.code in (404, 405) or syn_err.errcode in (
Codes.UNRECOGNIZED,
Codes.NOT_FOUND,
):
logger.debug("Falling back to locally-filtered /publicRooms")
else:
raise # Not an error that should trigger a fallback.
# if we reach this point, then we fall back to the situation where
# we currently don't support searching across federation, so we have
# to do it manually without pagination
limit = None
since_token = None
+199 -201
View File
@@ -14,21 +14,21 @@
# limitations under the License.
import logging
import urllib
import attr
from netaddr import IPAddress
from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.client import Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent
from twisted.web.iweb import IAgent, IAgentEndpointFactory
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -36,8 +36,9 @@ logger = logging.getLogger(__name__)
@implementer(IAgent)
class MatrixFederationAgent(object):
"""An Agent-like thing which provides a `request` method which will look up a matrix
server and send an HTTP request to it.
"""An Agent-like thing which provides a `request` method which correctly
handles resolving matrix server names when using matrix://. Handles standard
https URIs as normal.
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
@@ -51,9 +52,9 @@ class MatrixFederationAgent(object):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
_well_known_cache (TTLCache|None):
TTLCache impl for storing cached well-known lookups. None to use a default
implementation.
_well_known_resolver (WellKnownResolver|None):
WellKnownResolver to use to perform well-known lookups. None to use a
default implementation.
"""
def __init__(
@@ -61,49 +62,49 @@ class MatrixFederationAgent(object):
reactor,
tls_client_options_factory,
_srv_resolver=None,
_well_known_cache=None,
_well_known_resolver=None,
):
self._reactor = reactor
self._clock = Clock(reactor)
self._tls_client_options_factory = tls_client_options_factory
if _srv_resolver is None:
_srv_resolver = SrvResolver()
self._srv_resolver = _srv_resolver
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60
self._well_known_resolver = WellKnownResolver(
self._agent = Agent.usingEndpointFactory(
self._reactor,
agent=Agent(
self._reactor,
pool=self._pool,
contextFactory=tls_client_options_factory,
MatrixHostnameEndpointFactory(
reactor, tls_client_options_factory, _srv_resolver
),
well_known_cache=_well_known_cache,
pool=self._pool,
)
if _well_known_resolver is None:
_well_known_resolver = WellKnownResolver(
self._reactor,
agent=Agent(
self._reactor,
pool=self._pool,
contextFactory=tls_client_options_factory,
),
)
self._well_known_resolver = _well_known_resolver
@defer.inlineCallbacks
def request(self, method, uri, headers=None, bodyProducer=None):
"""
Args:
method (bytes): HTTP method: GET/POST/etc
uri (bytes): Absolute URI to be retrieved
headers (twisted.web.http_headers.Headers|None):
HTTP headers to send with the request, or None to
send no extra headers.
bodyProducer (twisted.web.iweb.IBodyProducer|None):
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
no body.
Returns:
Deferred[twisted.web.iweb.IResponse]:
fires when the header of the response has been received (regardless of the
@@ -111,210 +112,207 @@ class MatrixFederationAgent(object):
response from being received (including problems that prevent the request
from being sent).
"""
parsed_uri = URI.fromBytes(uri, defaultPort=-1)
res = yield self._route_matrix_uri(parsed_uri)
# We use urlparse as that will set `port` to None if there is no
# explicit port.
parsed_uri = urllib.parse.urlparse(uri)
# set up the TLS connection params
# If this is a matrix:// URI check if the server has delegated matrix
# traffic using well-known delegation.
#
# XXX disabling TLS is really only supported here for the benefit of the
# unit tests. We should make the UTs cope with TLS rather than having to make
# the code support the unit tests.
if self._tls_client_options_factory is None:
tls_options = None
else:
tls_options = self._tls_client_options_factory.get_options(
res.tls_server_name.decode("ascii")
# We have to do this here and not in the endpoint as we need to rewrite
# the host header with the delegated server name.
delegated_server = None
if (
parsed_uri.scheme == b"matrix"
and not _is_ip_literal(parsed_uri.hostname)
and not parsed_uri.port
):
well_known_result = yield self._well_known_resolver.get_well_known(
parsed_uri.hostname
)
delegated_server = well_known_result.delegated_server
# make sure that the Host header is set correctly
if delegated_server:
# Ok, the server has delegated matrix traffic to somewhere else, so
# lets rewrite the URL to replace the server with the delegated
# server name.
uri = urllib.parse.urlunparse(
(
parsed_uri.scheme,
delegated_server,
parsed_uri.path,
parsed_uri.params,
parsed_uri.query,
parsed_uri.fragment,
)
)
parsed_uri = urllib.parse.urlparse(uri)
# We need to make sure the host header is set to the netloc of the
# server.
if headers is None:
headers = Headers()
else:
headers = headers.copy()
if not headers.hasHeader(b"host"):
headers.addRawHeader(b"host", res.host_header)
headers.addRawHeader(b"host", parsed_uri.netloc)
class EndpointFactory(object):
@staticmethod
def endpointForURI(_uri):
ep = LoggingHostnameEndpoint(
self._reactor, res.target_host, res.target_port
)
if tls_options is not None:
ep = wrapClientTLS(tls_options, ep)
return ep
agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
res = yield make_deferred_yieldable(
agent.request(method, uri, headers, bodyProducer)
self._agent.request(method, uri, headers, bodyProducer)
)
return res
@defer.inlineCallbacks
def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
"""Helper for `request`: determine the routing for a Matrix URI
Args:
parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
if there is no explicit port given.
@implementer(IAgentEndpointFactory)
class MatrixHostnameEndpointFactory(object):
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.
"""
lookup_well_known (bool): True if we should look up the .well-known file if
there is no SRV record.
def __init__(self, reactor, tls_client_options_factory, srv_resolver):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
Returns:
Deferred[_RoutingResult]
"""
# check for an IP literal
try:
ip_address = IPAddress(parsed_uri.host.decode("ascii"))
except Exception:
# not an IP address
ip_address = None
if srv_resolver is None:
srv_resolver = SrvResolver()
if ip_address:
port = parsed_uri.port
if port == -1:
port = 8448
return _RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=parsed_uri.host,
target_port=port,
)
self._srv_resolver = srv_resolver
if parsed_uri.port != -1:
# there is an explicit port
return _RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=parsed_uri.host,
target_port=parsed_uri.port,
)
if lookup_well_known:
# try a .well-known lookup
well_known_result = yield self._well_known_resolver.get_well_known(
parsed_uri.host
)
well_known_server = well_known_result.delegated_server
if well_known_server:
# if we found a .well-known, start again, but don't do another
# .well-known lookup.
# parse the server name in the .well-known response into host/port.
# (This code is lifted from twisted.web.client.URI.fromBytes).
if b":" in well_known_server:
well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
try:
well_known_port = int(well_known_port)
except ValueError:
# the part after the colon could not be parsed as an int
# - we assume it is an IPv6 literal with no port (the closing
# ']' stops it being parsed as an int)
well_known_host, well_known_port = well_known_server, -1
else:
well_known_host, well_known_port = well_known_server, -1
new_uri = URI(
scheme=parsed_uri.scheme,
netloc=well_known_server,
host=well_known_host,
port=well_known_port,
path=parsed_uri.path,
params=parsed_uri.params,
query=parsed_uri.query,
fragment=parsed_uri.fragment,
)
res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
return res
# try a SRV lookup
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
server_list = yield self._srv_resolver.resolve_service(service_name)
if not server_list:
target_host = parsed_uri.host
port = 8448
logger.debug(
"No SRV record for %s, using %s:%i",
parsed_uri.host.decode("ascii"),
target_host.decode("ascii"),
port,
)
else:
target_host, port = pick_server_from_list(server_list)
logger.debug(
"Picked %s:%i from SRV records for %s",
target_host.decode("ascii"),
port,
parsed_uri.host.decode("ascii"),
)
return _RoutingResult(
host_header=parsed_uri.netloc,
tls_server_name=parsed_uri.host,
target_host=target_host,
target_port=port,
def endpointForURI(self, parsed_uri):
return MatrixHostnameEndpoint(
self._reactor,
self._tls_client_options_factory,
self._srv_resolver,
parsed_uri,
)
@implementer(IStreamClientEndpoint)
class LoggingHostnameEndpoint(object):
"""A wrapper for HostnameEndpint which logs when it connects"""
class MatrixHostnameEndpoint(object):
"""An endpoint that resolves matrix:// URLs using Matrix server name
resolution (i.e. via SRV). Does not check for well-known delegation.
def __init__(self, reactor, host, port, *args, **kwargs):
self.host = host
self.port = port
self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
Args:
reactor (IReactor)
tls_client_options_factory (ClientTLSOptionsFactory|None):
factory to use for fetching client tls options, or none to disable TLS.
srv_resolver (SrvResolver): The SRV resolver to use
parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
to connect to.
"""
def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
self._reactor = reactor
self._parsed_uri = parsed_uri
# set up the TLS connection params
#
# XXX disabling TLS is really only supported here for the benefit of the
# unit tests. We should make the UTs cope with TLS rather than having to make
# the code support the unit tests.
if tls_client_options_factory is None:
self._tls_options = None
else:
self._tls_options = tls_client_options_factory.get_options(
self._parsed_uri.host.decode("ascii")
)
self._srv_resolver = srv_resolver
def connect(self, protocol_factory):
logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
return self.ep.connect(protocol_factory)
"""Implements IStreamClientEndpoint interface
"""
return run_in_background(self._do_connect, protocol_factory)
@defer.inlineCallbacks
def _do_connect(self, protocol_factory):
first_exception = None
server_list = yield self._resolve_server()
for server in server_list:
host = server.host
port = server.port
try:
logger.info("Connecting to %s:%i", host.decode("ascii"), port)
endpoint = HostnameEndpoint(self._reactor, host, port)
if self._tls_options:
endpoint = wrapClientTLS(self._tls_options, endpoint)
result = yield make_deferred_yieldable(
endpoint.connect(protocol_factory)
)
return result
except Exception as e:
logger.info(
"Failed to connect to %s:%i: %s", host.decode("ascii"), port, e
)
if not first_exception:
first_exception = e
# We return the first failure because that's probably the most interesting.
if first_exception:
raise first_exception
# This shouldn't happen as we should always have at least one host/port
# to try and if that doesn't work then we'll have an exception.
raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
@defer.inlineCallbacks
def _resolve_server(self):
"""Resolves the server name to a list of hosts and ports to attempt to
connect to.
Returns:
Deferred[list[Server]]
"""
if self._parsed_uri.scheme != b"matrix":
return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)]
# Note: We don't do well-known lookup as that needs to have happened
# before now, due to needing to rewrite the Host header of the HTTP
# request.
# We reparse the URI so that defaultPort is -1 rather than 80
parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes())
host = parsed_uri.hostname
port = parsed_uri.port
# If there is an explicit port or the host is an IP address we bypass
# SRV lookups and just use the given host/port.
if port or _is_ip_literal(host):
return [Server(host, port or 8448)]
server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
if server_list:
return server_list
# No SRV records, so we fallback to host and 8448
return [Server(host, 8448)]
@attr.s
class _RoutingResult(object):
"""The result returned by `_route_matrix_uri`.
def _is_ip_literal(host):
"""Test if the given host name is either an IPv4 or IPv6 literal.
Contains the parameters needed to direct a federation connection to a particular
server.
Args:
host (bytes)
Where a SRV record points to several servers, this object contains a single server
chosen from the list.
Returns:
bool
"""
host_header = attr.ib()
"""
The value we should assign to the Host header (host:port from the matrix
URI, or .well-known).
host = host.decode("ascii")
:type: bytes
"""
tls_server_name = attr.ib()
"""
The server name we should set in the SNI (typically host, without port, from the
matrix URI or .well-known)
:type: bytes
"""
target_host = attr.ib()
"""
The hostname (or IP literal) we should route the TCP connection to (the target of the
SRV record, or the hostname from the URL/.well-known)
:type: bytes
"""
target_port = attr.ib()
"""
The port we should route the TCP connection to (the target of the SRV record, or
the port from the URL/.well-known, or 8448)
:type: int
"""
try:
IPAddress(host)
return True
except AddrFormatError:
return False
+38 -25
View File
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
SERVER_CACHE = {}
@attr.s
@attr.s(slots=True, frozen=True)
class Server(object):
"""
Our record of an individual server which can be tried to reach a destination.
@@ -53,34 +53,47 @@ class Server(object):
expires = attr.ib(default=0)
def pick_server_from_list(server_list):
"""Randomly choose a server from the server list
Args:
server_list (list[Server]): list of candidate servers
Returns:
Tuple[bytes, int]: (host, port) pair for the chosen server
def _sort_server_list(server_list):
"""Given a list of SRV records sort them into priority order and shuffle
each priority with the given weight.
"""
if not server_list:
raise RuntimeError("pick_server_from_list called with empty list")
priority_map = {}
# TODO: currently we only use the lowest-priority servers. We should maintain a
# cache of servers known to be "down" and filter them out
for server in server_list:
priority_map.setdefault(server.priority, []).append(server)
min_priority = min(s.priority for s in server_list)
eligible_servers = list(s for s in server_list if s.priority == min_priority)
total_weight = sum(s.weight for s in eligible_servers)
target_weight = random.randint(0, total_weight)
results = []
for priority in sorted(priority_map):
servers = priority_map[priority]
for s in eligible_servers:
target_weight -= s.weight
# This algorithms roughly follows the algorithm described in RFC2782,
# changed to remove an off-by-one error.
#
# N.B. Weights can be zero, which means that they should be picked
# rarely.
if target_weight <= 0:
return s.host, s.port
total_weight = sum(s.weight for s in servers)
# this should be impossible.
raise RuntimeError("pick_server_from_list got to end of eligible server list.")
# Total weight can become zero if there are only zero weight servers
# left, which we handle by just shuffling and appending to the results.
while servers and total_weight:
target_weight = random.randint(1, total_weight)
for s in servers:
target_weight -= s.weight
if target_weight <= 0:
break
results.append(s)
servers.remove(s)
total_weight -= s.weight
if servers:
random.shuffle(servers)
results.extend(servers)
return results
class SrvResolver(object):
@@ -120,7 +133,7 @@ class SrvResolver(object):
if cache_entry:
if all(s.expires > now for s in cache_entry):
servers = list(cache_entry)
return servers
return _sort_server_list(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
@@ -169,4 +182,4 @@ class SrvResolver(object):
)
self._cache[service_name] = list(servers)
return servers
return _sort_server_list(servers)
+101 -25
View File
@@ -32,12 +32,19 @@ from synapse.util.metrics import Measure
# period to cache .well-known results for by default
WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
# jitter to add to the .well-known default cache ttl
WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60
# jitter factor to add to the .well-known default cache ttls
WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 0.1
# period to cache failure to fetch .well-known for
WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
# period to cache failure to fetch .well-known if there has recently been a
# valid well-known for that domain.
WELL_KNOWN_DOWN_CACHE_PERIOD = 2 * 60
# period to remember there was a valid well-known after valid record expires
WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID = 2 * 3600
# cap for .well-known cache period
WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
@@ -49,11 +56,16 @@ WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60
# we'll start trying to refetch 1 minute before it expires.
WELL_KNOWN_GRACE_PERIOD_FACTOR = 0.2
# Number of times we retry fetching a well-known for a domain we know recently
# had a valid entry.
WELL_KNOWN_RETRY_ATTEMPTS = 3
logger = logging.getLogger(__name__)
_well_known_cache = TTLCache("well-known")
_had_valid_well_known_cache = TTLCache("had-valid-well-known")
@attr.s(slots=True, frozen=True)
@@ -65,14 +77,20 @@ class WellKnownResolver(object):
"""Handles well-known lookups for matrix servers.
"""
def __init__(self, reactor, agent, well_known_cache=None):
def __init__(
self, reactor, agent, well_known_cache=None, had_well_known_cache=None
):
self._reactor = reactor
self._clock = Clock(reactor)
if well_known_cache is None:
well_known_cache = _well_known_cache
if had_well_known_cache is None:
had_well_known_cache = _had_valid_well_known_cache
self._well_known_cache = well_known_cache
self._had_valid_well_known_cache = had_well_known_cache
self._well_known_agent = RedirectAgent(agent)
@defer.inlineCallbacks
@@ -100,7 +118,7 @@ class WellKnownResolver(object):
# requests for the same server in parallel?
try:
with Measure(self._clock, "get_well_known"):
result, cache_period = yield self._do_get_well_known(server_name)
result, cache_period = yield self._fetch_well_known(server_name)
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:
@@ -111,10 +129,18 @@ class WellKnownResolver(object):
result = None
# add some randomness to the TTL to avoid a stampeding herd every hour
# after startup
cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
if self._had_valid_well_known_cache.get(server_name, False):
# We have recently seen a valid well-known record for this
# server, so we cache the lack of well-known for a shorter time.
cache_period = WELL_KNOWN_DOWN_CACHE_PERIOD
else:
cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
# add some randomness to the TTL to avoid a stampeding herd
cache_period *= random.uniform(
1 - WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
1 + WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
)
if cache_period > 0:
self._well_known_cache.set(server_name, result, cache_period)
@@ -122,7 +148,7 @@ class WellKnownResolver(object):
return WellKnownLookupResult(delegated_server=result)
@defer.inlineCallbacks
def _do_get_well_known(self, server_name):
def _fetch_well_known(self, server_name):
"""Actually fetch and parse a .well-known, without checking the cache
Args:
@@ -134,24 +160,15 @@ class WellKnownResolver(object):
Returns:
Deferred[Tuple[bytes,int]]: The lookup result and cache period.
"""
uri = b"https://%s/.well-known/matrix/server" % (server_name,)
uri_str = uri.decode("ascii")
logger.info("Fetching %s", uri_str)
had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
# We do this in two steps to differentiate between possibly transient
# errors (e.g. can't connect to host, 503 response) and more permenant
# errors (such as getting a 404 response).
try:
response = yield make_deferred_yieldable(
self._well_known_agent.request(b"GET", uri)
)
body = yield make_deferred_yieldable(readBody(response))
if 500 <= response.code < 600:
raise Exception("Non-200 response %s" % (response.code,))
except Exception as e:
logger.info("Error fetching %s: %s", uri_str, e)
raise _FetchWellKnownFailure(temporary=True)
response, body = yield self._make_well_known_request(
server_name, retry=had_valid_well_known
)
try:
if response.code != 200:
@@ -161,8 +178,11 @@ class WellKnownResolver(object):
logger.info("Response from .well-known: %s", parsed_body)
result = parsed_body["m.server"].encode("ascii")
except defer.CancelledError:
# Bail if we've been cancelled
raise
except Exception as e:
logger.info("Error fetching %s: %s", uri_str, e)
logger.info("Error parsing well-known for %s: %s", server_name, e)
raise _FetchWellKnownFailure(temporary=False)
cache_period = _cache_period_from_headers(
@@ -172,13 +192,69 @@ class WellKnownResolver(object):
cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
# add some randomness to the TTL to avoid a stampeding herd every 24 hours
# after startup
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
cache_period *= random.uniform(
1 - WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
1 + WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
)
else:
cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
cache_period = max(cache_period, WELL_KNOWN_MIN_CACHE_PERIOD)
# We got a success, mark as such in the cache
self._had_valid_well_known_cache.set(
server_name,
bool(result),
cache_period + WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID,
)
return (result, cache_period)
@defer.inlineCallbacks
def _make_well_known_request(self, server_name, retry):
"""Make the well known request.
This will retry the request if requested and it fails (with unable
to connect or receives a 5xx error).
Args:
server_name (bytes)
retry (bool): Whether to retry the request if it fails.
Returns:
Deferred[tuple[IResponse, bytes]] Returns the response object and
body. Response may be a non-200 response.
"""
uri = b"https://%s/.well-known/matrix/server" % (server_name,)
uri_str = uri.decode("ascii")
i = 0
while True:
i += 1
logger.info("Fetching %s", uri_str)
try:
response = yield make_deferred_yieldable(
self._well_known_agent.request(b"GET", uri)
)
body = yield make_deferred_yieldable(readBody(response))
if 500 <= response.code < 600:
raise Exception("Non-200 response %s" % (response.code,))
return response, body
except defer.CancelledError:
# Bail if we've been cancelled
raise
except Exception as e:
if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS:
logger.info("Error fetching %s: %s", uri_str, e)
raise _FetchWellKnownFailure(temporary=True)
logger.info("Error fetching %s: %s. Retrying", uri_str, e)
# Sleep briefly in the hopes that they come back up
yield self._clock.sleep(0.5)
def _cache_period_from_headers(headers, time_now=time.time):
cache_controls = _parse_cache_control(headers)
+1 -1
View File
@@ -300,7 +300,7 @@ class RestServlet(object):
http_server.register_paths(
method,
patterns,
trace_servlet(servlet_classname, method_handler),
trace_servlet(servlet_classname)(method_handler),
servlet_classname,
)
+374
View File
@@ -0,0 +1,374 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os.path
import sys
import typing
import warnings
import attr
from constantly import NamedConstant, Names, ValueConstant, Values
from zope.interface import implementer
from twisted.logger import (
FileLogObserver,
FilteringLogObserver,
ILogObserver,
LogBeginner,
Logger,
LogLevel,
LogLevelFilterPredicate,
LogPublisher,
eventAsText,
globalLogBeginner,
jsonFileLogObserver,
)
from synapse.config._base import ConfigError
from synapse.logging._terse_json import (
TerseJSONToConsoleLogObserver,
TerseJSONToTCPLogObserver,
)
from synapse.logging.context import LoggingContext
def stdlib_log_level_to_twisted(level: str) -> LogLevel:
"""
Convert a stdlib log level to Twisted's log level.
"""
lvl = level.lower().replace("warning", "warn")
return LogLevel.levelWithName(lvl)
@attr.s
@implementer(ILogObserver)
class LogContextObserver(object):
"""
An ILogObserver which adds Synapse-specific log context information.
Attributes:
observer (ILogObserver): The target parent observer.
"""
observer = attr.ib()
def __call__(self, event: dict) -> None:
"""
Consume a log event and emit it to the parent observer after filtering
and adding log context information.
Args:
event (dict)
"""
# Filter out some useless events that Twisted outputs
if "log_text" in event:
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
return
if event["log_text"].startswith("(UDP Port "):
return
if event["log_text"].startswith("Timing out client") or event[
"log_format"
].startswith("Timing out client"):
return
context = LoggingContext.current_context()
# Copy the context information to the log event.
if context is not None:
context.copy_to_twisted_log_entry(event)
else:
# If there's no logging context, not even the root one, we might be
# starting up or it might be from non-Synapse code. Log it as if it
# came from the root logger.
event["request"] = None
event["scope"] = None
self.observer(event)
class PythonStdlibToTwistedLogger(logging.Handler):
"""
Transform a Python stdlib log message into a Twisted one.
"""
def __init__(self, observer, *args, **kwargs):
"""
Args:
observer (ILogObserver): A Twisted logging observer.
*args, **kwargs: Args/kwargs to be passed to logging.Handler.
"""
self.observer = observer
super().__init__(*args, **kwargs)
def emit(self, record: logging.LogRecord) -> None:
"""
Emit a record to Twisted's observer.
Args:
record (logging.LogRecord)
"""
self.observer(
{
"log_time": record.created,
"log_text": record.getMessage(),
"log_format": "{log_text}",
"log_namespace": record.name,
"log_level": stdlib_log_level_to_twisted(record.levelname),
}
)
def SynapseFileLogObserver(outFile: typing.io.TextIO) -> FileLogObserver:
"""
A log observer that formats events like the traditional log formatter and
sends them to `outFile`.
Args:
outFile (file object): The file object to write to.
"""
def formatEvent(_event: dict) -> str:
event = dict(_event)
event["log_level"] = event["log_level"].name.upper()
event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + (
event.get("log_format", "{log_text}") or "{log_text}"
)
return eventAsText(event, includeSystem=False) + "\n"
return FileLogObserver(outFile, formatEvent)
class DrainType(Names):
CONSOLE = NamedConstant()
CONSOLE_JSON = NamedConstant()
CONSOLE_JSON_TERSE = NamedConstant()
FILE = NamedConstant()
FILE_JSON = NamedConstant()
NETWORK_JSON_TERSE = NamedConstant()
class OutputPipeType(Values):
stdout = ValueConstant(sys.__stdout__)
stderr = ValueConstant(sys.__stderr__)
@attr.s
class DrainConfiguration(object):
name = attr.ib()
type = attr.ib()
location = attr.ib()
options = attr.ib(default=None)
@attr.s
class NetworkJSONTerseOptions(object):
maximum_buffer = attr.ib(type=int)
DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
def parse_drain_configs(
drains: dict
) -> typing.Generator[DrainConfiguration, None, None]:
"""
Parse the drain configurations.
Args:
drains (dict): A list of drain configurations.
Yields:
DrainConfiguration instances.
Raises:
ConfigError: If any of the drain configuration items are invalid.
"""
for name, config in drains.items():
if "type" not in config:
raise ConfigError("Logging drains require a 'type' key.")
try:
logging_type = DrainType.lookupByName(config["type"].upper())
except ValueError:
raise ConfigError(
"%s is not a known logging drain type." % (config["type"],)
)
if logging_type in [
DrainType.CONSOLE,
DrainType.CONSOLE_JSON,
DrainType.CONSOLE_JSON_TERSE,
]:
location = config.get("location")
if location is None or location not in ["stdout", "stderr"]:
raise ConfigError(
(
"The %s drain needs the 'location' key set to "
"either 'stdout' or 'stderr'."
)
% (logging_type,)
)
pipe = OutputPipeType.lookupByName(location).value
yield DrainConfiguration(name=name, type=logging_type, location=pipe)
elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
if "location" not in config:
raise ConfigError(
"The %s drain needs the 'location' key set." % (logging_type,)
)
location = config.get("location")
if os.path.abspath(location) != location:
raise ConfigError(
"File paths need to be absolute, '%s' is a relative path"
% (location,)
)
yield DrainConfiguration(name=name, type=logging_type, location=location)
elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
host = config.get("host")
port = config.get("port")
maximum_buffer = config.get("maximum_buffer", 1000)
yield DrainConfiguration(
name=name,
type=logging_type,
location=(host, port),
options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer),
)
else:
raise ConfigError(
"The %s drain type is currently not implemented."
% (config["type"].upper(),)
)
def setup_structured_logging(
hs,
config,
log_config: dict,
logBeginner: LogBeginner = globalLogBeginner,
redirect_stdlib_logging: bool = True,
) -> LogPublisher:
"""
Set up Twisted's structured logging system.
Args:
hs: The homeserver to use.
config (HomeserverConfig): The configuration of the Synapse homeserver.
log_config (dict): The log configuration to use.
"""
if config.no_redirect_stdio:
raise ConfigError(
"no_redirect_stdio cannot be defined using structured logging."
)
logger = Logger()
if "drains" not in log_config:
raise ConfigError("The logging configuration requires a list of drains.")
observers = []
for observer in parse_drain_configs(log_config["drains"]):
# Pipe drains
if observer.type == DrainType.CONSOLE:
logger.debug(
"Starting up the {name} console logger drain", name=observer.name
)
observers.append(SynapseFileLogObserver(observer.location))
elif observer.type == DrainType.CONSOLE_JSON:
logger.debug(
"Starting up the {name} JSON console logger drain", name=observer.name
)
observers.append(jsonFileLogObserver(observer.location))
elif observer.type == DrainType.CONSOLE_JSON_TERSE:
logger.debug(
"Starting up the {name} terse JSON console logger drain",
name=observer.name,
)
observers.append(
TerseJSONToConsoleLogObserver(observer.location, metadata={})
)
# File drains
elif observer.type == DrainType.FILE:
logger.debug("Starting up the {name} file logger drain", name=observer.name)
log_file = open(observer.location, "at", buffering=1, encoding="utf8")
observers.append(SynapseFileLogObserver(log_file))
elif observer.type == DrainType.FILE_JSON:
logger.debug(
"Starting up the {name} JSON file logger drain", name=observer.name
)
log_file = open(observer.location, "at", buffering=1, encoding="utf8")
observers.append(jsonFileLogObserver(log_file))
elif observer.type == DrainType.NETWORK_JSON_TERSE:
metadata = {"server_name": hs.config.server_name}
log_observer = TerseJSONToTCPLogObserver(
hs=hs,
host=observer.location[0],
port=observer.location[1],
metadata=metadata,
maximum_buffer=observer.options.maximum_buffer,
)
log_observer.start()
observers.append(log_observer)
else:
# We should never get here, but, just in case, throw an error.
raise ConfigError("%s drain type cannot be configured" % (observer.type,))
publisher = LogPublisher(*observers)
log_filter = LogLevelFilterPredicate()
for namespace, namespace_config in log_config.get(
"loggers", DEFAULT_LOGGERS
).items():
# Set the log level for twisted.logger.Logger namespaces
log_filter.setLogLevelForNamespace(
namespace,
stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")),
)
# Also set the log levels for the stdlib logger namespaces, to prevent
# them getting to PythonStdlibToTwistedLogger and having to be formatted
if "level" in namespace_config:
logging.getLogger(namespace).setLevel(namespace_config.get("level"))
f = FilteringLogObserver(publisher, [log_filter])
lco = LogContextObserver(f)
if redirect_stdlib_logging:
stuff_into_twisted = PythonStdlibToTwistedLogger(lco)
stdliblogger = logging.getLogger()
stdliblogger.addHandler(stuff_into_twisted)
# Always redirect standard I/O, otherwise other logging outputs might miss
# it.
logBeginner.beginLoggingTo([lco], redirectStandardIO=True)
return publisher
def reload_structured_logging(*args, log_config=None) -> None:
warnings.warn(
"Currently the structured logging system can not be reloaded, doing nothing"
)
+278
View File
@@ -0,0 +1,278 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Log formatters that output terse JSON.
"""
import sys
from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor
from typing.io import TextIO
import attr
from simplejson import dumps
from twisted.application.internet import ClientService
from twisted.internet.endpoints import (
HostnameEndpoint,
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
from twisted.internet.protocol import Factory, Protocol
from twisted.logger import FileLogObserver, Logger
from twisted.python.failure import Failure
def flatten_event(event: dict, metadata: dict, include_time: bool = False):
"""
Flatten a Twisted logging event to an dictionary capable of being sent
as a log event to a logging aggregation system.
The format is vastly simplified and is not designed to be a "human readable
string" in the sense that traditional logs are. Instead, the structure is
optimised for searchability and filtering, with human-understandable log
keys.
Args:
event (dict): The Twisted logging event we are flattening.
metadata (dict): Additional data to include with each log message. This
can be information like the server name. Since the target log
consumer does not know who we are other than by host IP, this
allows us to forward through static information.
include_time (bool): Should we include the `time` key? If False, the
event time is stripped from the event.
"""
new_event = {}
# If it's a failure, make the new event's log_failure be the traceback text.
if "log_failure" in event:
new_event["log_failure"] = event["log_failure"].getTraceback()
# If it's a warning, copy over a string representation of the warning.
if "warning" in event:
new_event["warning"] = str(event["warning"])
# Stdlib logging events have "log_text" as their human-readable portion,
# Twisted ones have "log_format". For now, include the log_format, so that
# context only given in the log format (e.g. what is being logged) is
# available.
if "log_text" in event:
new_event["log"] = event["log_text"]
else:
new_event["log"] = event["log_format"]
# We want to include the timestamp when forwarding over the network, but
# exclude it when we are writing to stdout. This is because the log ingester
# (e.g. logstash, fluentd) can add its own timestamp.
if include_time:
new_event["time"] = round(event["log_time"], 2)
# Convert the log level to a textual representation.
new_event["level"] = event["log_level"].name.upper()
# Ignore these keys, and do not transfer them over to the new log object.
# They are either useless (isError), transferred manually above (log_time,
# log_level, etc), or contain Python objects which are not useful for output
# (log_logger, log_source).
keys_to_delete = [
"isError",
"log_failure",
"log_format",
"log_level",
"log_logger",
"log_source",
"log_system",
"log_time",
"log_text",
"observer",
"warning",
]
# If it's from the Twisted legacy logger (twisted.python.log), it adds some
# more keys we want to purge.
if event.get("log_namespace") == "log_legacy":
keys_to_delete.extend(["message", "system", "time"])
# Rather than modify the dictionary in place, construct a new one with only
# the content we want. The original event should be considered 'frozen'.
for key in event.keys():
if key in keys_to_delete:
continue
if isinstance(event[key], (str, int, bool, float)) or event[key] is None:
# If it's a plain type, include it as is.
new_event[key] = event[key]
else:
# If it's not one of those basic types, write out a string
# representation. This should probably be a warning in development,
# so that we are sure we are only outputting useful data.
new_event[key] = str(event[key])
# Add the metadata information to the event (e.g. the server_name).
new_event.update(metadata)
return new_event
def TerseJSONToConsoleLogObserver(outFile: TextIO, metadata: dict) -> FileLogObserver:
"""
A log observer that formats events to a flattened JSON representation.
Args:
outFile: The file object to write to.
metadata: Metadata to be added to each log object.
"""
def formatEvent(_event: dict) -> str:
flattened = flatten_event(_event, metadata)
return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n"
return FileLogObserver(outFile, formatEvent)
@attr.s
class TerseJSONToTCPLogObserver(object):
"""
An IObserver that writes JSON logs to a TCP target.
Args:
hs (HomeServer): The Homeserver that is being logged for.
host: The host of the logging target.
port: The logging target's port.
metadata: Metadata to be added to each log entry.
"""
hs = attr.ib()
host = attr.ib(type=str)
port = attr.ib(type=int)
metadata = attr.ib(type=dict)
maximum_buffer = attr.ib(type=int)
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
_writer = attr.ib(default=None)
_logger = attr.ib(default=attr.Factory(Logger))
def start(self) -> None:
# Connect without DNS lookups if it's a direct IP.
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
endpoint = TCP4ClientEndpoint(
self.hs.get_reactor(), self.host, self.port
)
elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(
self.hs.get_reactor(), self.host, self.port
)
except ValueError:
endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
factory = Factory.forProtocol(Protocol)
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
self._service.startService()
def _write_loop(self) -> None:
"""
Implement the write loop.
"""
if self._writer:
return
self._writer = self._service.whenConnected()
@self._writer.addBoth
def writer(r):
if isinstance(r, Failure):
r.printTraceback(file=sys.__stderr__)
self._writer = None
self.hs.get_reactor().callLater(1, self._write_loop)
return
try:
for event in self._buffer:
r.transport.write(
dumps(event, ensure_ascii=False, separators=(",", ":")).encode(
"utf8"
)
)
r.transport.write(b"\n")
self._buffer.clear()
except Exception as e:
sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),))
self._writer = False
self.hs.get_reactor().callLater(1, self._write_loop)
def _handle_pressure(self) -> None:
"""
Handle backpressure by shedding events.
The buffer will, in this order, until the buffer is below the maximum:
- Shed DEBUG events
- Shed INFO events
- Shed the middle 50% of the events.
"""
if len(self._buffer) <= self.maximum_buffer:
return
# Strip out DEBUGs
self._buffer = deque(
filter(lambda event: event["level"] != "DEBUG", self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
return
# Strip out INFOs
self._buffer = deque(
filter(lambda event: event["level"] != "INFO", self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
return
# Cut the middle entries out
buffer_split = floor(self.maximum_buffer / 2)
old_buffer = self._buffer
self._buffer = deque()
for i in range(buffer_split):
self._buffer.append(old_buffer.popleft())
end_buffer = []
for i in range(buffer_split):
end_buffer.append(old_buffer.pop())
self._buffer.extend(reversed(end_buffer))
def __call__(self, event: dict) -> None:
flattened = flatten_event(event, self.metadata, include_time=True)
self._buffer.append(flattened)
# Handle backpressure, if it exists.
try:
self._handle_pressure()
except Exception:
# If handling backpressure fails,clear the buffer and log the
# exception.
self._buffer.clear()
self._logger.failure("Failed clearing backpressure")
# Try and write immediately.
self._write_loop()
+13 -1
View File
@@ -25,6 +25,7 @@ See doc/log_contexts.rst for details on how this works.
import logging
import threading
import types
from typing import Any, List
from twisted.internet import defer, threads
@@ -194,7 +195,7 @@ class LoggingContext(object):
class Sentinel(object):
"""Sentinel to represent the root context"""
__slots__ = []
__slots__ = [] # type: List[Any]
def __str__(self):
return "sentinel"
@@ -202,6 +203,10 @@ class LoggingContext(object):
def copy_to(self, record):
pass
def copy_to_twisted_log_entry(self, record):
record["request"] = None
record["scope"] = None
def start(self):
pass
@@ -330,6 +335,13 @@ class LoggingContext(object):
# we also track the current scope:
record.scope = self.scope
def copy_to_twisted_log_entry(self, record):
"""
Copy logging fields from this context to a Twisted log record.
"""
record["request"] = self.request
record["scope"] = self.scope
def start(self):
if get_thread_id() != self.main_thread:
logger.warning("Started logcontext %s on different thread", self)
+105 -65
View File
@@ -149,6 +149,9 @@ unchartered waters will require the enforcement of the whitelist.
``logging/opentracing.py`` has a ``whitelisted_homeserver`` method which takes
in a destination and compares it to the whitelist.
Most injection methods take a 'destination' arg. The context will only be injected
if the destination matches the whitelist or the destination is None.
=======
Gotchas
=======
@@ -174,10 +177,48 @@ from twisted.internet import defer
from synapse.config import ConfigError
# Helper class
class _DummyTagNames(object):
"""wrapper of opentracings tags. We need to have them if we
want to reference them without opentracing around. Clearly they
should never actually show up in a trace. `set_tags` overwrites
these with the correct ones."""
INVALID_TAG = "invalid-tag"
COMPONENT = INVALID_TAG
DATABASE_INSTANCE = INVALID_TAG
DATABASE_STATEMENT = INVALID_TAG
DATABASE_TYPE = INVALID_TAG
DATABASE_USER = INVALID_TAG
ERROR = INVALID_TAG
HTTP_METHOD = INVALID_TAG
HTTP_STATUS_CODE = INVALID_TAG
HTTP_URL = INVALID_TAG
MESSAGE_BUS_DESTINATION = INVALID_TAG
PEER_ADDRESS = INVALID_TAG
PEER_HOSTNAME = INVALID_TAG
PEER_HOST_IPV4 = INVALID_TAG
PEER_HOST_IPV6 = INVALID_TAG
PEER_PORT = INVALID_TAG
PEER_SERVICE = INVALID_TAG
SAMPLING_PRIORITY = INVALID_TAG
SERVICE = INVALID_TAG
SPAN_KIND = INVALID_TAG
SPAN_KIND_CONSUMER = INVALID_TAG
SPAN_KIND_PRODUCER = INVALID_TAG
SPAN_KIND_RPC_CLIENT = INVALID_TAG
SPAN_KIND_RPC_SERVER = INVALID_TAG
try:
import opentracing
tags = opentracing.tags
except ImportError:
opentracing = None
tags = _DummyTagNames
try:
from jaeger_client import Config as JaegerConfig
from synapse.logging.scopecontextmanager import LogContextScopeManager
@@ -252,10 +293,6 @@ def init_tracer(config):
scope_manager=LogContextScopeManager(config),
).initialize_tracer()
# Set up tags to be opentracing's tags
global tags
tags = opentracing.tags
# Whitelisting
@@ -334,8 +371,8 @@ def start_active_span_follows_from(operation_name, contexts):
return scope
def start_active_span_from_context(
headers,
def start_active_span_from_request(
request,
operation_name,
references=None,
tags=None,
@@ -344,9 +381,9 @@ def start_active_span_from_context(
finish_on_close=True,
):
"""
Extracts a span context from Twisted Headers.
Extracts a span context from a Twisted Request.
args:
headers (twisted.web.http_headers.Headers)
headers (twisted.web.http.Request)
For the other args see opentracing.tracer
@@ -360,7 +397,9 @@ def start_active_span_from_context(
if opentracing is None:
return _noop_context_manager()
header_dict = {k.decode(): v[0].decode() for k, v in headers.getAllRawHeaders()}
header_dict = {
k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
}
context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
return opentracing.tracer.start_active_span(
@@ -448,7 +487,7 @@ def set_operation_name(operation_name):
@only_if_tracing
def inject_active_span_twisted_headers(headers, destination):
def inject_active_span_twisted_headers(headers, destination, check_destination=True):
"""
Injects a span context into twisted headers in-place
@@ -467,7 +506,7 @@ def inject_active_span_twisted_headers(headers, destination):
https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
"""
if not whitelisted_homeserver(destination):
if check_destination and not whitelisted_homeserver(destination):
return
span = opentracing.tracer.active_span
@@ -479,7 +518,7 @@ def inject_active_span_twisted_headers(headers, destination):
@only_if_tracing
def inject_active_span_byte_dict(headers, destination):
def inject_active_span_byte_dict(headers, destination, check_destination=True):
"""
Injects a span context into a dict where the headers are encoded as byte
strings
@@ -511,7 +550,7 @@ def inject_active_span_byte_dict(headers, destination):
@only_if_tracing
def inject_active_span_text_map(carrier, destination=None):
def inject_active_span_text_map(carrier, destination, check_destination=True):
"""
Injects a span context into a dict
@@ -532,7 +571,7 @@ def inject_active_span_text_map(carrier, destination=None):
https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
"""
if destination and not whitelisted_homeserver(destination):
if check_destination and not whitelisted_homeserver(destination):
return
opentracing.tracer.inject(
@@ -540,6 +579,29 @@ def inject_active_span_text_map(carrier, destination=None):
)
def get_active_span_text_map(destination=None):
"""
Gets a span context as a dict. This can be used instead of manually
injecting a span into an empty carrier.
Args:
destination (str): the name of the remote server.
Returns:
dict: the active span's context if opentracing is enabled, otherwise empty.
"""
if not opentracing or (destination and not whitelisted_homeserver(destination)):
return {}
carrier = {}
opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
)
return carrier
def active_span_context_as_string():
"""
Returns:
@@ -689,65 +751,43 @@ def tag_args(func):
return _tag_args_inner
def trace_servlet(servlet_name, func):
def trace_servlet(servlet_name, extract_context=False):
"""Decorator which traces a serlet. It starts a span with some servlet specific
tags such as the servlet_name and request information"""
if not opentracing:
return func
tags such as the servlet_name and request information
@wraps(func)
@defer.inlineCallbacks
def _trace_servlet_inner(request, *args, **kwargs):
with start_active_span(
"incoming-client-request",
tags={
Args:
servlet_name (str): The name to be used for the span's operation_name
extract_context (bool): Whether to attempt to extract the opentracing
context from the request the servlet is handling.
"""
def _trace_servlet_inner_1(func):
if not opentracing:
return func
@wraps(func)
@defer.inlineCallbacks
def _trace_servlet_inner(request, *args, **kwargs):
request_tags = {
"request_id": request.get_request_id(),
tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
tags.HTTP_METHOD: request.get_method(),
tags.HTTP_URL: request.get_redacted_uri(),
tags.PEER_HOST_IPV6: request.getClientIP(),
"servlet_name": servlet_name,
},
):
result = yield defer.maybeDeferred(func, request, *args, **kwargs)
return result
}
return _trace_servlet_inner
if extract_context:
scope = start_active_span_from_request(
request, servlet_name, tags=request_tags
)
else:
scope = start_active_span(servlet_name, tags=request_tags)
with scope:
result = yield defer.maybeDeferred(func, request, *args, **kwargs)
return result
# Helper class
return _trace_servlet_inner
class _DummyTagNames(object):
"""wrapper of opentracings tags. We need to have them if we
want to reference them without opentracing around. Clearly they
should never actually show up in a trace. `set_tags` overwrites
these with the correct ones."""
INVALID_TAG = "invalid-tag"
COMPONENT = INVALID_TAG
DATABASE_INSTANCE = INVALID_TAG
DATABASE_STATEMENT = INVALID_TAG
DATABASE_TYPE = INVALID_TAG
DATABASE_USER = INVALID_TAG
ERROR = INVALID_TAG
HTTP_METHOD = INVALID_TAG
HTTP_STATUS_CODE = INVALID_TAG
HTTP_URL = INVALID_TAG
MESSAGE_BUS_DESTINATION = INVALID_TAG
PEER_ADDRESS = INVALID_TAG
PEER_HOSTNAME = INVALID_TAG
PEER_HOST_IPV4 = INVALID_TAG
PEER_HOST_IPV6 = INVALID_TAG
PEER_PORT = INVALID_TAG
PEER_SERVICE = INVALID_TAG
SAMPLING_PRIORITY = INVALID_TAG
SERVICE = INVALID_TAG
SPAN_KIND = INVALID_TAG
SPAN_KIND_CONSUMER = INVALID_TAG
SPAN_KIND_PRODUCER = INVALID_TAG
SPAN_KIND_RPC_CLIENT = INVALID_TAG
SPAN_KIND_RPC_SERVER = INVALID_TAG
tags = _DummyTagNames
return _trace_servlet_inner_1
+3 -3
View File
@@ -47,9 +47,9 @@ REQUIREMENTS = [
"idna>=2.5",
# validating SSL certs for IP addresses requires service_identity 18.1.
"service_identity>=18.1.0",
# our logcontext handling relies on the ability to cancel inlineCallbacks
# (https://twistedmatrix.com/trac/ticket/4632) which landed in Twisted 18.7.
"Twisted>=18.7.0",
# Twisted 18.9 introduces some logger improvements that the structured
# logger utilises
"Twisted>=18.9.0",
"treq>=15.1",
# Twisted has required pyopenssl 16.0 since about Twisted 16.6.
"pyopenssl>=16.0.0",
+14 -2
View File
@@ -22,6 +22,7 @@ from six.moves import urllib
from twisted.internet import defer
import synapse.logging.opentracing as opentracing
from synapse.api.errors import (
CodeMessageException,
HttpResponseException,
@@ -165,8 +166,12 @@ class ReplicationEndpoint(object):
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
headers = {}
opentracing.inject_active_span_byte_dict(
headers, None, check_destination=False
)
try:
result = yield request_func(uri, data)
result = yield request_func(uri, data, headers=headers)
break
except CodeMessageException as e:
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
@@ -205,7 +210,14 @@ class ReplicationEndpoint(object):
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths(method, [pattern], handler, self.__class__.__name__)
http_server.register_paths(
method,
[pattern],
opentracing.trace_servlet(self.__class__.__name__, extract_context=True)(
handler
),
self.__class__.__name__,
)
def _cached_handler(self, request, txn_id, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
+5 -1
View File
@@ -42,7 +42,9 @@ from synapse.rest.admin._base import (
historical_admin_path_patterns,
)
from synapse.rest.admin.media import register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.users import UserAdminServlet
from synapse.types import UserID, create_requester
from synapse.util.versionstring import get_version_string
@@ -50,7 +52,7 @@ logger = logging.getLogger(__name__)
class UsersRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)")
PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
def __init__(self, hs):
self.hs = hs
@@ -738,8 +740,10 @@ def register_servlets(hs, http_server):
Register all the admin servlets.
"""
register_servlets_for_client_rest_resource(hs, http_server)
PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
+57
View File
@@ -0,0 +1,57 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.rest.admin import assert_requester_is_admin
class PurgeRoomServlet(RestServlet):
"""Servlet which will remove all trace of a room from the database
POST /_synapse/admin/v1/purge_room
{
"room_id": "!room:id"
}
returns:
{}
"""
PATTERNS = (re.compile("^/_synapse/admin/v1/purge_room$"),)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
self.hs = hs
self.auth = hs.get_auth()
self.pagination_handler = hs.get_pagination_handler()
async def on_POST(self, request):
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("room_id",))
await self.pagination_handler.purge_room(body["room_id"])
return (200, {})
+100
View File
@@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.rest.admin import assert_requester_is_admin, assert_user_is_admin
from synapse.types import UserID
class UserAdminServlet(RestServlet):
"""
Get or set whether or not a user is a server administrator.
Note that only local users can be server administrators, and that an
administrator may not demote themselves.
Only server administrators can use this API.
Examples:
* Get
GET /_synapse/admin/v1/users/@nonadmin:example.com/admin
response on success:
{
"admin": false
}
* Set
PUT /_synapse/admin/v1/users/@reivilibre:librepush.net/admin
request body:
{
"admin": true
}
response on success:
{}
"""
PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>@[^/]*)/admin$"),)
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
yield assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Only local users can be admins of this homeserver")
is_admin = yield self.handlers.admin_handler.get_user_server_admin(target_user)
is_admin = bool(is_admin)
return (200, {"admin": is_admin})
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
requester = yield self.auth.get_user_by_req(request)
yield assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
target_user = UserID.from_string(user_id)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["admin"])
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Only local users can be admins of this homeserver")
set_admin_to = bool(body["admin"])
if target_user == auth_user and not set_admin_to:
raise SynapseError(400, "You may not demote yourself.")
yield self.handlers.admin_handler.set_user_server_admin(
target_user, set_admin_to
)
return (200, {})
+12 -1
View File
@@ -24,6 +24,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
from synapse.logging.opentracing import log_kv, set_tag, trace_using_operation_name
from synapse.types import StreamToken
from ._base import client_patterns
@@ -68,6 +69,7 @@ class KeyUploadServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@trace_using_operation_name("upload_keys")
@defer.inlineCallbacks
def on_POST(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -78,6 +80,14 @@ class KeyUploadServlet(RestServlet):
# passing the device_id here is deprecated; however, we allow it
# for now for compatibility with older clients.
if requester.device_id is not None and device_id != requester.device_id:
set_tag("error", True)
log_kv(
{
"message": "Client uploading keys for a different device",
"logged_in_id": requester.device_id,
"key_being_uploaded": device_id,
}
)
logger.warning(
"Client uploading keys for a different device "
"(logged in as %s, uploading for %s)",
@@ -178,10 +188,11 @@ class KeyChangesServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from")
set_tag("from", from_token_string)
# We want to enforce they do pass us one, but we ignore it and return
# changes after the "to" as well as before.
parse_string(request, "to")
set_tag("to", parse_string(request, "to"))
from_token = StreamToken.from_string(from_token_string)
+4 -53
View File
@@ -16,7 +16,6 @@
import hmac
import logging
from hashlib import sha1
from six import string_types
@@ -239,14 +238,12 @@ class RegisterRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
desired_password = None
if "password" in body:
if (
not isinstance(body["password"], string_types)
or len(body["password"]) > 512
):
raise SynapseError(400, "Invalid password")
desired_password = body["password"]
desired_username = None
if "username" in body:
@@ -261,8 +258,8 @@ class RegisterRestServlet(RestServlet):
if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
# have completely different registration flows to normal users
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
# == Application Service Registration ==
if appservice:
@@ -285,8 +282,8 @@ class RegisterRestServlet(RestServlet):
return (200, result) # we throw for non 200 responses
return
# for either shared secret or regular registration, downcase the
# provided username before attempting to register it. This should mean
# for regular registration, downcase the provided username before
# attempting to register it. This should mean
# that people who try to register with upper-case in their usernames
# don't get a nasty surprise. (Note that we treat username
# case-insenstively in login, so they are free to carry on imagining
@@ -294,16 +291,6 @@ class RegisterRestServlet(RestServlet):
if desired_username is not None:
desired_username = desired_username.lower()
# == Shared Secret Registration == (e.g. create new user scripts)
if "mac" in body:
# FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration(
desired_username, desired_password, body
)
return (200, result) # we throw for non 200 responses
return
# == Normal User Registration == (everyone else)
if not self.hs.config.enable_registration:
raise SynapseError(403, "Registration has been disabled")
@@ -512,42 +499,6 @@ class RegisterRestServlet(RestServlet):
)
return (yield self._create_registration_details(user_id, body))
@defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, body):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
if not username:
raise SynapseError(
400, "username must be specified", errcode=Codes.BAD_JSON
)
# use the username from the original request rather than the
# downcased one in `username` for the mac calculation
user = body["username"].encode("utf-8")
# str() because otherwise hmac complains that 'unicode' does not
# have the buffer interface
got_mac = str(body["mac"])
# FIXME this is different to the /v1/register endpoint, which
# includes the password and admin flag in the hashed text. Why are
# these different?
want_mac = hmac.new(
key=self.hs.config.registration_shared_secret.encode(),
msg=user,
digestmod=sha1,
).hexdigest()
if not compare_digest(want_mac, got_mac):
raise SynapseError(403, "HMAC incorrect")
user_id = yield self.registration_handler.register_user(
localpart=username, password=password
)
result = yield self._create_registration_details(user_id, body)
return result
@defer.inlineCallbacks
def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
+14 -12
View File
@@ -13,7 +13,9 @@
# limitations under the License.
import logging
from io import BytesIO
from canonicaljson import encode_canonical_json, json
from signedjson.sign import sign_json
from twisted.internet import defer
@@ -95,6 +97,7 @@ class RemoteKey(DirectServeResource):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
self.config = hs.config
@wrap_json_request_handler
async def _async_render_GET(self, request):
@@ -214,15 +217,14 @@ class RemoteKey(DirectServeResource):
yield self.fetcher.get_keys(cache_misses)
yield self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
result_io = BytesIO()
result_io.write(b'{"server_keys":')
sep = b"["
for json_bytes in json_results:
result_io.write(sep)
result_io.write(json_bytes)
sep = b","
if sep == b"[":
result_io.write(sep)
result_io.write(b"]}")
signed_keys = []
for key_json in json_results:
key_json = json.loads(key_json)
for signing_key in self.config.key_server_signing_keys:
key_json = sign_json(key_json, self.config.server_name, signing_key)
respond_with_json_bytes(request, 200, result_io.getvalue())
signed_keys.append(key_json)
results = {"server_keys": signed_keys}
respond_with_json_bytes(request, 200, encode_canonical_json(results))
+1 -1
View File
@@ -34,7 +34,7 @@ class WellKnownBuilder(object):
self._config = hs.config
def get_well_known(self):
# if we don't have a public_base_url, we can't help much here.
# if we don't have a public_baseurl, we can't help much here.
if self._config.public_baseurl is None:
return None
+16 -8
View File
@@ -1395,14 +1395,22 @@ class SQLBaseStore(object):
"""
txn.call_after(self._invalidate_state_caches, room_id, members_changed)
# We need to be careful that the size of the `members_changed` list
# isn't so large that it causes problems sending over replication, so we
# send them in chunks.
# Max line length is 16K, and max user ID length is 255, so 50 should
# be safe.
for chunk in batch_iter(members_changed, 50):
keys = itertools.chain([room_id], chunk)
self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys)
if members_changed:
# We need to be careful that the size of the `members_changed` list
# isn't so large that it causes problems sending over replication, so we
# send them in chunks.
# Max line length is 16K, and max user ID length is 255, so 50 should
# be safe.
for chunk in batch_iter(members_changed, 50):
keys = itertools.chain([room_id], chunk)
self._send_invalidation_to_replication(
txn, _CURRENT_STATE_CACHE_NAME, keys
)
else:
# if no members changed, we still need to invalidate the other caches.
self._send_invalidation_to_replication(
txn, _CURRENT_STATE_CACHE_NAME, [room_id]
)
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
+33 -6
View File
@@ -21,6 +21,11 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.logging.opentracing import (
get_active_span_text_map,
trace,
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import Cache, SQLBaseStore, db_to_json
from synapse.storage.background_updates import BackgroundUpdateStore
@@ -73,6 +78,7 @@ class DeviceWorkerStore(SQLBaseStore):
return {d["device_id"]: d for d in devices}
@trace
@defer.inlineCallbacks
def get_devices_by_remote(self, destination, from_stream_id, limit):
"""Get stream of updates to send to remote servers
@@ -127,8 +133,15 @@ class DeviceWorkerStore(SQLBaseStore):
# (user_id, device_id) entries into a map, with the value being
# the max stream_id across each set of duplicate entries
#
# maps (user_id, device_id) -> stream_id
# maps (user_id, device_id) -> (stream_id, opentracing_context)
# as long as their stream_id does not match that of the last row
#
# opentracing_context contains the opentracing metadata for the request
# that created the poke
#
# The most recent request's opentracing_context is used as the
# context which created the Edu.
query_map = {}
for update in updates:
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
@@ -136,7 +149,14 @@ class DeviceWorkerStore(SQLBaseStore):
break
key = (update[0], update[1])
query_map[key] = max(query_map.get(key, 0), update[2])
update_context = update[3]
update_stream_id = update[2]
previous_update_stream_id, _ = query_map.get(key, (0, None))
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
@@ -171,7 +191,7 @@ class DeviceWorkerStore(SQLBaseStore):
List: List of device updates
"""
sql = """
SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
ORDER BY stream_id
LIMIT ?
@@ -187,8 +207,9 @@ class DeviceWorkerStore(SQLBaseStore):
Args:
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
query_map (Dict[(str, str): int]): Dictionary mapping
user_id/device_id to update stream_id
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
user_id/device_id to update stream_id and the relevent json-encoded
opentracing context
Returns:
List[Dict]: List of objects representing an device update EDU
@@ -210,12 +231,13 @@ class DeviceWorkerStore(SQLBaseStore):
destination, user_id, from_stream_id
)
for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
stream_id, opentracing_context = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
"org.matrix.opentracing_context": opentracing_context,
}
prev_id = stream_id
@@ -814,6 +836,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
],
)
context = get_active_span_text_map()
self._simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@@ -825,6 +849,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"device_id": device_id,
"sent": False,
"ts": now,
"opentracing_context": json.dumps(context)
if whitelisted_homeserver(destination)
else None,
}
for destination in hosts
for device_id in device_ids
+14
View File
@@ -18,6 +18,7 @@ import json
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from ._base import SQLBaseStore
@@ -94,7 +95,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
},
lock=False,
)
log_kv(
{
"message": "Set room key",
"room_id": room_id,
"session_id": session_id,
"room_key": room_key,
}
)
@trace
@defer.inlineCallbacks
def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
@@ -153,6 +163,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return sessions
@trace
@defer.inlineCallbacks
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
@@ -236,6 +247,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
@trace
def create_e2e_room_keys_version(self, user_id, info):
"""Atomically creates a new version of this user's e2e_room_keys store
with the given version info.
@@ -276,6 +288,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
@trace
def update_e2e_room_keys_version(self, user_id, version, info):
"""Update a given backup version
@@ -292,6 +305,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
desc="update_e2e_room_keys_version",
)
@trace
def delete_e2e_room_keys_version(self, user_id, version=None):
"""Delete a given backup version of the user's room keys.
Doesn't delete their actual key data.
+35 -3
View File
@@ -18,12 +18,14 @@ from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore, db_to_json
class EndToEndKeyWorkerStore(SQLBaseStore):
@trace
@defer.inlineCallbacks
def get_e2e_device_keys(
self, query_list, include_all_devices=False, include_deleted_devices=False
@@ -40,6 +42,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name".
"""
set_tag("query_list", query_list)
if not query_list:
return {}
@@ -57,9 +60,13 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return results
@trace
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
):
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)
query_clauses = []
query_params = []
@@ -104,6 +111,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None
log_kv(result)
return result
@defer.inlineCallbacks
@@ -129,8 +137,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
keyvalues={"user_id": user_id, "device_id": device_id},
desc="add_e2e_one_time_keys_check",
)
return {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
return result
@defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
@@ -146,6 +155,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"""
def _add_e2e_one_time_keys(txn):
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", new_keys)
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
@@ -202,6 +214,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"""
def _set_e2e_device_keys_txn(txn):
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", device_keys)
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
@@ -215,6 +232,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
log_kv({"Message": "Device key already stored."})
return False
self._simple_upsert_txn(
@@ -223,7 +241,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
keyvalues={"user_id": user_id, "device_id": device_id},
values={"ts_added_ms": time_now, "key_json": new_key_json},
)
log_kv({"message": "Device keys stored."})
return True
return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
@@ -231,6 +249,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
@trace
def _claim_e2e_one_time_keys(txn):
sql = (
"SELECT key_id, key_json FROM e2e_one_time_keys_json"
@@ -252,7 +271,13 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
" AND key_id = ?"
)
for user_id, device_id, algorithm, key_id in delete:
log_kv(
{
"message": "Executing claim e2e_one_time_keys transaction on database."
}
)
txn.execute(sql, (user_id, device_id, algorithm, key_id))
log_kv({"message": "finished executing and invalidating cache"})
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
@@ -262,6 +287,13 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def delete_e2e_keys_by_device(self, user_id, device_id):
def delete_e2e_keys_by_device_txn(txn):
log_kv(
{
"message": "Deleting keys for device",
"device_id": device_id,
"user_id": user_id,
}
)
self._simple_delete_txn(
txn,
table="e2e_device_keys_json",
+137
View File
@@ -2181,6 +2181,143 @@ class EventsStore(
return to_delete, to_dedelta
def purge_room(self, room_id):
"""Deletes all record of a room
Args:
room_id (str):
"""
return self.runInteraction("purge_room", self._purge_room_txn, room_id)
def _purge_room_txn(self, txn, room_id):
# first we have to delete the state groups states
logger.info("[purge] removing %s from state_groups_state", room_id)
txn.execute(
"""
DELETE FROM state_groups_state WHERE state_group IN (
SELECT state_group FROM events JOIN event_to_state_groups USING(event_id)
WHERE events.room_id=?
)
""",
(room_id,),
)
# ... and the state group edges
logger.info("[purge] removing %s from state_group_edges", room_id)
txn.execute(
"""
DELETE FROM state_group_edges WHERE state_group IN (
SELECT state_group FROM events JOIN event_to_state_groups USING(event_id)
WHERE events.room_id=?
)
""",
(room_id,),
)
# ... and the state groups
logger.info("[purge] removing %s from state_groups", room_id)
txn.execute(
"""
DELETE FROM state_groups WHERE id IN (
SELECT state_group FROM events JOIN event_to_state_groups USING(event_id)
WHERE events.room_id=?
)
""",
(room_id,),
)
# and then tables which lack an index on room_id but have one on event_id
for table in (
"event_auth",
"event_edges",
"event_push_actions_staging",
"event_reference_hashes",
"event_relations",
"event_to_state_groups",
"redactions",
"rejections",
"state_events",
):
logger.info("[purge] removing %s from %s", room_id, table)
txn.execute(
"""
DELETE FROM %s WHERE event_id IN (
SELECT event_id FROM events WHERE room_id=?
)
"""
% (table,),
(room_id,),
)
# and finally, the tables with an index on room_id (or no useful index)
for table in (
"current_state_events",
"event_backward_extremities",
"event_forward_extremities",
"event_json",
"event_push_actions",
"event_search",
"events",
"group_rooms",
"public_room_list_stream",
"receipts_graph",
"receipts_linearized",
"room_aliases",
"room_depth",
"room_memberships",
"room_state",
"room_stats",
"room_stats_earliest_token",
"rooms",
"stream_ordering_to_exterm",
"topics",
"users_in_public_rooms",
"users_who_share_private_rooms",
# no useful index, but let's clear them anyway
"appservice_room_list",
"e2e_room_keys",
"event_push_summary",
"pusher_throttle",
"group_summary_rooms",
"local_invites",
"room_account_data",
"room_tags",
):
logger.info("[purge] removing %s from %s", room_id, table)
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
# Other tables we do NOT need to clear out:
#
# - blocked_rooms
# This is important, to make sure that we don't accidentally rejoin a blocked
# room after it was purged
#
# - user_directory
# This has a room_id column, but it is unused
#
# Other tables that we might want to consider clearing out include:
#
# - event_reports
# Given that these are intended for abuse management my initial
# inclination is to leave them in place.
#
# - current_state_delta_stream
# - ex_outlier_stream
# - room_tags_revisions
# The problem with these is that they are largeish and there is no room_id
# index on them. In any case we should be clearing out 'stream' tables
# periodically anyway (#5888)
# TODO: we could probably usefully do a bunch of cache invalidation here
logger.info("[purge] done")
@defer.inlineCallbacks
def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
+19 -5
View File
@@ -238,6 +238,13 @@ def _upgrade_existing_database(
logger.debug("applied_delta_files: %s", applied_delta_files)
if isinstance(database_engine, PostgresEngine):
specific_engine_extension = ".postgres"
else:
specific_engine_extension = ".sqlite"
specific_engine_extensions = (".sqlite", ".postgres")
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.info("Upgrading schema to v%d", v)
@@ -274,15 +281,22 @@ def _upgrade_existing_database(
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
pass
continue
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.info("Applying schema %s", relative_path)
executescript(cur, absolute_path)
elif ext == specific_engine_extension and root_name.endswith(".sql"):
# A .sql file specific to our engine; just read and execute it
logger.info("Applying engine-specific schema %s", relative_path)
executescript(cur, absolute_path)
elif ext in specific_engine_extensions and root_name.endswith(".sql"):
# A .sql file for a different engine; skip it.
continue
else:
# Not a valid delta file.
logger.warn(
"Found directory entry that did not end in .py or" " .sql: %s",
logger.warning(
"Found directory entry that did not end in .py or .sql: %s",
relative_path,
)
continue
@@ -290,7 +304,7 @@ def _upgrade_existing_database(
# Mark as done.
cur.execute(
database_engine.convert_param_style(
"INSERT INTO applied_schema_deltas (version, file)" " VALUES (?,?)"
"INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
),
(v, relative_path),
)
@@ -298,7 +312,7 @@ def _upgrade_existing_database(
cur.execute("DELETE FROM schema_version")
cur.execute(
database_engine.convert_param_style(
"INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
"INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
),
(v, True),
)
+24
View File
@@ -56,6 +56,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"consent_server_notice_sent",
"appservice_id",
"creation_ts",
"user_type",
],
allow_none=True,
desc="get_user_by_id",
@@ -272,6 +273,14 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def is_server_admin(self, user):
"""Determines if a user is an admin of this homeserver.
Args:
user (UserID): user ID of the user to test
Returns (bool):
true iff the user is a server admin, false otherwise.
"""
res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
@@ -282,6 +291,21 @@ class RegistrationWorkerStore(SQLBaseStore):
return res if res else False
def set_server_admin(self, user, admin):
"""Sets whether a user is an admin of this homeserver.
Args:
user (UserID): user ID of the user to test
admin (bool): true iff the user is to be a server admin,
false otherwise.
"""
return self._simple_update_one(
table="users",
keyvalues={"name": user.to_string()},
updatevalues={"admin": 1 if admin else 0},
desc="set_server_admin",
)
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
@@ -0,0 +1,20 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Opentracing context data for inclusion in the device_list_update EDUs, as a
* json-encoded dictionary. NULL if opentracing is disabled (or not enabled for this destination).
*/
ALTER TABLE device_lists_outbound_pokes ADD opentracing_context TEXT;
@@ -0,0 +1,17 @@
/* Copyright 2019 Matrix.org Foundation CIC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- this was apparently forgotten when the table was created back in delta 53.
CREATE INDEX users_in_public_rooms_r_idx ON users_in_public_rooms(room_id);
+10 -3
View File
@@ -30,6 +30,8 @@ from six import iteritems
import yaml
from synapse.config import find_config_files
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
GREEN = "\x1b[1;32m"
@@ -135,7 +137,8 @@ def main():
"configfile",
nargs="?",
default="homeserver.yaml",
help="the homeserver config file, defaults to homeserver.yaml",
help="the homeserver config file. Defaults to homeserver.yaml. May also be"
" a directory with *.yaml files",
)
parser.add_argument(
"-w", "--worker", metavar="WORKERCONFIG", help="start or stop a single worker"
@@ -176,8 +179,12 @@ def main():
)
sys.exit(1)
with open(configfile) as stream:
config = yaml.safe_load(stream)
config_files = find_config_files([configfile])
config = {}
for config_file in config_files:
with open(config_file) as file_stream:
yaml_config = yaml.safe_load(file_stream)
config.update(yaml_config)
pidfile = config["pid_file"]
cache_factor = config.get("synctl_cache_factor")

Some files were not shown because too many files have changed in this diff Show More