1
0

Compare commits

..

50 Commits

Author SHA1 Message Date
Erik Johnston
2ca8af3f06 Bump version 2019-03-08 11:35:51 +00:00
Erik Johnston
b050a10871 Merge pull request #4699 from matrix-org/erikj/stop_fed_not_in_room
Stop backpaginating when events not visible
2019-03-05 09:32:33 +00:00
Erik Johnston
9e8bca5667 Merge pull request #4799 from matrix-org/rav/clean_up_replication_code
Clean ups in replication notifier
2019-03-05 09:19:48 +00:00
Erik Johnston
aa06d26ae0 clarify comments 2019-03-05 09:16:35 +00:00
Erik Johnston
c3c542bb4a Merge pull request #4796 from matrix-org/erikj/factor_out_e2e_keys
Allow /keys/{changes,query} API to run on worker
2019-03-05 09:06:25 +00:00
Richard van der Hoff
48583cef7e Merge pull request #4798 from matrix-org/rav/rr_debug
Add some debug about processing read receipts.
2019-03-04 19:04:05 +00:00
Richard van der Hoff
cd7110c869 Merge pull request #4797 from matrix-org/rav/inline_rr_send
Clean up read-receipt handling.
2019-03-04 19:03:40 +00:00
Richard van der Hoff
eaa9f43603 changelog 2019-03-04 19:01:19 +00:00
Richard van der Hoff
c7325776a7 Remove redundant PreserveLoggingContext
Both (!) things that register as replication listeners do the right thing wrt
logcontexts, so this is redundant.
2019-03-04 18:31:18 +00:00
Erik Johnston
00b0e8b7df Newsfile 2019-03-04 18:30:07 +00:00
Erik Johnston
bfa7d46a10 Allow /keys/{changes,query} API to run on worker 2019-03-04 18:30:01 +00:00
Erik Johnston
157e5a8f27 Split DeviceHandler into master and worker 2019-03-04 18:29:26 +00:00
Richard van der Hoff
daa10e3e66 Remove unused wait_for_replication method
I guess this was used once? It's not now, anyway.
2019-03-04 18:27:32 +00:00
Richard van der Hoff
2db49ea476 Add some debug about processing read receipts.
I'm hoping to establish which rooms are having lots of RRs sent for them, and
how old the events are when they are sent.
2019-03-04 18:19:40 +00:00
Richard van der Hoff
b29693a30b Clean up read-receipt handling.
Remove a call to run_as_background_process: there is no need to run this as a
background process, because build_and_send_edu does not block.

We may as well inline the whole of _push_remotes.
2019-03-04 18:16:43 +00:00
Erik Johnston
a84b8d56c2 Fixup slave stores 2019-03-04 18:04:57 +00:00
Richard van der Hoff
8e28bc5eee Include a default configuration file in the 'docs' directory. (#4791) 2019-03-04 17:14:58 +00:00
Erik Johnston
0d2d046709 Fix missing null guard 2019-03-04 16:04:04 +00:00
Erik Johnston
d1523aed6b Only check history visibility when filtering
When filtering events to send to server we check more than just history
visibility. However when deciding whether to backfill or not we only
care about the history visibility.
2019-03-04 14:43:42 +00:00
Seebi
aba5eeabd5 Fix v4v6 option in HAProxy example config (#4790)
The v4v6 option only has a usage one ipv6 socket: https://serverfault.com/q/747895

Signed-off-by: Flakebi <flakebi@t-online.de>
2019-03-04 13:19:41 +00:00
Richard van der Hoff
856c83f5f8 Avoid rebuilding Edu objects in worker mode (#4770)
In worker mode, on the federation sender, when we receive an edu for sending
over the replication socket, it is parsed into an Edu object. There is no point
extracting the contents of it so that we can then immediately build another Edu.
2019-03-04 12:57:44 +00:00
Erik Johnston
8b63fe4c26 s/get_forward_events/get_successor_events/ 2019-03-04 11:56:03 +00:00
Erik Johnston
fbc047f2a5 Merge branch 'develop' of github.com:matrix-org/synapse into erikj/stop_fed_not_in_room 2019-03-04 11:54:58 +00:00
Richard van der Hoff
2c3548d9d8 Update test_typing to use HomeserverTestCase. (#4771) 2019-03-04 10:05:39 +00:00
Richard van der Hoff
3064952939 Fix incorrect log about not persisting duplicate state event. (#4776)
We were logging this when it was not true.
2019-03-01 16:47:12 +00:00
Richard van der Hoff
1beebe916f Merge branch 'master' into develop 2019-03-01 10:58:39 +00:00
Andrew Morgan
ac61b45a75 Minor docstring fixes for MatrixFederationAgent (#4765) 2019-02-28 16:24:01 +00:00
Amber Brown
b131cc77df Make 'event_id' a required parameter in federated state requests (#4741)
* make 'event_id' a required parameter in federated state requests

As per the spec: https://matrix.org/docs/spec/server_server/r0.1.1.html#id40

Signed-off-by: Joseph Weston <joseph@weston.cloud>

* add changelog entry for bugfix

Signed-off-by: Joseph Weston <joseph@weston.cloud>

* Update server.py
2019-02-27 14:35:47 -08:00
Richard van der Hoff
68f47d6744 Fix parsing of Content-Disposition headers (#4763)
* Fix parsing of Content-Disposition headers

TIL: filenames in content-dispostion headers can contain semicolons, and aren't
%-encoded.

* fix python2 incompatibility

* Fix docstrings
2019-02-27 14:29:10 -08:00
Amber Brown
f2a753ea38 Move from TravisCI to BuildKite (#4752) 2019-02-27 13:03:14 -08:00
Erik Johnston
76550c58d2 Merge pull request #4759 from matrix-org/erikj/3pid_client_reader
Move /account/3pid to client_reader
2019-02-27 16:11:21 +00:00
Erik Johnston
8267034a63 Merge pull request #4758 from matrix-org/erikj/use_presence_replication
When presence is disabled don't send over replication
2019-02-27 15:46:26 +00:00
Richard van der Hoff
3134964054 Update changelog.d/4759.feature
Co-Authored-By: erikjohnston <erikj@jki.re>
2019-02-27 15:32:21 +00:00
Erik Johnston
46b0151524 Merge pull request #4757 from matrix-org/erikj/key_api_fed_readae
Move server key queries to federation reader
2019-02-27 15:30:40 +00:00
Erik Johnston
95840d84d4 Newsfile 2019-02-27 14:28:52 +00:00
Erik Johnston
54f9ce11a7 Move /account/3pid to client_reader 2019-02-27 14:26:08 +00:00
Erik Johnston
d4dc527a1a Fix unit tests 2019-02-27 14:24:45 +00:00
Erik Johnston
1b2940b3bd Newsfile 2019-02-27 13:54:45 +00:00
Erik Johnston
1e315017d3 When presence is enabled don't send over replication 2019-02-27 13:53:46 +00:00
Erik Johnston
b5c13df0c4 Newsfile 2019-02-27 13:46:29 +00:00
Erik Johnston
4cff9376f7 Move server key queries to federation reader 2019-02-27 13:43:53 +00:00
Erik Johnston
71ef5fc411 Update newsfile to have a full stop 2019-02-27 13:22:23 +00:00
Erik Johnston
b183fef9ac Update comments 2019-02-27 13:06:10 +00:00
Erik Johnston
7590e9fa28 Merge pull request #4749 from matrix-org/erikj/replication_connection_backoff
Fix tightloop over connecting to replication server
2019-02-27 11:00:59 +00:00
Erik Johnston
6870fc496f Move connecting logic into ClientReplicationStreamProtocol 2019-02-27 10:23:51 +00:00
Erik Johnston
09fc34c935 Newsfile 2019-02-26 15:13:55 +00:00
Erik Johnston
25814921f1 Increase the max delay between retry attempts
Otherwise if you have many workers they can easily take out master with
their connection attempts
2019-02-26 15:12:33 +00:00
Erik Johnston
313987187e Fix tightloop over connecting to replication server
If the client failed to process incoming commands during the initial set
up of the replication connection it would immediately disconnect and
reconnect, resulting in a tightloop.

This can happen, for example, when subscribing to a stream that has a
row that is too long in the backlog.

The fix here is to not consider the connection successfully set up until
the client has succesfully subscribed and caught up with the streams.
This ensures that the retry logic timers aren't reset until then,
meaning that if an error does happen during start up the client will
continue backing off before retrying again.
2019-02-26 15:05:41 +00:00
Erik Johnston
56f4ece778 Newsfile 2019-02-20 18:15:13 +00:00
Erik Johnston
71b625d808 Stop backpaginating when events not visible 2019-02-20 18:14:12 +00:00
77 changed files with 2847 additions and 1293 deletions

13
.buildkite/.env Normal file
View File

@@ -0,0 +1,13 @@
CI
BUILDKITE
BUILDKITE_BUILD_NUMBER
BUILDKITE_BRANCH
BUILDKITE_BUILD_NUMBER
BUILDKITE_JOB_ID
BUILDKITE_BUILD_URL
BUILDKITE_PROJECT_SLUG
BUILDKITE_COMMIT
BUILDKITE_PULL_REQUEST
BUILDKITE_TAG
CODECOV_TOKEN
TRIAL_FLAGS

View File

@@ -0,0 +1,21 @@
version: '3.1'
services:
postgres:
image: postgres:9.4
environment:
POSTGRES_PASSWORD: postgres
testenv:
image: python:2.7
depends_on:
- postgres
env_file: .env
environment:
SYNAPSE_POSTGRES_HOST: postgres
SYNAPSE_POSTGRES_USER: postgres
SYNAPSE_POSTGRES_PASSWORD: postgres
working_dir: /app
volumes:
- ..:/app

View File

@@ -0,0 +1,21 @@
version: '3.1'
services:
postgres:
image: postgres:9.5
environment:
POSTGRES_PASSWORD: postgres
testenv:
image: python:2.7
depends_on:
- postgres
env_file: .env
environment:
SYNAPSE_POSTGRES_HOST: postgres
SYNAPSE_POSTGRES_USER: postgres
SYNAPSE_POSTGRES_PASSWORD: postgres
working_dir: /app
volumes:
- ..:/app

View File

@@ -0,0 +1,21 @@
version: '3.1'
services:
postgres:
image: postgres:9.4
environment:
POSTGRES_PASSWORD: postgres
testenv:
image: python:3.5
depends_on:
- postgres
env_file: .env
environment:
SYNAPSE_POSTGRES_HOST: postgres
SYNAPSE_POSTGRES_USER: postgres
SYNAPSE_POSTGRES_PASSWORD: postgres
working_dir: /app
volumes:
- ..:/app

View File

@@ -0,0 +1,21 @@
version: '3.1'
services:
postgres:
image: postgres:9.5
environment:
POSTGRES_PASSWORD: postgres
testenv:
image: python:3.5
depends_on:
- postgres
env_file: .env
environment:
SYNAPSE_POSTGRES_HOST: postgres
SYNAPSE_POSTGRES_USER: postgres
SYNAPSE_POSTGRES_PASSWORD: postgres
working_dir: /app
volumes:
- ..:/app

View File

@@ -0,0 +1,21 @@
version: '3.1'
services:
postgres:
image: postgres:11
environment:
POSTGRES_PASSWORD: postgres
testenv:
image: python:3.7
depends_on:
- postgres
env_file: .env
environment:
SYNAPSE_POSTGRES_HOST: postgres
SYNAPSE_POSTGRES_USER: postgres
SYNAPSE_POSTGRES_PASSWORD: postgres
working_dir: /app
volumes:
- ..:/app

View File

@@ -0,0 +1,21 @@
version: '3.1'
services:
postgres:
image: postgres:9.5
environment:
POSTGRES_PASSWORD: postgres
testenv:
image: python:3.7
depends_on:
- postgres
env_file: .env
environment:
SYNAPSE_POSTGRES_HOST: postgres
SYNAPSE_POSTGRES_USER: postgres
SYNAPSE_POSTGRES_PASSWORD: postgres
working_dir: /app
volumes:
- ..:/app

157
.buildkite/pipeline.yml Normal file
View File

@@ -0,0 +1,157 @@
env:
CODECOV_TOKEN: "2dd7eb9b-0eda-45fe-a47c-9b5ac040045f"
steps:
- command:
- "python -m pip install tox"
- "tox -e pep8"
label: "\U0001F9F9 PEP-8"
plugins:
- docker#v3.0.1:
image: "python:3.6"
- command:
- "python -m pip install tox"
- "tox -e packaging"
label: "\U0001F9F9 packaging"
plugins:
- docker#v3.0.1:
image: "python:3.6"
- command:
- "python -m pip install tox"
- "tox -e check_isort"
label: "\U0001F9F9 isort"
plugins:
- docker#v3.0.1:
image: "python:3.6"
- command:
- "python -m pip install tox"
- "scripts-dev/check-newsfragment"
label: ":newspaper: Newsfile"
branches: "!master !develop !release-*"
plugins:
- docker#v3.0.1:
image: "python:3.6"
propagate-environment: true
- wait
- command:
- "python -m pip install tox"
- "tox -e check-sampleconfig"
label: "\U0001F9F9 check-sample-config"
plugins:
- docker#v3.0.1:
image: "python:3.6"
- command:
- "python -m pip install tox"
- "tox -e py27,codecov"
label: ":python: 2.7 / SQLite"
env:
TRIAL_FLAGS: "-j 2"
plugins:
- docker#v3.0.1:
image: "python:2.7"
propagate-environment: true
- command:
- "python -m pip install tox"
- "tox -e py35,codecov"
label: ":python: 3.5 / SQLite"
env:
TRIAL_FLAGS: "-j 2"
plugins:
- docker#v3.0.1:
image: "python:3.5"
propagate-environment: true
- command:
- "python -m pip install tox"
- "tox -e py36,codecov"
label: ":python: 3.6 / SQLite"
env:
TRIAL_FLAGS: "-j 2"
plugins:
- docker#v3.0.1:
image: "python:3.6"
propagate-environment: true
- command:
- "python -m pip install tox"
- "tox -e py37,codecov"
label: ":python: 3.7 / SQLite"
env:
TRIAL_FLAGS: "-j 2"
plugins:
- docker#v3.0.1:
image: "python:3.7"
propagate-environment: true
- label: ":python: 2.7 / :postgres: 9.4"
env:
TRIAL_FLAGS: "-j 4"
command:
- "bash -c 'python -m pip install tox && python -m tox -e py27-postgres,codecov'"
plugins:
- docker-compose#v2.1.0:
run: testenv
config:
- .buildkite/docker-compose.py27.pg94.yaml
- label: ":python: 2.7 / :postgres: 9.5"
env:
TRIAL_FLAGS: "-j 4"
command:
- "bash -c 'python -m pip install tox && python -m tox -e py27-postgres,codecov'"
plugins:
- docker-compose#v2.1.0:
run: testenv
config:
- .buildkite/docker-compose.py27.pg95.yaml
- label: ":python: 3.5 / :postgres: 9.4"
env:
TRIAL_FLAGS: "-j 4"
command:
- "bash -c 'python -m pip install tox && python -m tox -e py35-postgres,codecov'"
plugins:
- docker-compose#v2.1.0:
run: testenv
config:
- .buildkite/docker-compose.py35.pg94.yaml
- label: ":python: 3.5 / :postgres: 9.5"
env:
TRIAL_FLAGS: "-j 4"
command:
- "bash -c 'python -m pip install tox && python -m tox -e py35-postgres,codecov'"
plugins:
- docker-compose#v2.1.0:
run: testenv
config:
- .buildkite/docker-compose.py35.pg95.yaml
- label: ":python: 3.7 / :postgres: 9.5"
env:
TRIAL_FLAGS: "-j 4"
command:
- "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,codecov'"
plugins:
- docker-compose#v2.1.0:
run: testenv
config:
- .buildkite/docker-compose.py37.pg95.yaml
- label: ":python: 3.7 / :postgres: 11"
env:
TRIAL_FLAGS: "-j 4"
command:
- "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,codecov'"
plugins:
- docker-compose#v2.1.0:
run: testenv
config:
- .buildkite/docker-compose.py37.pg11.yaml

View File

@@ -1,97 +0,0 @@
dist: xenial
language: python
cache:
directories:
# we only bother to cache the wheels; parts of the http cache get
# invalidated every build (because they get served with a max-age of 600
# seconds), which means that we end up re-uploading the whole cache for
# every build, which is time-consuming In any case, it's not obvious that
# downloading the cache from S3 would be much faster than downloading the
# originals from pypi.
#
- $HOME/.cache/pip/wheels
# don't clone the whole repo history, one commit will do
git:
depth: 1
# only build branches we care about (PRs are built seperately)
branches:
only:
- master
- develop
- /^release-v/
- rav/pg95
# When running the tox environments that call Twisted Trial, we can pass the -j
# flag to run the tests concurrently. We set this to 2 for CPU bound tests
# (SQLite) and 4 for I/O bound tests (PostgreSQL).
matrix:
fast_finish: true
include:
- name: "pep8"
python: 3.6
env: TOX_ENV="pep8,check_isort,packaging"
- name: "py2.7 / sqlite"
python: 2.7
env: TOX_ENV=py27,codecov TRIAL_FLAGS="-j 2"
- name: "py2.7 / sqlite / olddeps"
python: 2.7
env: TOX_ENV=py27-old TRIAL_FLAGS="-j 2"
- name: "py2.7 / postgres9.5"
python: 2.7
addons:
postgresql: "9.5"
env: TOX_ENV=py27-postgres,codecov TRIAL_FLAGS="-j 4"
services:
- postgresql
- name: "py3.5 / sqlite"
python: 3.5
env: TOX_ENV=py35,codecov TRIAL_FLAGS="-j 2"
- name: "py3.7 / sqlite"
python: 3.7
env: TOX_ENV=py37,codecov TRIAL_FLAGS="-j 2"
- name: "py3.7 / postgres9.4"
python: 3.7
addons:
postgresql: "9.4"
env: TOX_ENV=py37-postgres TRIAL_FLAGS="-j 4"
services:
- postgresql
- name: "py3.7 / postgres9.5"
python: 3.7
addons:
postgresql: "9.5"
env: TOX_ENV=py37-postgres,codecov TRIAL_FLAGS="-j 4"
services:
- postgresql
- # we only need to check for the newsfragment if it's a PR build
if: type = pull_request
name: "check-newsfragment"
python: 3.6
script: scripts-dev/check-newsfragment
install:
# this just logs the postgres version we will be testing against (if any)
- psql -At -U postgres -c 'select version();' || true
- pip install tox
# if we don't have python3.6 in this environment, travis unhelpfully gives us
# a `python3.6` on our path which does nothing but spit out a warning. Tox
# tries to run it (even if we're not running a py36 env), so the build logs
# then have warnings which look like errors. To reduce the noise, remove the
# non-functional python3.6.
- ( ! command -v python3.6 || python3.6 --version ) &>/dev/null || rm -f $(command -v python3.6)
script:
- tox -e $TOX_ENV

View File

@@ -39,6 +39,7 @@ prune .circleci
prune .coveragerc
prune debian
prune .codecov.yml
prune .buildkite
exclude jenkins*
recursive-exclude jenkins *.sh

1
changelog.d/4699.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix attempting to paginate in rooms where server cannot see any events, to avoid unnecessarily pulling in lots of redacted events.

1
changelog.d/4740.bugfix Normal file
View File

@@ -0,0 +1 @@
'event_id' is now a required parameter in federated state requests, as per the matrix spec.

1
changelog.d/4749.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix tightloop over connecting to replication server.

1
changelog.d/4752.misc Normal file
View File

@@ -0,0 +1 @@
Change from TravisCI to Buildkite for CI.

1
changelog.d/4757.feature Normal file
View File

@@ -0,0 +1 @@
Move server key queries to federation reader.

1
changelog.d/4757.misc Normal file
View File

@@ -0,0 +1 @@
When presence is disabled don't send over replication.

1
changelog.d/4759.feature Normal file
View File

@@ -0,0 +1 @@
Add support for /account/3pid REST endpoint to client_reader worker.

1
changelog.d/4763.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix parsing of Content-Disposition headers on remote media requests and URL previews.

1
changelog.d/4765.misc Normal file
View File

@@ -0,0 +1 @@
Minor docstring fixes for MatrixFederationAgent.

1
changelog.d/4770.misc Normal file
View File

@@ -0,0 +1 @@
Optimise EDU transmission for the federation_sender worker.

1
changelog.d/4771.misc Normal file
View File

@@ -0,0 +1 @@
Update test_typing to use HomeserverTestCase.

1
changelog.d/4776.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix incorrect log about not persisting duplicate state event.

1
changelog.d/4790.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix v4v6 option in HAProxy example config. Contributed by Flakebi.

1
changelog.d/4791.feature Normal file
View File

@@ -0,0 +1 @@
Include a default configuration file in the 'docs' directory.

1
changelog.d/4796.feature Normal file
View File

@@ -0,0 +1 @@
Add support for /keys/query and /keys/changes REST endpoints to client_reader worker.

1
changelog.d/4797.misc Normal file
View File

@@ -0,0 +1 @@
Clean up read-receipt handling.

1
changelog.d/4798.misc Normal file
View File

@@ -0,0 +1 @@
Add some debug about processing read receipts.

1
changelog.d/4799.misc Normal file
View File

@@ -0,0 +1 @@
Clean up some replication code.

View File

@@ -0,0 +1,7 @@
# This file is a reference to the configuration options which can be set in
# homeserver.yaml.
#
# Note that it is not quite ready to be used as-is. If you are starting from
# scratch, it is easier to generate the config files following the instructions
# in INSTALL.md.

View File

@@ -88,18 +88,16 @@ Let's assume that we expect clients to connect to our server at
* HAProxy::
frontend https
bind 0.0.0.0:443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
bind :::443 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
# Matrix client traffic
acl matrix hdr(host) -i matrix.example.com
use_backend matrix if matrix
frontend matrix-federation
bind 0.0.0.0:8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
bind :::8448 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
default_backend matrix
backend matrix
server matrix 127.0.0.1:8008

1041
docs/sample_config.yaml Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -188,7 +188,9 @@ RDATA (S)
A single update in a stream
POSITION (S)
The position of the stream has been updated
The position of the stream has been updated. Sent to the client after all
missing updates for a stream have been sent to the client and they're now
up to date.
ERROR (S, C)
There was an error

View File

@@ -182,6 +182,7 @@ endpoints matching the following regular expressions::
^/_matrix/federation/v1/event_auth/
^/_matrix/federation/v1/exchange_third_party_invite/
^/_matrix/federation/v1/send/
^/_matrix/key/v2/query
The above endpoints should all be routed to the federation_reader worker by the
reverse-proxy configuration.
@@ -223,6 +224,9 @@ following regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
^/_matrix/client/(api/v1|r0|unstable)/login$
^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
^/_matrix/client/(api/v1|r0|unstable)/keys/query$
^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
Additionally, the following REST endpoints can be handled, but all requests must
be routed to the same instance::

View File

@@ -0,0 +1,18 @@
#!/bin/bash
#
# Update/check the docs/sample_config.yaml
set -e
cd `dirname $0`/..
SAMPLE_CONFIG="docs/sample_config.yaml"
if [ "$1" == "--check" ]; then
diff -u "$SAMPLE_CONFIG" <(./scripts/generate_config --header-file docs/.sample_config_header.yaml) >/dev/null || {
echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2
exit 1
}
else
./scripts/generate_config --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG"
fi

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python
import argparse
import shutil
import sys
from synapse.config.homeserver import HomeServerConfig
@@ -50,6 +51,13 @@ if __name__ == "__main__":
help="File to write the configuration to. Default: stdout",
)
parser.add_argument(
"--header-file",
type=argparse.FileType('r'),
help="File from which to read a header, which will be printed before the "
"generated config.",
)
args = parser.parse_args()
report_stats = args.report_stats
@@ -64,4 +72,7 @@ if __name__ == "__main__":
report_stats=report_stats,
)
if args.header_file:
shutil.copyfileobj(args.header_file, args.output_file)
args.output_file.write(conf)

View File

@@ -27,4 +27,4 @@ try:
except ImportError:
pass
__version__ = "0.99.2"
__version__ = "0.99.2.post1"

View File

@@ -33,9 +33,13 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
@@ -48,6 +52,8 @@ from synapse.rest.client.v1.room import (
RoomMemberListRestServlet,
RoomStateRestServlet,
)
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
@@ -60,6 +66,10 @@ logger = logging.getLogger("synapse.app.client_reader")
class ClientReaderSlavedStore(
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedReceiptsStore,
SlavedPushRuleStore,
SlavedAccountDataStore,
SlavedEventStore,
SlavedKeyStore,
@@ -96,6 +106,9 @@ class ClientReaderServer(HomeServer):
RoomEventContextServlet(self).register(resource)
RegisterRestServlet(self).register(resource)
LoginRestServlet(self).register(resource)
ThreepidRestServlet(self).register(resource)
KeyQueryServlet(self).register(resource)
KeyChangesServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,

View File

@@ -21,7 +21,7 @@ from twisted.web.resource import NoResource
import synapse
from synapse import events
from synapse.api.urls import FEDERATION_PREFIX
from synapse.api.urls import FEDERATION_PREFIX, SERVER_KEY_V2_PREFIX
from synapse.app import _base
from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
@@ -44,6 +44,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
@@ -99,6 +100,9 @@ class FederationReaderServer(HomeServer):
),
})
if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(

View File

@@ -180,9 +180,7 @@ class Config(object):
Returns:
str: the yaml config file
"""
default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(
default_config = "\n\n".join(
dedent(conf)
for conf in self.invoke_all(
"default_config",
@@ -297,19 +295,26 @@ class Config(object):
"Must specify a server_name to a generate config for."
" Pass -H server.name."
)
config_str = obj.generate_config(
config_dir_path=config_dir_path,
data_dir_path=os.getcwd(),
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
generate_secrets=True,
)
if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "w") as config_file:
config_str = obj.generate_config(
config_dir_path=config_dir_path,
data_dir_path=os.getcwd(),
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
generate_secrets=True,
config_file.write(
"# vim:ft=yaml\n\n"
)
config = yaml.load(config_str)
obj.invoke_all("generate_files", config)
config_file.write(config_str)
config = yaml.load(config_str)
obj.invoke_all("generate_files", config)
print(
(
"A config file has been generated in %r for server name"

View File

@@ -49,7 +49,8 @@ class DatabaseConfig(Config):
def default_config(self, data_dir_path, **kwargs):
database_path = os.path.join(data_dir_path, "homeserver.db")
return """\
# Database configuration
## Database ##
database:
# The database engine name
name: "sqlite3"

View File

@@ -81,7 +81,9 @@ class LoggingConfig(Config):
def default_config(self, config_dir_path, server_name, **kwargs):
log_config = os.path.join(config_dir_path, server_name + ".log.config")
return """
return """\
## Logging ##
# A yaml python logging config file
#
log_config: "%(log_config)s"

View File

@@ -260,9 +260,11 @@ class ServerConfig(Config):
# This is used by remote servers to connect to this server,
# e.g. matrix.org, localhost:8080, etc.
# This is also the last part of your UserID.
#
server_name: "%(server_name)s"
# When running as a daemon, the file to store the pid in
#
pid_file: %(pid_file)s
# CPU affinity mask. Setting this restricts the CPUs on which the
@@ -304,9 +306,11 @@ class ServerConfig(Config):
# Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the
# hard limit.
#
soft_file_limit: 0
# Set to false to disable presence tracking on this homeserver.
#
use_presence: true
# The GC threshold parameters to pass to `gc.set_threshold`, if defined

View File

@@ -886,6 +886,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
def on_edu(self, edu_type, origin, content):
"""Overrides FederationHandlerRegistry
"""
if not self.config.use_presence and edu_type == "m.presence":
return
handler = self.edu_handlers.get(edu_type)
if handler:
return super(ReplicationFederationHandlerRegistry, self).on_edu(

View File

@@ -159,8 +159,12 @@ class FederationRemoteSendQueue(object):
# stream.
pass
def send_edu(self, destination, edu_type, content, key=None):
def build_and_send_edu(self, destination, edu_type, content, key=None):
"""As per TransactionQueue"""
if destination == self.server_name:
logger.info("Not sending EDU to ourselves")
return
pos = self._next_pos()
edu = Edu(
@@ -465,15 +469,11 @@ def process_rows_for_federation(transaction_queue, rows):
for destination, edu_map in iteritems(buff.keyed_edus):
for key, edu in edu_map.items():
transaction_queue.send_edu(
edu.destination, edu.edu_type, edu.content, key=key,
)
transaction_queue.send_edu(edu, key)
for destination, edu_list in iteritems(buff.edus):
for edu in edu_list:
transaction_queue.send_edu(
edu.destination, edu.edu_type, edu.content, key=None,
)
transaction_queue.send_edu(edu, None)
for destination in buff.device_destinations:
transaction_queue.send_device_messages(destination)

View File

@@ -361,7 +361,19 @@ class TransactionQueue(object):
self._attempt_new_transaction(destination)
def send_edu(self, destination, edu_type, content, key=None):
def build_and_send_edu(self, destination, edu_type, content, key=None):
"""Construct an Edu object, and queue it for sending
Args:
destination (str): name of server to send to
edu_type (str): type of EDU to send
content (dict): content of EDU
key (Any|None): clobbering key for this edu
"""
if destination == self.server_name:
logger.info("Not sending EDU to ourselves")
return
edu = Edu(
origin=self.server_name,
destination=destination,
@@ -369,18 +381,23 @@ class TransactionQueue(object):
content=content,
)
if destination == self.server_name:
logger.info("Not sending EDU to ourselves")
return
self.send_edu(edu, key)
def send_edu(self, edu, key):
"""Queue an EDU for sending
Args:
edu (Edu): edu to send
key (Any|None): clobbering key for this edu
"""
if key:
self.pending_edus_keyed_by_dest.setdefault(
destination, {}
edu.destination, {}
)[(edu.edu_type, key)] = edu
else:
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
self.pending_edus_by_dest.setdefault(edu.destination, []).append(edu)
self._attempt_new_transaction(destination)
self._attempt_new_transaction(edu.destination)
def send_device_messages(self, destination):
if destination == self.server_name:

View File

@@ -393,7 +393,7 @@ class FederationStateServlet(BaseFederationServlet):
return self.handler.on_context_state_request(
origin,
context,
parse_string_from_args(query, "event_id", None),
parse_string_from_args(query, "event_id", None, required=True),
)
@@ -404,7 +404,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
return self.handler.on_state_ids_request(
origin,
room_id,
parse_string_from_args(query, "event_id", None),
parse_string_from_args(query, "event_id", None, required=True),
)

View File

@@ -37,13 +37,185 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
class DeviceHandler(BaseHandler):
class DeviceWorkerHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
super(DeviceWorkerHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self._auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""
Retrieve the given user's devices
Args:
user_id (str):
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
"""
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
user_id, device_id=None
)
devices = list(device_map.values())
for device in devices:
_update_device_from_client_ips(device, ips)
defer.returnValue(devices)
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred: dict[str, X]: info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
user_id, device_id,
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@measure_func("device.get_user_ids_changed")
@defer.inlineCallbacks
def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
Args:
user_id (str)
from_token (StreamToken)
"""
now_room_key = yield self.store.get_room_events_max_id()
room_ids = yield self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed
changed = yield self.store.get_user_whose_devices_changed(
from_token.device_list_key
)
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
member_events = yield self.store.get_membership_changes_for_user(
user_id, from_token.room_key, now_room_key,
)
rooms_changed.update(event.room_id for event in member_events)
stream_ordering = RoomStreamToken.parse_stream_token(
from_token.room_key
).stream
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
current_state_ids = yield self.store.get_current_state_ids(room_id)
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids:
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_left.add(state_key)
continue
# Fetch the current state at the time.
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering
)
except errors.StoreError:
# we have purged the stream_ordering index since the stream
# ordering: treat it the same as a new room
event_ids = []
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
continue
current_member_id = current_state_ids.get((EventTypes.Member, user_id))
if not current_member_id:
continue
# mapping from event_id -> state_dict
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
for state_dict in itervalues(prev_state_ids):
member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id:
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
break
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
for state_dict in itervalues(prev_state_ids):
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
if state_key != user_id:
possibly_changed.add(state_key)
break
if possibly_changed or possibly_left:
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
# Take the intersection of the users whose devices may have changed
# and those that actually still share a room with the user
possibly_joined = possibly_changed & users_who_share_room
possibly_left = (possibly_changed | possibly_left) - users_who_share_room
else:
possibly_joined = []
possibly_left = []
defer.returnValue({
"changed": list(possibly_joined),
"left": list(possibly_left),
})
class DeviceHandler(DeviceWorkerHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
self.federation_sender = hs.get_federation_sender()
self._edu_updater = DeviceListEduUpdater(hs, self)
@@ -103,52 +275,6 @@ class DeviceHandler(BaseHandler):
raise errors.StoreError(500, "Couldn't generate a device ID.")
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""
Retrieve the given user's devices
Args:
user_id (str):
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
"""
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
user_id, device_id=None
)
devices = list(device_map.values())
for device in devices:
_update_device_from_client_ips(device, ips)
defer.returnValue(devices)
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred: dict[str, X]: info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
user_id, device_id,
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
@@ -287,126 +413,6 @@ class DeviceHandler(BaseHandler):
for host in hosts:
self.federation_sender.send_device_messages(host)
@measure_func("device.get_user_ids_changed")
@defer.inlineCallbacks
def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
Args:
user_id (str)
from_token (StreamToken)
"""
now_token = yield self.hs.get_event_sources().get_current_token()
room_ids = yield self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed
changed = yield self.store.get_user_whose_devices_changed(
from_token.device_list_key
)
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
member_events = yield self.store.get_membership_changes_for_user(
user_id, from_token.room_key, now_token.room_key
)
rooms_changed.update(event.room_id for event in member_events)
stream_ordering = RoomStreamToken.parse_stream_token(
from_token.room_key
).stream
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
current_state_ids = yield self.store.get_current_state_ids(room_id)
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids:
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_left.add(state_key)
continue
# Fetch the current state at the time.
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering
)
except errors.StoreError:
# we have purged the stream_ordering index since the stream
# ordering: treat it the same as a new room
event_ids = []
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
continue
current_member_id = current_state_ids.get((EventTypes.Member, user_id))
if not current_member_id:
continue
# mapping from event_id -> state_dict
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
for state_dict in itervalues(prev_state_ids):
member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id:
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
break
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
for state_dict in itervalues(prev_state_ids):
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
if state_key != user_id:
possibly_changed.add(state_key)
break
if possibly_changed or possibly_left:
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
# Take the intersection of the users whose devices may have changed
# and those that actually still share a room with the user
possibly_joined = possibly_changed & users_who_share_room
possibly_left = (possibly_changed | possibly_left) - users_who_share_room
else:
possibly_joined = []
possibly_left = []
defer.returnValue({
"changed": list(possibly_joined),
"left": list(possibly_left),
})
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)

View File

@@ -858,6 +858,52 @@ class FederationHandler(BaseHandler):
logger.debug("Not backfilling as no extremeties found.")
return
# We only want to paginate if we can actually see the events we'll get,
# as otherwise we'll just spend a lot of resources to get redacted
# events.
#
# We do this by filtering all the backwards extremities and seeing if
# any remain. Given we don't have the extremity events themselves, we
# need to actually check the events that reference them.
#
# *Note*: the spec wants us to keep backfilling until we reach the start
# of the room in case we are allowed to see some of the history. However
# in practice that causes more issues than its worth, as a) its
# relatively rare for there to be any visible history and b) even when
# there is its often sufficiently long ago that clients would stop
# attempting to paginate before backfill reached the visible history.
#
# TODO: If we do do a backfill then we should filter the backwards
# extremities to only include those that point to visible portions of
# history.
#
# TODO: Correctly handle the case where we are allowed to see the
# forward event but not the backward extremity, e.g. in the case of
# initial join of the server where we are allowed to see the join
# event but not anything before it. This would require looking at the
# state *before* the event, ignoring the special casing certain event
# types have.
forward_events = yield self.store.get_successor_events(
list(extremities),
)
extremities_events = yield self.store.get_events(
forward_events,
check_redacted=False,
get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = yield filter_events_for_server(
self.store, self.server_name, list(extremities_events.values()),
redact=False, check_history_visibility_only=True,
)
if not filtered_extremities:
defer.returnValue(False)
# Check if we reached a point where we should start backfilling.
sorted_extremeties_tuple = sorted(
extremities.items(),

View File

@@ -436,10 +436,11 @@ class EventCreationHandler(object):
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
logger.info(
"Not bothering to persist duplicate state event %s", event.event_id,
)
if prev_state is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
event.event_id, prev_state.event_id,
)
defer.returnValue(prev_state)
yield self.handle_new_client_event(

View File

@@ -816,7 +816,7 @@ class PresenceHandler(object):
if self.is_mine(observed_user):
yield self.invite_presence(observed_user, observer_user)
else:
yield self.federation.send_edu(
yield self.federation.build_and_send_edu(
destination=observed_user.domain,
edu_type="m.presence_invite",
content={
@@ -836,7 +836,7 @@ class PresenceHandler(object):
if self.is_mine(observer_user):
yield self.accept_presence(observed_user, observer_user)
else:
self.federation.send_edu(
self.federation.build_and_send_edu(
destination=observer_user.domain,
edu_type="m.presence_accept",
content={
@@ -848,7 +848,7 @@ class PresenceHandler(object):
state_dict = yield self.get_state(observed_user, as_event=False)
state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
self.federation.send_edu(
self.federation.build_and_send_edu(
destination=observer_user.domain,
edu_type="m.presence",
content={

View File

@@ -16,7 +16,6 @@ import logging
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from ._base import BaseHandler
@@ -38,31 +37,6 @@ class ReceiptsHandler(BaseHandler):
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
@defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id,
event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
receipt = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": [event_id],
"data": {
"ts": int(self.clock.time_msec()),
}
}
is_new = yield self._handle_new_receipts([receipt])
if is_new:
# fire off a process in the background to send the receipt to
# remote servers
run_as_background_process(
'push_receipts_to_remotes', self._push_remotes, receipt
)
@defer.inlineCallbacks
def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS.
@@ -128,43 +102,54 @@ class ReceiptsHandler(BaseHandler):
defer.returnValue(True)
@defer.inlineCallbacks
def _push_remotes(self, receipt):
"""Given a receipt, works out which remote servers should be
poked and pokes them.
def received_client_receipt(self, room_id, receipt_type, user_id,
event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
try:
# TODO: optimise this to move some of the work to the workers.
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
receipt = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": [event_id],
"data": {
"ts": int(self.clock.time_msec()),
}
}
users = yield self.state.get_current_user_in_room(room_id)
remotedomains = set(get_domain_from_id(u) for u in users)
remotedomains = remotedomains.copy()
remotedomains.discard(self.server_name)
is_new = yield self._handle_new_receipts([receipt])
if not is_new:
return
logger.debug("Sending receipt to: %r", remotedomains)
# Work out which remote servers should be poked and poke them.
for domain in remotedomains:
self.federation.send_edu(
destination=domain,
edu_type="m.receipt",
content={
room_id: {
receipt_type: {
user_id: {
"event_ids": event_ids,
"data": data,
}
# TODO: optimise this to move some of the work to the workers.
data = receipt["data"]
# XXX why does this not use state.get_current_hosts_in_room() ?
users = yield self.state.get_current_user_in_room(room_id)
remotedomains = set(get_domain_from_id(u) for u in users)
remotedomains = remotedomains.copy()
remotedomains.discard(self.server_name)
logger.debug("Sending receipt to: %r", remotedomains)
for domain in remotedomains:
self.federation.build_and_send_edu(
destination=domain,
edu_type="m.receipt",
content={
room_id: {
receipt_type: {
user_id: {
"event_ids": [event_id],
"data": data,
}
},
}
},
key=(room_id, receipt_type, user_id),
)
except Exception:
logger.exception("Error pushing receipts to remote servers")
},
key=(room_id, receipt_type, user_id),
)
@defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key):

View File

@@ -231,7 +231,7 @@ class TypingHandler(object):
for domain in set(get_domain_from_id(u) for u in users):
if domain != self.server_name:
logger.debug("sending typing update to %s", domain)
self.federation.send_edu(
self.federation.build_and_send_edu(
destination=domain,
edu_type="m.typing",
content={

View File

@@ -68,9 +68,13 @@ class MatrixFederationAgent(object):
TLS policy to use for fetching .well-known files. None to use a default
(browser-like) implementation.
srv_resolver (SrvResolver|None):
_srv_resolver (SrvResolver|None):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
_well_known_cache (TTLCache|None):
TTLCache impl for storing cached well-known lookups. None to use a default
implementation.
"""
def __init__(

View File

@@ -178,8 +178,6 @@ class Notifier(object):
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
)
self.replication_deferred = ObservableDeferred(defer.Deferred())
# This is not a very cheap test to perform, but it's only executed
# when rendering the metrics page, which is likely once per minute at
# most when scraping it.
@@ -205,7 +203,9 @@ class Notifier(object):
def add_replication_callback(self, cb):
"""Add a callback that will be called when some new data is available.
Callback is not given any arguments.
Callback is not given any arguments. It should *not* return a Deferred - if
it needs to do any asynchronous work, a background thread should be started and
wrapped with run_as_background_process.
"""
self.replication_callbacks.append(cb)
@@ -517,60 +517,5 @@ class Notifier(object):
def notify_replication(self):
"""Notify the any replication listeners that there's a new event"""
with PreserveLoggingContext():
deferred = self.replication_deferred
self.replication_deferred = ObservableDeferred(defer.Deferred())
deferred.callback(None)
# the callbacks may well outlast the current request, so we run
# them in the sentinel logcontext.
#
# (ideally it would be up to the callbacks to know if they were
# starting off background processes and drop the logcontext
# accordingly, but that requires more changes)
for cb in self.replication_callbacks:
cb()
@defer.inlineCallbacks
def wait_for_replication(self, callback, timeout):
"""Wait for an event to happen.
Args:
callback: Gets called whenever an event happens. If this returns a
truthy value then ``wait_for_replication`` returns, otherwise
it waits for another event.
timeout: How many milliseconds to wait for callback return a truthy
value.
Returns:
A deferred that resolves with the value returned by the callback.
"""
listener = _NotificationListener(None)
end_time = self.clock.time_msec() + timeout
while True:
listener.deferred = self.replication_deferred.observe()
result = yield callback()
if result:
break
now = self.clock.time_msec()
if end_time <= now:
break
listener.deferred = timeout_deferred(
listener.deferred,
timeout=(end_time - now) / 1000.,
reactor=self.hs.get_reactor(),
)
try:
with PreserveLoggingContext():
yield listener.deferred
except defer.TimeoutError:
break
except defer.CancelledError:
break
defer.returnValue(result)
for cb in self.replication_callbacks:
cb()

View File

@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage import DataStore
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.deviceinbox import DeviceInboxWorkerStore
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
from ._slaved_id_tracker import SlavedIdTracker
class SlavedDeviceInboxStore(BaseSlavedStore):
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
@@ -43,12 +42,6 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
expiry_ms=30 * 60 * 1000,
)
get_to_device_stream_token = __func__(DataStore.get_to_device_stream_token)
get_new_messages_for_device = __func__(DataStore.get_new_messages_for_device)
get_new_device_msgs_for_remote = __func__(DataStore.get_new_device_msgs_for_remote)
delete_messages_for_device = __func__(DataStore.delete_messages_for_device)
delete_device_msgs_for_remote = __func__(DataStore.delete_device_msgs_for_remote)
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()

View File

@@ -13,15 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage import DataStore
from synapse.storage.end_to_end_keys import EndToEndKeyStore
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.devices import DeviceWorkerStore
from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
from ._slaved_id_tracker import SlavedIdTracker
class SlavedDeviceStore(BaseSlavedStore):
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
@@ -38,17 +37,6 @@ class SlavedDeviceStore(BaseSlavedStore):
"DeviceListFederationStreamChangeCache", device_list_max,
)
get_device_stream_token = __func__(DataStore.get_device_stream_token)
get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed)
get_devices_by_remote = __func__(DataStore.get_devices_by_remote)
_get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn)
_get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn)
mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote)
_mark_as_sent_devices_by_remote_txn = (
__func__(DataStore._mark_as_sent_devices_by_remote_txn)
)
count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token()
@@ -58,14 +46,23 @@ class SlavedDeviceStore(BaseSlavedStore):
if stream_name == "device_lists":
self._device_list_id_gen.advance(token)
for row in rows:
self._device_list_stream_cache.entity_has_changed(
row.user_id, token
self._invalidate_caches_for_devices(
token, row.user_id, row.destination,
)
if row.destination:
self._device_list_federation_stream_cache.entity_has_changed(
row.destination, token
)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
def _invalidate_caches_for_devices(self, token, user_id, destination):
self._device_list_stream_cache.entity_has_changed(
user_id, token
)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
self._get_cached_devices_for_user.invalidate((user_id,))
self._get_cached_user_device.invalidate_many((user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))

View File

@@ -54,8 +54,11 @@ class SlavedPresenceStore(BaseSlavedStore):
def stream_positions(self):
result = super(SlavedPresenceStore, self).stream_positions()
position = self._presence_id_gen.get_current_token()
result["presence"] = position
if self.hs.config.use_presence:
position = self._presence_id_gen.get_current_token()
result["presence"] = position
return result
def process_replication_rows(self, stream_name, token, rows):

View File

@@ -20,7 +20,7 @@ from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore
class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def __init__(self, db_conn, hs):
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",

View File

@@ -39,7 +39,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
Accepts a handler that will be called when new data is available or data
is required.
"""
maxDelay = 5 # Try at least once every N seconds
maxDelay = 30 # Try at least once every N seconds
def __init__(self, hs, client_name, handler):
self.client_name = client_name
@@ -54,7 +54,6 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def buildProtocol(self, addr):
logger.info("Connected to replication: %r", addr)
self.resetDelay()
return ClientReplicationStreamProtocol(
self.client_name, self.server_name, self._clock, self.handler
)
@@ -90,15 +89,18 @@ class ReplicationClientHandler(object):
# Used for tests.
self.awaiting_syncs = {}
# The factory used to create connections.
self.factory = None
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
"""
client_name = hs.config.worker_name
factory = ReplicationClientFactory(hs, client_name, self)
self.factory = ReplicationClientFactory(hs, client_name, self)
host = hs.config.worker_replication_host
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, factory)
hs.get_reactor().connectTCP(host, port, self.factory)
def on_rdata(self, stream_name, token, rows):
"""Called when we get new replication data. By default this just pokes
@@ -140,6 +142,7 @@ class ReplicationClientHandler(object):
args["account_data"] = user_account_data
elif room_account_data:
args["account_data"] = room_account_data
return args
def get_currently_syncing_users(self):
@@ -204,3 +207,14 @@ class ReplicationClientHandler(object):
for cmd in self.pending_commands:
connection.send_command(cmd)
self.pending_commands = []
def finished_connecting(self):
"""Called when we have successfully subscribed and caught up to all
streams we're interested in.
"""
logger.info("Finished connecting to server")
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
self.factory.resetDelay()

View File

@@ -127,8 +127,11 @@ class RdataCommand(Command):
class PositionCommand(Command):
"""Sent by the client to tell the client the stream postition without
"""Sent by the server to tell the client the stream postition without
needing to send an RDATA.
Sent to the client after all missing updates for a stream have been sent
to the client and they're now up to date.
"""
NAME = "POSITION"

View File

@@ -526,6 +526,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name
self.handler = handler
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
self.streams_connecting = set()
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
self.pending_batches = {}
@@ -548,6 +553,10 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# We've now finished connecting to so inform the client handler
self.handler.update_connection(self)
# This will happen if we don't actually subscribe to any streams
if not self.streams_connecting:
self.handler.finished_connecting()
def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
@@ -577,6 +586,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
return self.handler.on_rdata(stream_name, cmd.token, rows)
def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
return self.handler.on_position(cmd.stream_name, cmd.token)
def on_SYNC(self, cmd):
@@ -593,6 +608,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.id(), stream_name, token
)
self.streams_connecting.add(stream_name)
self.send_command(ReplicateCommand(stream_name, token))
def on_connection_closed(self):

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -213,8 +214,7 @@ def get_filename_from_headers(headers):
Content-Disposition HTTP header.
Args:
headers (twisted.web.http_headers.Headers): The HTTP
request headers.
headers (dict[bytes, list[bytes]]): The HTTP request headers.
Returns:
A Unicode string of the filename, or None.
@@ -225,23 +225,12 @@ def get_filename_from_headers(headers):
if not content_disposition[0]:
return
# dict of unicode: bytes, corresponding to the key value sections of the
# Content-Disposition header.
params = {}
parts = content_disposition[0].split(b";")
for i in parts:
# Split into key-value pairs, if able
# We don't care about things like `inline`, so throw it out
if b"=" not in i:
continue
key, value = i.strip().split(b"=")
params[key.decode('ascii')] = value
_, params = _parse_header(content_disposition[0])
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get("filename*", None)
upload_name_utf8 = params.get(b"filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith(b"utf-8''"):
upload_name_utf8 = upload_name_utf8[7:]
@@ -267,12 +256,68 @@ def get_filename_from_headers(headers):
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get("filename", None)
upload_name_ascii = params.get(b"filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
# Make sure there's no %-quoted bytes. If there is, reject it as
# non-valid ASCII.
if b"%" not in upload_name_ascii:
upload_name = upload_name_ascii.decode('ascii')
upload_name = upload_name_ascii.decode('ascii')
# This may be None here, indicating we did not find a matching name.
return upload_name
def _parse_header(line):
"""Parse a Content-type like header.
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
line (bytes): header to be parsed
Returns:
Tuple[bytes, dict[bytes, bytes]]:
the main content-type, followed by the parameter dictionary
"""
parts = _parseparam(b';' + line)
key = next(parts)
pdict = {}
for p in parts:
i = p.find(b'=')
if i >= 0:
name = p[:i].strip().lower()
value = p[i + 1:].strip()
# strip double-quotes
if len(value) >= 2 and value[0:1] == value[-1:] == b'"':
value = value[1:-1]
value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"')
pdict[name] = value
return key, pdict
def _parseparam(s):
"""Generator which splits the input on ;, respecting double-quoted sequences
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
s (bytes): header to be parsed
Returns:
Iterable[bytes]: the split input
"""
while s[:1] == b';':
s = s[1:]
# look for the next ;
end = s.find(b';')
# if there is an odd number of " marks between here and the next ;, skip to the
# next ; instead
while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2:
end = s.find(b';', end + 1)
if end < 0:
end = len(s)
f = s[:end]
yield f.strip()
s = s[end:]

View File

@@ -51,7 +51,7 @@ from synapse.handlers.acme import AcmeHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.device import DeviceHandler
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
@@ -307,7 +307,10 @@ class HomeServer(object):
return MacaroonGenerator(self)
def build_device_handler(self):
return DeviceHandler(self)
if self.config.worker_app:
return DeviceWorkerHandler(self)
else:
return DeviceHandler(self)
def build_device_message_handler(self):
return DeviceMessageHandler(self)

View File

@@ -7,9 +7,9 @@ import synapse.handlers.auth
import synapse.handlers.deactivate_account
import synapse.handlers.device
import synapse.handlers.e2e_keys
import synapse.handlers.message
import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.message
import synapse.handlers.set_password
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager

View File

@@ -19,14 +19,174 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.expiringcache import ExpiringCache
from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
class DeviceInboxStore(BackgroundUpdateStore):
class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
current_stream_id(int): The current position of the to device
message stream.
Returns:
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_device_txn(txn):
sql = (
"SELECT stream_id, message_json FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
user_id, device_id, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn,
)
@defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves to the number of messages deleted.
"""
# If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), None
)
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_deleted_stream_id
)
if not has_changed:
defer.returnValue(0)
def delete_messages_for_device_txn(txn):
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
count = yield self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
last_deleted_stream_id, up_to_stream_id
)
defer.returnValue(count)
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit=100
):
"""
Args:
destination(str): The name of the remote server.
last_stream_id(int|long): The last position of the device message stream
that the server sent up to.
current_stream_id(int|long): The current position of the device
message stream.
Returns:
Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
destination, last_stream_id
)
if not has_changed or last_stream_id == current_stream_id:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_remote_destination_txn(txn):
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
destination, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
destination(str): The destination server_name
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
"""
def delete_messages_for_remote_destination_txn(txn):
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
)
class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
@@ -220,93 +380,6 @@ class DeviceInboxStore(BackgroundUpdateStore):
txn.executemany(sql, rows)
def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
current_stream_id(int): The current position of the to device
message stream.
Returns:
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_device_txn(txn):
sql = (
"SELECT stream_id, message_json FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
user_id, device_id, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn,
)
@defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves to the number of messages deleted.
"""
# If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), None
)
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_deleted_stream_id
)
if not has_changed:
defer.returnValue(0)
def delete_messages_for_device_txn(txn):
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
count = yield self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
last_deleted_stream_id, up_to_stream_id
)
defer.returnValue(count)
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
@@ -351,77 +424,6 @@ class DeviceInboxStore(BackgroundUpdateStore):
"get_all_new_device_messages", get_all_new_device_messages_txn
)
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit=100
):
"""
Args:
destination(str): The name of the remote server.
last_stream_id(int|long): The last position of the device message stream
that the server sent up to.
current_stream_id(int|long): The current position of the device
message stream.
Returns:
Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
destination, last_stream_id
)
if not has_changed or last_stream_id == current_stream_id:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_remote_destination_txn(txn):
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
destination, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
destination(str): The destination server_name
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
"""
def delete_messages_for_remote_destination_txn(txn):
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):

View File

@@ -22,11 +22,10 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import Cache, SQLBaseStore, db_to_json
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from ._base import Cache, db_to_json
logger = logging.getLogger(__name__)
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@@ -34,7 +33,343 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
)
class DeviceStore(BackgroundUpdateStore):
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
"""Retrieve a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
Raises:
StoreError: if the device is not found
"""
return self._simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices.
Args:
user_id (str):
Returns:
defer.Deferred: resolves to a dict from device_id to a dict
containing "device_id", "user_id" and "display_name" for each
device.
"""
devices = yield self._simple_select_list(
table="devices",
keyvalues={"user_id": user_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user"
)
defer.returnValue({d["device_id"]: d for d in devices})
def get_devices_by_remote(self, destination, from_stream_id):
"""Get stream of updates to send to remote servers
Returns:
(int, list[dict]): current stream id and list of updates
"""
now_stream_id = self._device_list_id_gen.get_current_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])
return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn,
destination, from_stream_id, now_stream_id,
)
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
now_stream_id):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
LIMIT 20
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
)
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
results = []
for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
}
prev_id = stream_id
if device is not None:
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
else:
result["deleted"] = True
results.append(result)
return (now_stream_id, results)
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
destination, stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# We update the device_lists_outbound_last_success with the successfully
# poked users. We do the join to see which users need to be inserted and
# which updated.
sql = """
SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
FROM device_lists_outbound_pokes as o
LEFT JOIN device_lists_outbound_last_success as s
USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
UPDATE device_lists_outbound_last_success
SET stream_id = ?
WHERE destination = ? AND user_id = ?
"""
txn.executemany(
sql, ((row[1], destination, row[0],) for row in rows if row[2])
)
sql = """
INSERT INTO device_lists_outbound_last_success
(destination, user_id, stream_id) VALUES (?, ?, ?)
"""
txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows if not row[2])
)
# Delete all sent outbound pokes
sql = """
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
"""
txn.execute(sql, (destination, stream_id,))
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
@defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
Args:
query_list(list): List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
user_ids = set(user_id for user_id, _ in query_list)
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
user_ids_in_cache = set(
user_id for user_id, stream_id in user_map.items() if stream_id
)
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
if device_id:
device = yield self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
results[user_id] = yield self._get_cached_devices_for_user(user_id)
defer.returnValue((user_ids_not_in_cache, results))
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
content = yield self._simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="content",
desc="_get_cached_user_device",
)
defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
devices = yield self._simple_select_list(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
retcols=("device_id", "content"),
desc="_get_cached_devices_for_user",
)
defer.returnValue({
device["device_id"]: db_to_json(device["content"])
for device in devices
})
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
Returns:
(stream_id, devices)
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
txn, [(user_id, None)], include_all_devices=True
)
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in iteritems(user_devices):
result = {
"device_id": device_id,
}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
"""Get set of users whose devices have changed since `from_key`.
"""
from_key = int(from_key)
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
if changed is not None:
defer.returnValue(set(changed))
sql = """
SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
# We do a group by here as there can be a large number of duplicate
# entries, since we throw away device IDs.
sql = """
SELECT MAX(stream_id) AS stream_id, user_id, destination
FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
return self._execute(
"get_all_device_list_changes_for_remotes", None,
sql, from_key, to_key
)
@cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self._simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
desc="get_device_list_last_stream_id_for_remote",
allow_none=True,
)
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids", inlineCallbacks=True)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self._simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id",),
desc="get_device_list_last_stream_id_for_remotes",
)
results = {user_id: None for user_id in user_ids}
results.update({
row["user_id"]: row["stream_id"] for row in rows
})
defer.returnValue(results)
class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs)
@@ -121,24 +456,6 @@ class DeviceStore(BackgroundUpdateStore):
initial_device_display_name, e)
raise StoreError(500, "Problem storing device.")
def get_device(self, user_id, device_id):
"""Retrieve a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
Raises:
StoreError: if the device is not found
"""
return self._simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
"""Delete a device.
@@ -202,57 +519,6 @@ class DeviceStore(BackgroundUpdateStore):
desc="update_device",
)
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices.
Args:
user_id (str):
Returns:
defer.Deferred: resolves to a dict from device_id to a dict
containing "device_id", "user_id" and "display_name" for each
device.
"""
devices = yield self._simple_select_list(
table="devices",
keyvalues={"user_id": user_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user"
)
defer.returnValue({d["device_id"]: d for d in devices})
@cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self._simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
desc="get_device_list_remote_extremity",
allow_none=True,
)
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids", inlineCallbacks=True)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self._simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id",),
desc="get_user_devices_from_cache",
)
results = {user_id: None for user_id in user_ids}
results.update({
row["user_id"]: row["stream_id"] for row in rows
})
defer.returnValue(results)
@defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
@@ -405,268 +671,6 @@ class DeviceStore(BackgroundUpdateStore):
lock=False,
)
def get_devices_by_remote(self, destination, from_stream_id):
"""Get stream of updates to send to remote servers
Returns:
(int, list[dict]): current stream id and list of updates
"""
now_stream_id = self._device_list_id_gen.get_current_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])
return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn,
destination, from_stream_id, now_stream_id,
)
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
now_stream_id):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
LIMIT 20
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
)
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
results = []
for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
}
prev_id = stream_id
if device is not None:
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
else:
result["deleted"] = True
results.append(result)
return (now_stream_id, results)
@defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
Args:
query_list(list): List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
user_ids = set(user_id for user_id, _ in query_list)
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
user_ids_in_cache = set(
user_id for user_id, stream_id in user_map.items() if stream_id
)
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
if device_id:
device = yield self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
results[user_id] = yield self._get_cached_devices_for_user(user_id)
defer.returnValue((user_ids_not_in_cache, results))
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
content = yield self._simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="content",
desc="_get_cached_user_device",
)
defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
devices = yield self._simple_select_list(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
retcols=("device_id", "content"),
desc="_get_cached_devices_for_user",
)
defer.returnValue({
device["device_id"]: db_to_json(device["content"])
for device in devices
})
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
Returns:
(stream_id, devices)
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
txn, [(user_id, None)], include_all_devices=True
)
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in iteritems(user_devices):
result = {
"device_id": device_id,
}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
destination, stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# We update the device_lists_outbound_last_success with the successfully
# poked users. We do the join to see which users need to be inserted and
# which updated.
sql = """
SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
FROM device_lists_outbound_pokes as o
LEFT JOIN device_lists_outbound_last_success as s
USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
UPDATE device_lists_outbound_last_success
SET stream_id = ?
WHERE destination = ? AND user_id = ?
"""
txn.executemany(
sql, ((row[1], destination, row[0],) for row in rows if row[2])
)
sql = """
INSERT INTO device_lists_outbound_last_success
(destination, user_id, stream_id) VALUES (?, ?, ?)
"""
txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows if not row[2])
)
# Delete all sent outbound pokes
sql = """
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
"""
txn.execute(sql, (destination, stream_id,))
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
"""Get set of users whose devices have changed since `from_key`.
"""
from_key = int(from_key)
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
if changed is not None:
defer.returnValue(set(changed))
sql = """
SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
# We do a group by here as there can be a large number of duplicate
# entries, since we throw away device IDs.
sql = """
SELECT MAX(stream_id) AS stream_id, user_id, destination
FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
return self._execute(
"get_all_device_list_changes_for_remotes", None,
sql, from_key, to_key
)
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
@@ -732,9 +736,6 @@ class DeviceStore(BackgroundUpdateStore):
]
)
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
def _prune_old_outbound_device_pokes(self):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers. We keep one entry per

View File

@@ -23,49 +23,7 @@ from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore, db_to_json
class EndToEndKeyStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
return False
self._simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": new_key_json,
}
)
return True
return self.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
class EndToEndKeyWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_e2e_device_keys(
self, query_list, include_all_devices=False,
@@ -238,6 +196,50 @@ class EndToEndKeyStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
return False
self._simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": new_key_json,
}
)
return True
return self.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
def _claim_e2e_one_time_keys(txn):

View File

@@ -442,6 +442,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
event_results.reverse()
return event_results
@defer.inlineCallbacks
def get_successor_events(self, event_ids):
"""Fetch all events that have the given events as a prev event
Args:
event_ids (iterable[str])
Returns:
Deferred[list[str]]
"""
rows = yield self._simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=event_ids,
retcols=("event_id",),
desc="get_successor_events"
)
defer.returnValue([
row["event_id"] for row in rows
])
class EventFederationStore(EventFederationWorkerStore):
""" Responsible for storing and serving up the various graphs associated

View File

@@ -346,15 +346,23 @@ class ReceiptsStore(ReceiptsWorkerStore):
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
"""Inserts a read-receipt into the database if it's newer than the current RR
Returns: int|None
None if the RR is older than the current RR
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
res = self._simple_select_one_txn(
txn,
table="events",
retcols=["topological_ordering", "stream_ordering"],
retcols=["stream_ordering", "received_ts"],
keyvalues={"event_id": event_id},
allow_none=True
)
stream_ordering = int(res["stream_ordering"]) if res else None
rx_ts = res["received_ts"] if res else 0
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
@@ -373,7 +381,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"one for later event %s",
event_id, eid,
)
return False
return None
txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
@@ -429,7 +437,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
stream_ordering=stream_ordering,
)
return True
return rx_ts
@defer.inlineCallbacks
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
@@ -466,7 +474,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
have_persisted = yield self.runInteraction(
event_ts = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id, receipt_type, user_id, linearized_event_id,
@@ -474,8 +482,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
stream_id=stream_id,
)
if not have_persisted:
defer.returnValue(None)
if event_ts is None:
defer.returnValue(None)
now = self._clock.time_msec()
logger.debug(
"RR for event %s in %s (%i ms old)",
linearized_event_id, room_id, now - event_ts,
)
yield self.insert_graph_receipt(
room_id, receipt_type, user_id, event_ids, data

View File

@@ -295,6 +295,39 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret['user_id']
return None
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", {
"medium": medium,
"address": address,
}, {
"user_id": user_id,
"validated_at": validated_at,
"added_at": added_at,
})
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
ret = yield self._simple_select_list(
"user_threepids", {
"user_id": user_id
},
['medium', 'address', 'validated_at', 'added_at'],
'user_get_threepids'
)
defer.returnValue(ret)
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
},
desc="user_delete_threepids",
)
class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore):
@@ -632,39 +665,6 @@ class RegistrationStore(RegistrationWorkerStore,
defer.returnValue(res if res else False)
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", {
"medium": medium,
"address": address,
}, {
"user_id": user_id,
"validated_at": validated_at,
"added_at": added_at,
})
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
ret = yield self._simple_select_list(
"user_threepids", {
"user_id": user_id
},
['medium', 'address', 'validated_at', 'added_at'],
'user_get_threepids'
)
defer.returnValue(ret)
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
},
desc="user_delete_threepids",
)
@defer.inlineCallbacks
def save_or_get_3pid_guest_access_token(
self, medium, address, access_token, inviter_user_id

View File

@@ -216,28 +216,36 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
@defer.inlineCallbacks
def filter_events_for_server(store, server_name, events):
# Whatever else we do, we need to check for senders which have requested
# erasure of their data.
erased_senders = yield store.are_users_erased(
(e.sender for e in events),
)
def filter_events_for_server(store, server_name, events, redact=True,
check_history_visibility_only=False):
"""Filter a list of events based on whether given server is allowed to
see them.
def redact_disallowed(event, state):
# if the sender has been gdpr17ed, always return a redacted
# copy of the event.
if erased_senders[event.sender]:
Args:
store (DataStore)
server_name (str)
events (iterable[FrozenEvent])
redact (bool): Whether to return a redacted version of the event, or
to filter them out entirely.
check_history_visibility_only (bool): Whether to only check the
history visibility, rather than things like if the sender has been
erased. This is used e.g. during pagination to decide whether to
backfill or not.
Returns
Deferred[list[FrozenEvent]]
"""
def is_sender_erased(event, erased_senders):
if erased_senders and erased_senders[event.sender]:
logger.info(
"Sender of %s has been erased, redacting",
event.event_id,
)
return prune_event(event)
# state will be None if we decided we didn't need to filter by
# room membership.
if not state:
return event
return True
return False
def check_event_is_visible(event, state):
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
@@ -259,17 +267,17 @@ def filter_events_for_server(store, server_name, events):
memtype = ev.membership
if memtype == Membership.JOIN:
return event
return True
elif memtype == Membership.INVITE:
if visibility == "invited":
return event
return True
else:
# server has no users in the room: redact
return prune_event(event)
return False
return event
return True
# Next lets check to see if all the events have a history visibility
# Lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If thats the case then we don't
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield store.get_state_ids_for_events(
@@ -296,16 +304,31 @@ def filter_events_for_server(store, server_name, events):
for e in itervalues(event_map)
)
if not check_history_visibility_only:
erased_senders = yield store.are_users_erased(
(e.sender for e in events),
)
else:
# We don't want to check whether users are erased, which is equivalent
# to no users having been erased.
erased_senders = {}
if all_open:
# all the history_visibility state affecting these events is open, so
# we don't need to filter by membership state. We *do* need to check
# for user erasure, though.
if erased_senders:
events = [
redact_disallowed(e, None)
for e in events
]
to_return = []
for e in events:
if not is_sender_erased(e, erased_senders):
to_return.append(e)
elif redact:
to_return.append(prune_event(e))
defer.returnValue(to_return)
# If there are no erased users then we can just return the given list
# of events without having to copy it.
defer.returnValue(events)
# Ok, so we're dealing with events that have non-trivial visibility
@@ -361,7 +384,13 @@ def filter_events_for_server(store, server_name, events):
for e_id, key_to_eid in iteritems(event_to_state_ids)
}
defer.returnValue([
redact_disallowed(e, event_to_state[e.event_id])
for e in events
])
to_return = []
for e in events:
erased = is_sender_erased(e, erased_senders)
visible = check_event_is_visible(e, event_to_state[e.event_id])
if visible and not erased:
to_return.append(e)
elif redact:
to_return.append(prune_event(e))
defer.returnValue(to_return)

View File

@@ -24,13 +24,17 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
from tests import unittest
from tests.utils import register_federation_servlets
from ..utils import (
DeferredMockCallable,
MockClock,
MockHttpResource,
setup_test_homeserver,
)
# Some local users to test with
U_APPLE = UserID.from_string("@apple:test")
U_BANANA = UserID.from_string("@banana:test")
# Remote user
U_ONION = UserID.from_string("@onion:farm")
# Test room id
ROOM_ID = "a-room"
def _expect_edu_transaction(edu_type, content, origin="test"):
@@ -46,30 +50,21 @@ def _make_edu_transaction_json(edu_type, content):
return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8')
class TypingNotificationsTestCase(unittest.TestCase):
"""Tests typing notifications to rooms."""
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
servlets = [register_federation_servlets]
@defer.inlineCallbacks
def setUp(self):
self.clock = MockClock()
def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
mock_keyring.verify_json_for_server.return_value = defer.succeed(True)
self.mock_http_client = Mock(spec=[])
self.mock_http_client.put_json = DeferredMockCallable()
# we mock out the federation client too
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
self.mock_federation_resource = MockHttpResource()
mock_notifier = Mock()
self.on_new_event = mock_notifier.on_new_event
self.auth = Mock(spec=[])
self.state_handler = Mock()
hs = yield setup_test_homeserver(
self.addCleanup,
"test",
auth=self.auth,
clock=self.clock,
datastore=Mock(
hs = self.setup_test_homeserver(
datastore=(Mock(
spec=[
# Bits that Federation needs
"prep_send_transaction",
@@ -82,16 +77,21 @@ class TypingNotificationsTestCase(unittest.TestCase):
"get_user_directory_stream_pos",
"get_current_state_deltas",
]
),
state_handler=self.state_handler,
handlers=Mock(),
notifier=mock_notifier,
resource_for_client=Mock(),
resource_for_federation=self.mock_federation_resource,
http_client=self.mock_http_client,
keyring=Mock(),
)),
notifier=Mock(),
http_client=mock_federation_client,
keyring=mock_keyring,
)
return hs
def prepare(self, reactor, clock, hs):
# the tests assume that we are starting at unix time 1000
reactor.pump((1000, ))
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
self.handler = hs.get_typing_handler()
self.event_source = hs.get_event_sources().sources["typing"]
@@ -109,13 +109,12 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.datastore.get_received_txn_response = get_received_txn_response
self.room_id = "a-room"
self.room_members = []
def check_joined_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
hs.get_auth().check_joined_room = check_joined_room
def get_joined_hosts_for_room(room_id):
return set(member.domain for member in self.room_members)
@@ -124,8 +123,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
def get_current_user_in_room(room_id):
return set(str(u) for u in self.room_members)
self.state_handler.get_current_user_in_room = get_current_user_in_room
hs.get_state_handler().get_current_user_in_room = get_current_user_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
@@ -134,230 +132,208 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.datastore.get_current_state_deltas.return_value = None
self.auth.check_joined_room = check_joined_room
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
# Some local users to test with
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
# Remote user
self.u_onion = UserID.from_string("@onion:farm")
@defer.inlineCallbacks
def test_started_typing_local(self):
self.room_members = [self.u_apple, self.u_banana]
self.room_members = [U_APPLE, U_BANANA]
self.assertEquals(self.event_source.get_current_key(), 0)
yield self.handler.started_typing(
target_user=self.u_apple,
auth_user=self.u_apple,
room_id=self.room_id,
self.successResultOf(self.handler.started_typing(
target_user=U_APPLE,
auth_user=U_APPLE,
room_id=ROOM_ID,
timeout=20000,
)
))
self.on_new_event.assert_has_calls(
[call('typing_key', 1, rooms=[self.room_id])]
[call('typing_key', 1, rooms=[ROOM_ID])]
)
self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events(
room_ids=[self.room_id], from_key=0
events = self.event_source.get_new_events(
room_ids=[ROOM_ID], from_key=0
)
self.assertEquals(
events[0],
[
{
"type": "m.typing",
"room_id": self.room_id,
"content": {"user_ids": [self.u_apple.to_string()]},
"room_id": ROOM_ID,
"content": {"user_ids": [U_APPLE.to_string()]},
}
],
)
@defer.inlineCallbacks
def test_started_typing_remote_send(self):
self.room_members = [self.u_apple, self.u_onion]
self.room_members = [U_APPLE, U_ONION]
put_json = self.mock_http_client.put_json
put_json.expect_call_and_return(
call(
"farm",
path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu_transaction(
"m.typing",
content={
"room_id": self.room_id,
"user_id": self.u_apple.to_string(),
"typing": True,
},
),
json_data_callback=ANY,
long_retries=True,
backoff_on_404=True,
),
defer.succeed((200, "OK")),
)
yield self.handler.started_typing(
target_user=self.u_apple,
auth_user=self.u_apple,
room_id=self.room_id,
self.successResultOf(self.handler.started_typing(
target_user=U_APPLE,
auth_user=U_APPLE,
room_id=ROOM_ID,
timeout=20000,
))
put_json = self.hs.get_http_client().put_json
put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu_transaction(
"m.typing",
content={
"room_id": ROOM_ID,
"user_id": U_APPLE.to_string(),
"typing": True,
},
),
json_data_callback=ANY,
long_retries=True,
backoff_on_404=True,
)
yield put_json.await_calls()
@defer.inlineCallbacks
def test_started_typing_remote_recv(self):
self.room_members = [self.u_apple, self.u_onion]
self.room_members = [U_APPLE, U_ONION]
self.assertEquals(self.event_source.get_current_key(), 0)
(code, response) = yield self.mock_federation_resource.trigger(
(request, channel) = self.make_request(
"PUT",
"/_matrix/federation/v1/send/1000000/",
_make_edu_transaction_json(
"m.typing",
content={
"room_id": self.room_id,
"user_id": self.u_onion.to_string(),
"room_id": ROOM_ID,
"user_id": U_ONION.to_string(),
"typing": True,
},
),
federation_auth_origin=b'farm',
)
self.render(request)
self.assertEqual(channel.code, 200)
self.on_new_event.assert_has_calls(
[call('typing_key', 1, rooms=[self.room_id])]
[call('typing_key', 1, rooms=[ROOM_ID])]
)
self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events(
room_ids=[self.room_id], from_key=0
events = self.event_source.get_new_events(
room_ids=[ROOM_ID], from_key=0
)
self.assertEquals(
events[0],
[
{
"type": "m.typing",
"room_id": self.room_id,
"content": {"user_ids": [self.u_onion.to_string()]},
"room_id": ROOM_ID,
"content": {"user_ids": [U_ONION.to_string()]},
}
],
)
@defer.inlineCallbacks
def test_stopped_typing(self):
self.room_members = [self.u_apple, self.u_banana, self.u_onion]
put_json = self.mock_http_client.put_json
put_json.expect_call_and_return(
call(
"farm",
path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu_transaction(
"m.typing",
content={
"room_id": self.room_id,
"user_id": self.u_apple.to_string(),
"typing": False,
},
),
json_data_callback=ANY,
long_retries=True,
backoff_on_404=True,
),
defer.succeed((200, "OK")),
)
self.room_members = [U_APPLE, U_BANANA, U_ONION]
# Gut-wrenching
from synapse.handlers.typing import RoomMember
member = RoomMember(self.room_id, self.u_apple.to_string())
member = RoomMember(ROOM_ID, U_APPLE.to_string())
self.handler._member_typing_until[member] = 1002000
self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()])
self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()])
self.assertEquals(self.event_source.get_current_key(), 0)
yield self.handler.stopped_typing(
target_user=self.u_apple, auth_user=self.u_apple, room_id=self.room_id
)
self.successResultOf(self.handler.stopped_typing(
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
))
self.on_new_event.assert_has_calls(
[call('typing_key', 1, rooms=[self.room_id])]
[call('typing_key', 1, rooms=[ROOM_ID])]
)
yield put_json.await_calls()
put_json = self.hs.get_http_client().put_json
put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu_transaction(
"m.typing",
content={
"room_id": ROOM_ID,
"user_id": U_APPLE.to_string(),
"typing": False,
},
),
json_data_callback=ANY,
long_retries=True,
backoff_on_404=True,
)
self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events(
room_ids=[self.room_id], from_key=0
events = self.event_source.get_new_events(
room_ids=[ROOM_ID], from_key=0
)
self.assertEquals(
events[0],
[
{
"type": "m.typing",
"room_id": self.room_id,
"room_id": ROOM_ID,
"content": {"user_ids": []},
}
],
)
@defer.inlineCallbacks
def test_typing_timeout(self):
self.room_members = [self.u_apple, self.u_banana]
self.room_members = [U_APPLE, U_BANANA]
self.assertEquals(self.event_source.get_current_key(), 0)
yield self.handler.started_typing(
target_user=self.u_apple,
auth_user=self.u_apple,
room_id=self.room_id,
self.successResultOf(self.handler.started_typing(
target_user=U_APPLE,
auth_user=U_APPLE,
room_id=ROOM_ID,
timeout=10000,
)
))
self.on_new_event.assert_has_calls(
[call('typing_key', 1, rooms=[self.room_id])]
[call('typing_key', 1, rooms=[ROOM_ID])]
)
self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events(
room_ids=[self.room_id], from_key=0
events = self.event_source.get_new_events(
room_ids=[ROOM_ID], from_key=0
)
self.assertEquals(
events[0],
[
{
"type": "m.typing",
"room_id": self.room_id,
"content": {"user_ids": [self.u_apple.to_string()]},
"room_id": ROOM_ID,
"content": {"user_ids": [U_APPLE.to_string()]},
}
],
)
self.clock.advance_time(16)
self.reactor.pump([16, ])
self.on_new_event.assert_has_calls(
[call('typing_key', 2, rooms=[self.room_id])]
[call('typing_key', 2, rooms=[ROOM_ID])]
)
self.assertEquals(self.event_source.get_current_key(), 2)
events = yield self.event_source.get_new_events(
room_ids=[self.room_id], from_key=1
events = self.event_source.get_new_events(
room_ids=[ROOM_ID], from_key=1
)
self.assertEquals(
events[0],
[
{
"type": "m.typing",
"room_id": self.room_id,
"room_id": ROOM_ID,
"content": {"user_ids": []},
}
],
@@ -365,29 +341,29 @@ class TypingNotificationsTestCase(unittest.TestCase):
# SYN-230 - see if we can still set after timeout
yield self.handler.started_typing(
target_user=self.u_apple,
auth_user=self.u_apple,
room_id=self.room_id,
self.successResultOf(self.handler.started_typing(
target_user=U_APPLE,
auth_user=U_APPLE,
room_id=ROOM_ID,
timeout=10000,
)
))
self.on_new_event.assert_has_calls(
[call('typing_key', 3, rooms=[self.room_id])]
[call('typing_key', 3, rooms=[ROOM_ID])]
)
self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3)
events = yield self.event_source.get_new_events(
room_ids=[self.room_id], from_key=0
events = self.event_source.get_new_events(
room_ids=[ROOM_ID], from_key=0
)
self.assertEquals(
events[0],
[
{
"type": "m.typing",
"room_id": self.room_id,
"content": {"user_ids": [self.u_apple.to_string()]},
"room_id": ROOM_ID,
"content": {"user_ids": [U_APPLE.to_string()]},
}
],
)

View File

@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.rest.media.v1._base import get_filename_from_headers
from tests import unittest
class GetFileNameFromHeadersTests(unittest.TestCase):
# input -> expected result
TEST_CASES = {
b"inline; filename=abc.txt": u"abc.txt",
b'inline; filename="azerty"': u"azerty",
b'inline; filename="aze%20rty"': u"aze%20rty",
b'inline; filename="aze\"rty"': u'aze"rty',
b'inline; filename="azer;ty"': u"azer;ty",
b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar",
}
def tests(self):
for hdr, expected in self.TEST_CASES.items():
res = get_filename_from_headers(
{
b'Content-Disposition': [hdr],
},
)
self.assertEqual(
res, expected,
"expected output for %s to be %s but was %s" % (
hdr, expected, res,
)
)

View File

@@ -137,6 +137,7 @@ def make_request(
access_token=None,
request=SynapseRequest,
shorthand=True,
federation_auth_origin=None,
):
"""
Make a web request using the given method and path, feed it the
@@ -150,9 +151,11 @@ def make_request(
a dict.
shorthand: Whether to try and be helpful and prefix the given URL
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
Returns:
A synapse.http.site.SynapseRequest.
Tuple[synapse.http.site.SynapseRequest, channel]
"""
if not isinstance(method, bytes):
method = method.encode('ascii')
@@ -184,6 +187,11 @@ def make_request(
b"Authorization", b"Bearer " + access_token.encode('ascii')
)
if federation_auth_origin is not None:
req.requestHeaders.addRawHeader(
b"Authorization", b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
)
if content:
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
@@ -288,9 +296,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
**kwargs
)
pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
class ThreadPool:
"""
Threadless thread pool.
@@ -316,8 +321,12 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
return d
clock.threadpool = ThreadPool()
pool.threadpool = ThreadPool()
pool.running = True
if pool:
pool.runWithConnection = runWithConnection
pool.runInteraction = runInteraction
pool.threadpool = ThreadPool()
pool.running = True
return d

View File

@@ -262,6 +262,7 @@ class HomeserverTestCase(TestCase):
access_token=None,
request=SynapseRequest,
shorthand=True,
federation_auth_origin=None,
):
"""
Create a SynapseRequest at the path using the method and containing the
@@ -275,15 +276,18 @@ class HomeserverTestCase(TestCase):
a dict.
shorthand: Whether to try and be helpful and prefix the given URL
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
Returns:
A synapse.http.site.SynapseRequest.
Tuple[synapse.http.site.SynapseRequest, channel]
"""
if isinstance(content, dict):
content = json.dumps(content).encode('utf8')
return make_request(
self.reactor, method, path, content, access_token, request, shorthand
self.reactor, method, path, content, access_token, request, shorthand,
federation_auth_origin,
)
def render(self, request):

View File

@@ -29,7 +29,7 @@ from twisted.internet import defer, reactor
from synapse.api.constants import EventTypes, RoomVersions
from synapse.api.errors import CodeMessageException, cs_error
from synapse.config.server import ServerConfig
from synapse.federation.transport import server
from synapse.federation.transport import server as federation_server
from synapse.http.server import HttpServer
from synapse.server import HomeServer
from synapse.storage import DataStore
@@ -45,7 +45,9 @@ from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None)
POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None)
POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None)
POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
@@ -58,6 +60,8 @@ def setupdb():
"args": {
"database": POSTGRES_BASE_DB,
"user": POSTGRES_USER,
"host": POSTGRES_HOST,
"password": POSTGRES_PASSWORD,
"cp_min": 1,
"cp_max": 5,
},
@@ -66,7 +70,9 @@ def setupdb():
config.password_providers = []
config.database_config = pgconfig
db_engine = create_engine(pgconfig)
db_conn = db_engine.module.connect(user=POSTGRES_USER)
db_conn = db_engine.module.connect(
user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD
)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
@@ -76,7 +82,10 @@ def setupdb():
# Set up in the db
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB, user=POSTGRES_USER
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
cur = db_conn.cursor()
_get_or_create_schema_state(cur, db_engine)
@@ -86,7 +95,9 @@ def setupdb():
db_conn.close()
def _cleanup():
db_conn = db_engine.module.connect(user=POSTGRES_USER)
db_conn = db_engine.module.connect(
user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD
)
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
@@ -142,6 +153,9 @@ def default_config(name):
config.saml2_enabled = False
config.public_baseurl = None
config.default_identity_server = None
config.key_refresh_interval = 24 * 60 * 60 * 1000
config.old_signing_keys = {}
config.tls_fingerprints = []
config.use_frozen_dicts = False
@@ -186,6 +200,9 @@ def setup_test_homeserver(
Args:
cleanup_func : The function used to register a cleanup routine for
after the test.
Calling this method directly is deprecated: you should instead derive from
HomeserverTestCase.
"""
if reactor is None:
from twisted.internet import reactor
@@ -203,7 +220,14 @@ def setup_test_homeserver(
config.database_config = {
"name": "psycopg2",
"args": {"database": test_db, "cp_min": 1, "cp_max": 5},
"args": {
"database": test_db,
"host": POSTGRES_HOST,
"password": POSTGRES_PASSWORD,
"user": POSTGRES_USER,
"cp_min": 1,
"cp_max": 5,
},
}
else:
config.database_config = {
@@ -217,7 +241,10 @@ def setup_test_homeserver(
# the template database we generate in setupdb()
if datastore is None and isinstance(db_engine, PostgresEngine):
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB, user=POSTGRES_USER
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn.autocommit = True
cur = db_conn.cursor()
@@ -267,7 +294,10 @@ def setup_test_homeserver(
# Drop the test database
db_conn = db_engine.module.connect(
database=POSTGRES_BASE_DB, user=POSTGRES_USER
database=POSTGRES_BASE_DB,
user=POSTGRES_USER,
host=POSTGRES_HOST,
password=POSTGRES_PASSWORD,
)
db_conn.autocommit = True
cur = db_conn.cursor()
@@ -324,23 +354,27 @@ def setup_test_homeserver(
fed = kargs.get("resource_for_federation", None)
if fed:
server.register_servlets(
hs,
resource=fed,
authenticator=server.Authenticator(hs),
ratelimiter=FederationRateLimiter(
hs.get_clock(),
window_size=hs.config.federation_rc_window_size,
sleep_limit=hs.config.federation_rc_sleep_limit,
sleep_msec=hs.config.federation_rc_sleep_delay,
reject_limit=hs.config.federation_rc_reject_limit,
concurrent_requests=hs.config.federation_rc_concurrent,
),
)
register_federation_servlets(hs, fed)
defer.returnValue(hs)
def register_federation_servlets(hs, resource):
federation_server.register_servlets(
hs,
resource=resource,
authenticator=federation_server.Authenticator(hs),
ratelimiter=FederationRateLimiter(
hs.get_clock(),
window_size=hs.config.federation_rc_window_size,
sleep_limit=hs.config.federation_rc_sleep_limit,
sleep_msec=hs.config.federation_rc_sleep_delay,
reject_limit=hs.config.federation_rc_reject_limit,
concurrent_requests=hs.config.federation_rc_concurrent,
),
)
def get_mock_call_args(pattern_func, mock_func):
""" Return the arguments the mock function was called with interpreted
by the pattern functions argument list.
@@ -457,6 +491,9 @@ class MockKey(object):
def verify(self, message, sig):
assert sig == b"\x9a\x87$"
def encode(self):
return b"<fake_encoded_key>"
class MockClock(object):
now = 1000
@@ -486,7 +523,7 @@ class MockClock(object):
return t
def looping_call(self, function, interval):
self.loopers.append([function, interval / 1000., self.now])
self.loopers.append([function, interval / 1000.0, self.now])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
@@ -522,7 +559,7 @@ class MockClock(object):
looped[2] = self.now
def advance_time_msec(self, ms):
self.advance_time(ms / 1000.)
self.advance_time(ms / 1000.0)
def time_bound_deferred(self, d, *args, **kwargs):
# We don't bother timing things out for now.
@@ -631,7 +668,7 @@ def create_room(hs, room_id, creator_id):
"sender": creator_id,
"room_id": room_id,
"content": {},
}
},
)
event, context = yield event_creation_handler.create_new_client_event(builder)

View File

@@ -118,6 +118,9 @@ commands =
python -m towncrier.check --compare-with=origin/develop
basepython = python3.6
[testenv:check-sampleconfig]
commands = {toxinidir}/scripts-dev/generate_sample_config --check
[testenv:codecov]
skip_install = True
deps =