Compare commits
50 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ca8af3f06 | ||
|
|
b050a10871 | ||
|
|
9e8bca5667 | ||
|
|
aa06d26ae0 | ||
|
|
c3c542bb4a | ||
|
|
48583cef7e | ||
|
|
cd7110c869 | ||
|
|
eaa9f43603 | ||
|
|
c7325776a7 | ||
|
|
00b0e8b7df | ||
|
|
bfa7d46a10 | ||
|
|
157e5a8f27 | ||
|
|
daa10e3e66 | ||
|
|
2db49ea476 | ||
|
|
b29693a30b | ||
|
|
a84b8d56c2 | ||
|
|
8e28bc5eee | ||
|
|
0d2d046709 | ||
|
|
d1523aed6b | ||
|
|
aba5eeabd5 | ||
|
|
856c83f5f8 | ||
|
|
8b63fe4c26 | ||
|
|
fbc047f2a5 | ||
|
|
2c3548d9d8 | ||
|
|
3064952939 | ||
|
|
1beebe916f | ||
|
|
ac61b45a75 | ||
|
|
b131cc77df | ||
|
|
68f47d6744 | ||
|
|
f2a753ea38 | ||
|
|
76550c58d2 | ||
|
|
8267034a63 | ||
|
|
3134964054 | ||
|
|
46b0151524 | ||
|
|
95840d84d4 | ||
|
|
54f9ce11a7 | ||
|
|
d4dc527a1a | ||
|
|
1b2940b3bd | ||
|
|
1e315017d3 | ||
|
|
b5c13df0c4 | ||
|
|
4cff9376f7 | ||
|
|
71ef5fc411 | ||
|
|
b183fef9ac | ||
|
|
7590e9fa28 | ||
|
|
6870fc496f | ||
|
|
09fc34c935 | ||
|
|
25814921f1 | ||
|
|
313987187e | ||
|
|
56f4ece778 | ||
|
|
71b625d808 |
13
.buildkite/.env
Normal file
13
.buildkite/.env
Normal file
@@ -0,0 +1,13 @@
|
||||
CI
|
||||
BUILDKITE
|
||||
BUILDKITE_BUILD_NUMBER
|
||||
BUILDKITE_BRANCH
|
||||
BUILDKITE_BUILD_NUMBER
|
||||
BUILDKITE_JOB_ID
|
||||
BUILDKITE_BUILD_URL
|
||||
BUILDKITE_PROJECT_SLUG
|
||||
BUILDKITE_COMMIT
|
||||
BUILDKITE_PULL_REQUEST
|
||||
BUILDKITE_TAG
|
||||
CODECOV_TOKEN
|
||||
TRIAL_FLAGS
|
||||
21
.buildkite/docker-compose.py27.pg94.yaml
Normal file
21
.buildkite/docker-compose.py27.pg94.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
version: '3.1'
|
||||
|
||||
services:
|
||||
|
||||
postgres:
|
||||
image: postgres:9.4
|
||||
environment:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
|
||||
testenv:
|
||||
image: python:2.7
|
||||
depends_on:
|
||||
- postgres
|
||||
env_file: .env
|
||||
environment:
|
||||
SYNAPSE_POSTGRES_HOST: postgres
|
||||
SYNAPSE_POSTGRES_USER: postgres
|
||||
SYNAPSE_POSTGRES_PASSWORD: postgres
|
||||
working_dir: /app
|
||||
volumes:
|
||||
- ..:/app
|
||||
21
.buildkite/docker-compose.py27.pg95.yaml
Normal file
21
.buildkite/docker-compose.py27.pg95.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
version: '3.1'
|
||||
|
||||
services:
|
||||
|
||||
postgres:
|
||||
image: postgres:9.5
|
||||
environment:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
|
||||
testenv:
|
||||
image: python:2.7
|
||||
depends_on:
|
||||
- postgres
|
||||
env_file: .env
|
||||
environment:
|
||||
SYNAPSE_POSTGRES_HOST: postgres
|
||||
SYNAPSE_POSTGRES_USER: postgres
|
||||
SYNAPSE_POSTGRES_PASSWORD: postgres
|
||||
working_dir: /app
|
||||
volumes:
|
||||
- ..:/app
|
||||
21
.buildkite/docker-compose.py35.pg94.yaml
Normal file
21
.buildkite/docker-compose.py35.pg94.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
version: '3.1'
|
||||
|
||||
services:
|
||||
|
||||
postgres:
|
||||
image: postgres:9.4
|
||||
environment:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
|
||||
testenv:
|
||||
image: python:3.5
|
||||
depends_on:
|
||||
- postgres
|
||||
env_file: .env
|
||||
environment:
|
||||
SYNAPSE_POSTGRES_HOST: postgres
|
||||
SYNAPSE_POSTGRES_USER: postgres
|
||||
SYNAPSE_POSTGRES_PASSWORD: postgres
|
||||
working_dir: /app
|
||||
volumes:
|
||||
- ..:/app
|
||||
21
.buildkite/docker-compose.py35.pg95.yaml
Normal file
21
.buildkite/docker-compose.py35.pg95.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
version: '3.1'
|
||||
|
||||
services:
|
||||
|
||||
postgres:
|
||||
image: postgres:9.5
|
||||
environment:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
|
||||
testenv:
|
||||
image: python:3.5
|
||||
depends_on:
|
||||
- postgres
|
||||
env_file: .env
|
||||
environment:
|
||||
SYNAPSE_POSTGRES_HOST: postgres
|
||||
SYNAPSE_POSTGRES_USER: postgres
|
||||
SYNAPSE_POSTGRES_PASSWORD: postgres
|
||||
working_dir: /app
|
||||
volumes:
|
||||
- ..:/app
|
||||
21
.buildkite/docker-compose.py37.pg11.yaml
Normal file
21
.buildkite/docker-compose.py37.pg11.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
version: '3.1'
|
||||
|
||||
services:
|
||||
|
||||
postgres:
|
||||
image: postgres:11
|
||||
environment:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
|
||||
testenv:
|
||||
image: python:3.7
|
||||
depends_on:
|
||||
- postgres
|
||||
env_file: .env
|
||||
environment:
|
||||
SYNAPSE_POSTGRES_HOST: postgres
|
||||
SYNAPSE_POSTGRES_USER: postgres
|
||||
SYNAPSE_POSTGRES_PASSWORD: postgres
|
||||
working_dir: /app
|
||||
volumes:
|
||||
- ..:/app
|
||||
21
.buildkite/docker-compose.py37.pg95.yaml
Normal file
21
.buildkite/docker-compose.py37.pg95.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
version: '3.1'
|
||||
|
||||
services:
|
||||
|
||||
postgres:
|
||||
image: postgres:9.5
|
||||
environment:
|
||||
POSTGRES_PASSWORD: postgres
|
||||
|
||||
testenv:
|
||||
image: python:3.7
|
||||
depends_on:
|
||||
- postgres
|
||||
env_file: .env
|
||||
environment:
|
||||
SYNAPSE_POSTGRES_HOST: postgres
|
||||
SYNAPSE_POSTGRES_USER: postgres
|
||||
SYNAPSE_POSTGRES_PASSWORD: postgres
|
||||
working_dir: /app
|
||||
volumes:
|
||||
- ..:/app
|
||||
157
.buildkite/pipeline.yml
Normal file
157
.buildkite/pipeline.yml
Normal file
@@ -0,0 +1,157 @@
|
||||
env:
|
||||
CODECOV_TOKEN: "2dd7eb9b-0eda-45fe-a47c-9b5ac040045f"
|
||||
|
||||
steps:
|
||||
- command:
|
||||
- "python -m pip install tox"
|
||||
- "tox -e pep8"
|
||||
label: "\U0001F9F9 PEP-8"
|
||||
plugins:
|
||||
- docker#v3.0.1:
|
||||
image: "python:3.6"
|
||||
|
||||
- command:
|
||||
- "python -m pip install tox"
|
||||
- "tox -e packaging"
|
||||
label: "\U0001F9F9 packaging"
|
||||
plugins:
|
||||
- docker#v3.0.1:
|
||||
image: "python:3.6"
|
||||
|
||||
- command:
|
||||
- "python -m pip install tox"
|
||||
- "tox -e check_isort"
|
||||
label: "\U0001F9F9 isort"
|
||||
plugins:
|
||||
- docker#v3.0.1:
|
||||
image: "python:3.6"
|
||||
|
||||
- command:
|
||||
- "python -m pip install tox"
|
||||
- "scripts-dev/check-newsfragment"
|
||||
label: ":newspaper: Newsfile"
|
||||
branches: "!master !develop !release-*"
|
||||
plugins:
|
||||
- docker#v3.0.1:
|
||||
image: "python:3.6"
|
||||
propagate-environment: true
|
||||
|
||||
- wait
|
||||
|
||||
- command:
|
||||
- "python -m pip install tox"
|
||||
- "tox -e check-sampleconfig"
|
||||
label: "\U0001F9F9 check-sample-config"
|
||||
plugins:
|
||||
- docker#v3.0.1:
|
||||
image: "python:3.6"
|
||||
|
||||
- command:
|
||||
- "python -m pip install tox"
|
||||
- "tox -e py27,codecov"
|
||||
label: ":python: 2.7 / SQLite"
|
||||
env:
|
||||
TRIAL_FLAGS: "-j 2"
|
||||
plugins:
|
||||
- docker#v3.0.1:
|
||||
image: "python:2.7"
|
||||
propagate-environment: true
|
||||
|
||||
- command:
|
||||
- "python -m pip install tox"
|
||||
- "tox -e py35,codecov"
|
||||
label: ":python: 3.5 / SQLite"
|
||||
env:
|
||||
TRIAL_FLAGS: "-j 2"
|
||||
plugins:
|
||||
- docker#v3.0.1:
|
||||
image: "python:3.5"
|
||||
propagate-environment: true
|
||||
|
||||
- command:
|
||||
- "python -m pip install tox"
|
||||
- "tox -e py36,codecov"
|
||||
label: ":python: 3.6 / SQLite"
|
||||
env:
|
||||
TRIAL_FLAGS: "-j 2"
|
||||
plugins:
|
||||
- docker#v3.0.1:
|
||||
image: "python:3.6"
|
||||
propagate-environment: true
|
||||
|
||||
- command:
|
||||
- "python -m pip install tox"
|
||||
- "tox -e py37,codecov"
|
||||
label: ":python: 3.7 / SQLite"
|
||||
env:
|
||||
TRIAL_FLAGS: "-j 2"
|
||||
plugins:
|
||||
- docker#v3.0.1:
|
||||
image: "python:3.7"
|
||||
propagate-environment: true
|
||||
|
||||
- label: ":python: 2.7 / :postgres: 9.4"
|
||||
env:
|
||||
TRIAL_FLAGS: "-j 4"
|
||||
command:
|
||||
- "bash -c 'python -m pip install tox && python -m tox -e py27-postgres,codecov'"
|
||||
plugins:
|
||||
- docker-compose#v2.1.0:
|
||||
run: testenv
|
||||
config:
|
||||
- .buildkite/docker-compose.py27.pg94.yaml
|
||||
|
||||
- label: ":python: 2.7 / :postgres: 9.5"
|
||||
env:
|
||||
TRIAL_FLAGS: "-j 4"
|
||||
command:
|
||||
- "bash -c 'python -m pip install tox && python -m tox -e py27-postgres,codecov'"
|
||||
plugins:
|
||||
- docker-compose#v2.1.0:
|
||||
run: testenv
|
||||
config:
|
||||
- .buildkite/docker-compose.py27.pg95.yaml
|
||||
|
||||
- label: ":python: 3.5 / :postgres: 9.4"
|
||||
env:
|
||||
TRIAL_FLAGS: "-j 4"
|
||||
command:
|
||||
- "bash -c 'python -m pip install tox && python -m tox -e py35-postgres,codecov'"
|
||||
plugins:
|
||||
- docker-compose#v2.1.0:
|
||||
run: testenv
|
||||
config:
|
||||
- .buildkite/docker-compose.py35.pg94.yaml
|
||||
|
||||
- label: ":python: 3.5 / :postgres: 9.5"
|
||||
env:
|
||||
TRIAL_FLAGS: "-j 4"
|
||||
command:
|
||||
- "bash -c 'python -m pip install tox && python -m tox -e py35-postgres,codecov'"
|
||||
plugins:
|
||||
- docker-compose#v2.1.0:
|
||||
run: testenv
|
||||
config:
|
||||
- .buildkite/docker-compose.py35.pg95.yaml
|
||||
|
||||
- label: ":python: 3.7 / :postgres: 9.5"
|
||||
env:
|
||||
TRIAL_FLAGS: "-j 4"
|
||||
command:
|
||||
- "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,codecov'"
|
||||
plugins:
|
||||
- docker-compose#v2.1.0:
|
||||
run: testenv
|
||||
config:
|
||||
- .buildkite/docker-compose.py37.pg95.yaml
|
||||
|
||||
- label: ":python: 3.7 / :postgres: 11"
|
||||
env:
|
||||
TRIAL_FLAGS: "-j 4"
|
||||
command:
|
||||
- "bash -c 'python -m pip install tox && python -m tox -e py37-postgres,codecov'"
|
||||
plugins:
|
||||
- docker-compose#v2.1.0:
|
||||
run: testenv
|
||||
config:
|
||||
- .buildkite/docker-compose.py37.pg11.yaml
|
||||
97
.travis.yml
97
.travis.yml
@@ -1,97 +0,0 @@
|
||||
dist: xenial
|
||||
language: python
|
||||
|
||||
cache:
|
||||
directories:
|
||||
# we only bother to cache the wheels; parts of the http cache get
|
||||
# invalidated every build (because they get served with a max-age of 600
|
||||
# seconds), which means that we end up re-uploading the whole cache for
|
||||
# every build, which is time-consuming In any case, it's not obvious that
|
||||
# downloading the cache from S3 would be much faster than downloading the
|
||||
# originals from pypi.
|
||||
#
|
||||
- $HOME/.cache/pip/wheels
|
||||
|
||||
# don't clone the whole repo history, one commit will do
|
||||
git:
|
||||
depth: 1
|
||||
|
||||
# only build branches we care about (PRs are built seperately)
|
||||
branches:
|
||||
only:
|
||||
- master
|
||||
- develop
|
||||
- /^release-v/
|
||||
- rav/pg95
|
||||
|
||||
# When running the tox environments that call Twisted Trial, we can pass the -j
|
||||
# flag to run the tests concurrently. We set this to 2 for CPU bound tests
|
||||
# (SQLite) and 4 for I/O bound tests (PostgreSQL).
|
||||
matrix:
|
||||
fast_finish: true
|
||||
include:
|
||||
- name: "pep8"
|
||||
python: 3.6
|
||||
env: TOX_ENV="pep8,check_isort,packaging"
|
||||
|
||||
- name: "py2.7 / sqlite"
|
||||
python: 2.7
|
||||
env: TOX_ENV=py27,codecov TRIAL_FLAGS="-j 2"
|
||||
|
||||
- name: "py2.7 / sqlite / olddeps"
|
||||
python: 2.7
|
||||
env: TOX_ENV=py27-old TRIAL_FLAGS="-j 2"
|
||||
|
||||
- name: "py2.7 / postgres9.5"
|
||||
python: 2.7
|
||||
addons:
|
||||
postgresql: "9.5"
|
||||
env: TOX_ENV=py27-postgres,codecov TRIAL_FLAGS="-j 4"
|
||||
services:
|
||||
- postgresql
|
||||
|
||||
- name: "py3.5 / sqlite"
|
||||
python: 3.5
|
||||
env: TOX_ENV=py35,codecov TRIAL_FLAGS="-j 2"
|
||||
|
||||
- name: "py3.7 / sqlite"
|
||||
python: 3.7
|
||||
env: TOX_ENV=py37,codecov TRIAL_FLAGS="-j 2"
|
||||
|
||||
- name: "py3.7 / postgres9.4"
|
||||
python: 3.7
|
||||
addons:
|
||||
postgresql: "9.4"
|
||||
env: TOX_ENV=py37-postgres TRIAL_FLAGS="-j 4"
|
||||
services:
|
||||
- postgresql
|
||||
|
||||
- name: "py3.7 / postgres9.5"
|
||||
python: 3.7
|
||||
addons:
|
||||
postgresql: "9.5"
|
||||
env: TOX_ENV=py37-postgres,codecov TRIAL_FLAGS="-j 4"
|
||||
services:
|
||||
- postgresql
|
||||
|
||||
- # we only need to check for the newsfragment if it's a PR build
|
||||
if: type = pull_request
|
||||
name: "check-newsfragment"
|
||||
python: 3.6
|
||||
script: scripts-dev/check-newsfragment
|
||||
|
||||
install:
|
||||
# this just logs the postgres version we will be testing against (if any)
|
||||
- psql -At -U postgres -c 'select version();' || true
|
||||
|
||||
- pip install tox
|
||||
|
||||
# if we don't have python3.6 in this environment, travis unhelpfully gives us
|
||||
# a `python3.6` on our path which does nothing but spit out a warning. Tox
|
||||
# tries to run it (even if we're not running a py36 env), so the build logs
|
||||
# then have warnings which look like errors. To reduce the noise, remove the
|
||||
# non-functional python3.6.
|
||||
- ( ! command -v python3.6 || python3.6 --version ) &>/dev/null || rm -f $(command -v python3.6)
|
||||
|
||||
script:
|
||||
- tox -e $TOX_ENV
|
||||
@@ -39,6 +39,7 @@ prune .circleci
|
||||
prune .coveragerc
|
||||
prune debian
|
||||
prune .codecov.yml
|
||||
prune .buildkite
|
||||
|
||||
exclude jenkins*
|
||||
recursive-exclude jenkins *.sh
|
||||
|
||||
1
changelog.d/4699.bugfix
Normal file
1
changelog.d/4699.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix attempting to paginate in rooms where server cannot see any events, to avoid unnecessarily pulling in lots of redacted events.
|
||||
1
changelog.d/4740.bugfix
Normal file
1
changelog.d/4740.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
'event_id' is now a required parameter in federated state requests, as per the matrix spec.
|
||||
1
changelog.d/4749.bugfix
Normal file
1
changelog.d/4749.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix tightloop over connecting to replication server.
|
||||
1
changelog.d/4752.misc
Normal file
1
changelog.d/4752.misc
Normal file
@@ -0,0 +1 @@
|
||||
Change from TravisCI to Buildkite for CI.
|
||||
1
changelog.d/4757.feature
Normal file
1
changelog.d/4757.feature
Normal file
@@ -0,0 +1 @@
|
||||
Move server key queries to federation reader.
|
||||
1
changelog.d/4757.misc
Normal file
1
changelog.d/4757.misc
Normal file
@@ -0,0 +1 @@
|
||||
When presence is disabled don't send over replication.
|
||||
1
changelog.d/4759.feature
Normal file
1
changelog.d/4759.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add support for /account/3pid REST endpoint to client_reader worker.
|
||||
1
changelog.d/4763.bugfix
Normal file
1
changelog.d/4763.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix parsing of Content-Disposition headers on remote media requests and URL previews.
|
||||
1
changelog.d/4765.misc
Normal file
1
changelog.d/4765.misc
Normal file
@@ -0,0 +1 @@
|
||||
Minor docstring fixes for MatrixFederationAgent.
|
||||
1
changelog.d/4770.misc
Normal file
1
changelog.d/4770.misc
Normal file
@@ -0,0 +1 @@
|
||||
Optimise EDU transmission for the federation_sender worker.
|
||||
1
changelog.d/4771.misc
Normal file
1
changelog.d/4771.misc
Normal file
@@ -0,0 +1 @@
|
||||
Update test_typing to use HomeserverTestCase.
|
||||
1
changelog.d/4776.bugfix
Normal file
1
changelog.d/4776.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix incorrect log about not persisting duplicate state event.
|
||||
1
changelog.d/4790.bugfix
Normal file
1
changelog.d/4790.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix v4v6 option in HAProxy example config. Contributed by Flakebi.
|
||||
1
changelog.d/4791.feature
Normal file
1
changelog.d/4791.feature
Normal file
@@ -0,0 +1 @@
|
||||
Include a default configuration file in the 'docs' directory.
|
||||
1
changelog.d/4796.feature
Normal file
1
changelog.d/4796.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add support for /keys/query and /keys/changes REST endpoints to client_reader worker.
|
||||
1
changelog.d/4797.misc
Normal file
1
changelog.d/4797.misc
Normal file
@@ -0,0 +1 @@
|
||||
Clean up read-receipt handling.
|
||||
1
changelog.d/4798.misc
Normal file
1
changelog.d/4798.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add some debug about processing read receipts.
|
||||
1
changelog.d/4799.misc
Normal file
1
changelog.d/4799.misc
Normal file
@@ -0,0 +1 @@
|
||||
Clean up some replication code.
|
||||
7
docs/.sample_config_header.yaml
Normal file
7
docs/.sample_config_header.yaml
Normal file
@@ -0,0 +1,7 @@
|
||||
# This file is a reference to the configuration options which can be set in
|
||||
# homeserver.yaml.
|
||||
#
|
||||
# Note that it is not quite ready to be used as-is. If you are starting from
|
||||
# scratch, it is easier to generate the config files following the instructions
|
||||
# in INSTALL.md.
|
||||
|
||||
@@ -88,18 +88,16 @@ Let's assume that we expect clients to connect to our server at
|
||||
* HAProxy::
|
||||
|
||||
frontend https
|
||||
bind 0.0.0.0:443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
|
||||
bind :::443 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
|
||||
|
||||
bind :::443 v4v6 ssl crt /etc/ssl/haproxy/ strict-sni alpn h2,http/1.1
|
||||
|
||||
# Matrix client traffic
|
||||
acl matrix hdr(host) -i matrix.example.com
|
||||
use_backend matrix if matrix
|
||||
|
||||
|
||||
frontend matrix-federation
|
||||
bind 0.0.0.0:8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
|
||||
bind :::8448 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
|
||||
bind :::8448 v4v6 ssl crt /etc/ssl/haproxy/synapse.pem alpn h2,http/1.1
|
||||
default_backend matrix
|
||||
|
||||
|
||||
backend matrix
|
||||
server matrix 127.0.0.1:8008
|
||||
|
||||
|
||||
1041
docs/sample_config.yaml
Normal file
1041
docs/sample_config.yaml
Normal file
File diff suppressed because it is too large
Load Diff
@@ -188,7 +188,9 @@ RDATA (S)
|
||||
A single update in a stream
|
||||
|
||||
POSITION (S)
|
||||
The position of the stream has been updated
|
||||
The position of the stream has been updated. Sent to the client after all
|
||||
missing updates for a stream have been sent to the client and they're now
|
||||
up to date.
|
||||
|
||||
ERROR (S, C)
|
||||
There was an error
|
||||
|
||||
@@ -182,6 +182,7 @@ endpoints matching the following regular expressions::
|
||||
^/_matrix/federation/v1/event_auth/
|
||||
^/_matrix/federation/v1/exchange_third_party_invite/
|
||||
^/_matrix/federation/v1/send/
|
||||
^/_matrix/key/v2/query
|
||||
|
||||
The above endpoints should all be routed to the federation_reader worker by the
|
||||
reverse-proxy configuration.
|
||||
@@ -223,6 +224,9 @@ following regular expressions::
|
||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
|
||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
|
||||
^/_matrix/client/(api/v1|r0|unstable)/login$
|
||||
^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
|
||||
^/_matrix/client/(api/v1|r0|unstable)/keys/query$
|
||||
^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
|
||||
|
||||
Additionally, the following REST endpoints can be handled, but all requests must
|
||||
be routed to the same instance::
|
||||
|
||||
18
scripts-dev/generate_sample_config
Executable file
18
scripts-dev/generate_sample_config
Executable file
@@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Update/check the docs/sample_config.yaml
|
||||
|
||||
set -e
|
||||
|
||||
cd `dirname $0`/..
|
||||
|
||||
SAMPLE_CONFIG="docs/sample_config.yaml"
|
||||
|
||||
if [ "$1" == "--check" ]; then
|
||||
diff -u "$SAMPLE_CONFIG" <(./scripts/generate_config --header-file docs/.sample_config_header.yaml) >/dev/null || {
|
||||
echo -e "\e[1m\e[31m$SAMPLE_CONFIG is not up-to-date. Regenerate it with \`scripts-dev/generate_sample_config\`.\e[0m" >&2
|
||||
exit 1
|
||||
}
|
||||
else
|
||||
./scripts/generate_config --header-file docs/.sample_config_header.yaml -o "$SAMPLE_CONFIG"
|
||||
fi
|
||||
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
@@ -50,6 +51,13 @@ if __name__ == "__main__":
|
||||
help="File to write the configuration to. Default: stdout",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--header-file",
|
||||
type=argparse.FileType('r'),
|
||||
help="File from which to read a header, which will be printed before the "
|
||||
"generated config.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
report_stats = args.report_stats
|
||||
@@ -64,4 +72,7 @@ if __name__ == "__main__":
|
||||
report_stats=report_stats,
|
||||
)
|
||||
|
||||
if args.header_file:
|
||||
shutil.copyfileobj(args.header_file, args.output_file)
|
||||
|
||||
args.output_file.write(conf)
|
||||
|
||||
@@ -27,4 +27,4 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "0.99.2"
|
||||
__version__ = "0.99.2.post1"
|
||||
|
||||
@@ -33,9 +33,13 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
|
||||
from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
||||
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.keys import SlavedKeyStore
|
||||
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
@@ -48,6 +52,8 @@ from synapse.rest.client.v1.room import (
|
||||
RoomMemberListRestServlet,
|
||||
RoomStateRestServlet,
|
||||
)
|
||||
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
|
||||
from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
|
||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
@@ -60,6 +66,10 @@ logger = logging.getLogger("synapse.app.client_reader")
|
||||
|
||||
|
||||
class ClientReaderSlavedStore(
|
||||
SlavedDeviceInboxStore,
|
||||
SlavedDeviceStore,
|
||||
SlavedReceiptsStore,
|
||||
SlavedPushRuleStore,
|
||||
SlavedAccountDataStore,
|
||||
SlavedEventStore,
|
||||
SlavedKeyStore,
|
||||
@@ -96,6 +106,9 @@ class ClientReaderServer(HomeServer):
|
||||
RoomEventContextServlet(self).register(resource)
|
||||
RegisterRestServlet(self).register(resource)
|
||||
LoginRestServlet(self).register(resource)
|
||||
ThreepidRestServlet(self).register(resource)
|
||||
KeyQueryServlet(self).register(resource)
|
||||
KeyChangesServlet(self).register(resource)
|
||||
|
||||
resources.update({
|
||||
"/_matrix/client/r0": resource,
|
||||
|
||||
@@ -21,7 +21,7 @@ from twisted.web.resource import NoResource
|
||||
|
||||
import synapse
|
||||
from synapse import events
|
||||
from synapse.api.urls import FEDERATION_PREFIX
|
||||
from synapse.api.urls import FEDERATION_PREFIX, SERVER_KEY_V2_PREFIX
|
||||
from synapse.app import _base
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
@@ -44,6 +44,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
@@ -99,6 +100,9 @@ class FederationReaderServer(HomeServer):
|
||||
),
|
||||
})
|
||||
|
||||
if name in ["keys", "federation"]:
|
||||
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
|
||||
|
||||
root_resource = create_resource_tree(resources, NoResource())
|
||||
|
||||
_base.listen_tcp(
|
||||
|
||||
@@ -180,9 +180,7 @@ class Config(object):
|
||||
Returns:
|
||||
str: the yaml config file
|
||||
"""
|
||||
default_config = "# vim:ft=yaml\n"
|
||||
|
||||
default_config += "\n\n".join(
|
||||
default_config = "\n\n".join(
|
||||
dedent(conf)
|
||||
for conf in self.invoke_all(
|
||||
"default_config",
|
||||
@@ -297,19 +295,26 @@ class Config(object):
|
||||
"Must specify a server_name to a generate config for."
|
||||
" Pass -H server.name."
|
||||
)
|
||||
|
||||
config_str = obj.generate_config(
|
||||
config_dir_path=config_dir_path,
|
||||
data_dir_path=os.getcwd(),
|
||||
server_name=server_name,
|
||||
report_stats=(config_args.report_stats == "yes"),
|
||||
generate_secrets=True,
|
||||
)
|
||||
|
||||
if not cls.path_exists(config_dir_path):
|
||||
os.makedirs(config_dir_path)
|
||||
with open(config_path, "w") as config_file:
|
||||
config_str = obj.generate_config(
|
||||
config_dir_path=config_dir_path,
|
||||
data_dir_path=os.getcwd(),
|
||||
server_name=server_name,
|
||||
report_stats=(config_args.report_stats == "yes"),
|
||||
generate_secrets=True,
|
||||
config_file.write(
|
||||
"# vim:ft=yaml\n\n"
|
||||
)
|
||||
config = yaml.load(config_str)
|
||||
obj.invoke_all("generate_files", config)
|
||||
config_file.write(config_str)
|
||||
|
||||
config = yaml.load(config_str)
|
||||
obj.invoke_all("generate_files", config)
|
||||
|
||||
print(
|
||||
(
|
||||
"A config file has been generated in %r for server name"
|
||||
|
||||
@@ -49,7 +49,8 @@ class DatabaseConfig(Config):
|
||||
def default_config(self, data_dir_path, **kwargs):
|
||||
database_path = os.path.join(data_dir_path, "homeserver.db")
|
||||
return """\
|
||||
# Database configuration
|
||||
## Database ##
|
||||
|
||||
database:
|
||||
# The database engine name
|
||||
name: "sqlite3"
|
||||
|
||||
@@ -81,7 +81,9 @@ class LoggingConfig(Config):
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
log_config = os.path.join(config_dir_path, server_name + ".log.config")
|
||||
return """
|
||||
return """\
|
||||
## Logging ##
|
||||
|
||||
# A yaml python logging config file
|
||||
#
|
||||
log_config: "%(log_config)s"
|
||||
|
||||
@@ -260,9 +260,11 @@ class ServerConfig(Config):
|
||||
# This is used by remote servers to connect to this server,
|
||||
# e.g. matrix.org, localhost:8080, etc.
|
||||
# This is also the last part of your UserID.
|
||||
#
|
||||
server_name: "%(server_name)s"
|
||||
|
||||
# When running as a daemon, the file to store the pid in
|
||||
#
|
||||
pid_file: %(pid_file)s
|
||||
|
||||
# CPU affinity mask. Setting this restricts the CPUs on which the
|
||||
@@ -304,9 +306,11 @@ class ServerConfig(Config):
|
||||
# Set the soft limit on the number of file descriptors synapse can use
|
||||
# Zero is used to indicate synapse should set the soft limit to the
|
||||
# hard limit.
|
||||
#
|
||||
soft_file_limit: 0
|
||||
|
||||
# Set to false to disable presence tracking on this homeserver.
|
||||
#
|
||||
use_presence: true
|
||||
|
||||
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
|
||||
|
||||
@@ -886,6 +886,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
|
||||
def on_edu(self, edu_type, origin, content):
|
||||
"""Overrides FederationHandlerRegistry
|
||||
"""
|
||||
if not self.config.use_presence and edu_type == "m.presence":
|
||||
return
|
||||
|
||||
handler = self.edu_handlers.get(edu_type)
|
||||
if handler:
|
||||
return super(ReplicationFederationHandlerRegistry, self).on_edu(
|
||||
|
||||
@@ -159,8 +159,12 @@ class FederationRemoteSendQueue(object):
|
||||
# stream.
|
||||
pass
|
||||
|
||||
def send_edu(self, destination, edu_type, content, key=None):
|
||||
def build_and_send_edu(self, destination, edu_type, content, key=None):
|
||||
"""As per TransactionQueue"""
|
||||
if destination == self.server_name:
|
||||
logger.info("Not sending EDU to ourselves")
|
||||
return
|
||||
|
||||
pos = self._next_pos()
|
||||
|
||||
edu = Edu(
|
||||
@@ -465,15 +469,11 @@ def process_rows_for_federation(transaction_queue, rows):
|
||||
|
||||
for destination, edu_map in iteritems(buff.keyed_edus):
|
||||
for key, edu in edu_map.items():
|
||||
transaction_queue.send_edu(
|
||||
edu.destination, edu.edu_type, edu.content, key=key,
|
||||
)
|
||||
transaction_queue.send_edu(edu, key)
|
||||
|
||||
for destination, edu_list in iteritems(buff.edus):
|
||||
for edu in edu_list:
|
||||
transaction_queue.send_edu(
|
||||
edu.destination, edu.edu_type, edu.content, key=None,
|
||||
)
|
||||
transaction_queue.send_edu(edu, None)
|
||||
|
||||
for destination in buff.device_destinations:
|
||||
transaction_queue.send_device_messages(destination)
|
||||
|
||||
@@ -361,7 +361,19 @@ class TransactionQueue(object):
|
||||
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
def send_edu(self, destination, edu_type, content, key=None):
|
||||
def build_and_send_edu(self, destination, edu_type, content, key=None):
|
||||
"""Construct an Edu object, and queue it for sending
|
||||
|
||||
Args:
|
||||
destination (str): name of server to send to
|
||||
edu_type (str): type of EDU to send
|
||||
content (dict): content of EDU
|
||||
key (Any|None): clobbering key for this edu
|
||||
"""
|
||||
if destination == self.server_name:
|
||||
logger.info("Not sending EDU to ourselves")
|
||||
return
|
||||
|
||||
edu = Edu(
|
||||
origin=self.server_name,
|
||||
destination=destination,
|
||||
@@ -369,18 +381,23 @@ class TransactionQueue(object):
|
||||
content=content,
|
||||
)
|
||||
|
||||
if destination == self.server_name:
|
||||
logger.info("Not sending EDU to ourselves")
|
||||
return
|
||||
self.send_edu(edu, key)
|
||||
|
||||
def send_edu(self, edu, key):
|
||||
"""Queue an EDU for sending
|
||||
|
||||
Args:
|
||||
edu (Edu): edu to send
|
||||
key (Any|None): clobbering key for this edu
|
||||
"""
|
||||
if key:
|
||||
self.pending_edus_keyed_by_dest.setdefault(
|
||||
destination, {}
|
||||
edu.destination, {}
|
||||
)[(edu.edu_type, key)] = edu
|
||||
else:
|
||||
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
|
||||
self.pending_edus_by_dest.setdefault(edu.destination, []).append(edu)
|
||||
|
||||
self._attempt_new_transaction(destination)
|
||||
self._attempt_new_transaction(edu.destination)
|
||||
|
||||
def send_device_messages(self, destination):
|
||||
if destination == self.server_name:
|
||||
|
||||
@@ -393,7 +393,7 @@ class FederationStateServlet(BaseFederationServlet):
|
||||
return self.handler.on_context_state_request(
|
||||
origin,
|
||||
context,
|
||||
parse_string_from_args(query, "event_id", None),
|
||||
parse_string_from_args(query, "event_id", None, required=True),
|
||||
)
|
||||
|
||||
|
||||
@@ -404,7 +404,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
|
||||
return self.handler.on_state_ids_request(
|
||||
origin,
|
||||
room_id,
|
||||
parse_string_from_args(query, "event_id", None),
|
||||
parse_string_from_args(query, "event_id", None, required=True),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -37,13 +37,185 @@ from ._base import BaseHandler
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeviceHandler(BaseHandler):
|
||||
class DeviceWorkerHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(DeviceHandler, self).__init__(hs)
|
||||
super(DeviceWorkerHandler, self).__init__(hs)
|
||||
|
||||
self.hs = hs
|
||||
self.state = hs.get_state_handler()
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_devices_by_user(self, user_id):
|
||||
"""
|
||||
Retrieve the given user's devices
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Returns:
|
||||
defer.Deferred: list[dict[str, X]]: info on each device
|
||||
"""
|
||||
|
||||
device_map = yield self.store.get_devices_by_user(user_id)
|
||||
|
||||
ips = yield self.store.get_last_client_ip_by_device(
|
||||
user_id, device_id=None
|
||||
)
|
||||
|
||||
devices = list(device_map.values())
|
||||
for device in devices:
|
||||
_update_device_from_client_ips(device, ips)
|
||||
|
||||
defer.returnValue(devices)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_device(self, user_id, device_id):
|
||||
""" Retrieve the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
|
||||
Returns:
|
||||
defer.Deferred: dict[str, X]: info on the device
|
||||
Raises:
|
||||
errors.NotFoundError: if the device was not found
|
||||
"""
|
||||
try:
|
||||
device = yield self.store.get_device(user_id, device_id)
|
||||
except errors.StoreError:
|
||||
raise errors.NotFoundError
|
||||
ips = yield self.store.get_last_client_ip_by_device(
|
||||
user_id, device_id,
|
||||
)
|
||||
_update_device_from_client_ips(device, ips)
|
||||
defer.returnValue(device)
|
||||
|
||||
@measure_func("device.get_user_ids_changed")
|
||||
@defer.inlineCallbacks
|
||||
def get_user_ids_changed(self, user_id, from_token):
|
||||
"""Get list of users that have had the devices updated, or have newly
|
||||
joined a room, that `user_id` may be interested in.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
from_token (StreamToken)
|
||||
"""
|
||||
now_room_key = yield self.store.get_room_events_max_id()
|
||||
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
|
||||
# First we check if any devices have changed
|
||||
changed = yield self.store.get_user_whose_devices_changed(
|
||||
from_token.device_list_key
|
||||
)
|
||||
|
||||
# Then work out if any users have since joined
|
||||
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
|
||||
|
||||
member_events = yield self.store.get_membership_changes_for_user(
|
||||
user_id, from_token.room_key, now_room_key,
|
||||
)
|
||||
rooms_changed.update(event.room_id for event in member_events)
|
||||
|
||||
stream_ordering = RoomStreamToken.parse_stream_token(
|
||||
from_token.room_key
|
||||
).stream
|
||||
|
||||
possibly_changed = set(changed)
|
||||
possibly_left = set()
|
||||
for room_id in rooms_changed:
|
||||
current_state_ids = yield self.store.get_current_state_ids(room_id)
|
||||
|
||||
# The user may have left the room
|
||||
# TODO: Check if they actually did or if we were just invited.
|
||||
if room_id not in room_ids:
|
||||
for key, event_id in iteritems(current_state_ids):
|
||||
etype, state_key = key
|
||||
if etype != EventTypes.Member:
|
||||
continue
|
||||
possibly_left.add(state_key)
|
||||
continue
|
||||
|
||||
# Fetch the current state at the time.
|
||||
try:
|
||||
event_ids = yield self.store.get_forward_extremeties_for_room(
|
||||
room_id, stream_ordering=stream_ordering
|
||||
)
|
||||
except errors.StoreError:
|
||||
# we have purged the stream_ordering index since the stream
|
||||
# ordering: treat it the same as a new room
|
||||
event_ids = []
|
||||
|
||||
# special-case for an empty prev state: include all members
|
||||
# in the changed list
|
||||
if not event_ids:
|
||||
for key, event_id in iteritems(current_state_ids):
|
||||
etype, state_key = key
|
||||
if etype != EventTypes.Member:
|
||||
continue
|
||||
possibly_changed.add(state_key)
|
||||
continue
|
||||
|
||||
current_member_id = current_state_ids.get((EventTypes.Member, user_id))
|
||||
if not current_member_id:
|
||||
continue
|
||||
|
||||
# mapping from event_id -> state_dict
|
||||
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
|
||||
|
||||
# Check if we've joined the room? If so we just blindly add all the users to
|
||||
# the "possibly changed" users.
|
||||
for state_dict in itervalues(prev_state_ids):
|
||||
member_event = state_dict.get((EventTypes.Member, user_id), None)
|
||||
if not member_event or member_event != current_member_id:
|
||||
for key, event_id in iteritems(current_state_ids):
|
||||
etype, state_key = key
|
||||
if etype != EventTypes.Member:
|
||||
continue
|
||||
possibly_changed.add(state_key)
|
||||
break
|
||||
|
||||
# If there has been any change in membership, include them in the
|
||||
# possibly changed list. We'll check if they are joined below,
|
||||
# and we're not toooo worried about spuriously adding users.
|
||||
for key, event_id in iteritems(current_state_ids):
|
||||
etype, state_key = key
|
||||
if etype != EventTypes.Member:
|
||||
continue
|
||||
|
||||
# check if this member has changed since any of the extremities
|
||||
# at the stream_ordering, and add them to the list if so.
|
||||
for state_dict in itervalues(prev_state_ids):
|
||||
prev_event_id = state_dict.get(key, None)
|
||||
if not prev_event_id or prev_event_id != event_id:
|
||||
if state_key != user_id:
|
||||
possibly_changed.add(state_key)
|
||||
break
|
||||
|
||||
if possibly_changed or possibly_left:
|
||||
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||
user_id
|
||||
)
|
||||
|
||||
# Take the intersection of the users whose devices may have changed
|
||||
# and those that actually still share a room with the user
|
||||
possibly_joined = possibly_changed & users_who_share_room
|
||||
possibly_left = (possibly_changed | possibly_left) - users_who_share_room
|
||||
else:
|
||||
possibly_joined = []
|
||||
possibly_left = []
|
||||
|
||||
defer.returnValue({
|
||||
"changed": list(possibly_joined),
|
||||
"left": list(possibly_left),
|
||||
})
|
||||
|
||||
|
||||
class DeviceHandler(DeviceWorkerHandler):
|
||||
def __init__(self, hs):
|
||||
super(DeviceHandler, self).__init__(hs)
|
||||
|
||||
self.federation_sender = hs.get_federation_sender()
|
||||
|
||||
self._edu_updater = DeviceListEduUpdater(hs, self)
|
||||
@@ -103,52 +275,6 @@ class DeviceHandler(BaseHandler):
|
||||
|
||||
raise errors.StoreError(500, "Couldn't generate a device ID.")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_devices_by_user(self, user_id):
|
||||
"""
|
||||
Retrieve the given user's devices
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Returns:
|
||||
defer.Deferred: list[dict[str, X]]: info on each device
|
||||
"""
|
||||
|
||||
device_map = yield self.store.get_devices_by_user(user_id)
|
||||
|
||||
ips = yield self.store.get_last_client_ip_by_device(
|
||||
user_id, device_id=None
|
||||
)
|
||||
|
||||
devices = list(device_map.values())
|
||||
for device in devices:
|
||||
_update_device_from_client_ips(device, ips)
|
||||
|
||||
defer.returnValue(devices)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_device(self, user_id, device_id):
|
||||
""" Retrieve the given device
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
device_id (str):
|
||||
|
||||
Returns:
|
||||
defer.Deferred: dict[str, X]: info on the device
|
||||
Raises:
|
||||
errors.NotFoundError: if the device was not found
|
||||
"""
|
||||
try:
|
||||
device = yield self.store.get_device(user_id, device_id)
|
||||
except errors.StoreError:
|
||||
raise errors.NotFoundError
|
||||
ips = yield self.store.get_last_client_ip_by_device(
|
||||
user_id, device_id,
|
||||
)
|
||||
_update_device_from_client_ips(device, ips)
|
||||
defer.returnValue(device)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_device(self, user_id, device_id):
|
||||
""" Delete the given device
|
||||
@@ -287,126 +413,6 @@ class DeviceHandler(BaseHandler):
|
||||
for host in hosts:
|
||||
self.federation_sender.send_device_messages(host)
|
||||
|
||||
@measure_func("device.get_user_ids_changed")
|
||||
@defer.inlineCallbacks
|
||||
def get_user_ids_changed(self, user_id, from_token):
|
||||
"""Get list of users that have had the devices updated, or have newly
|
||||
joined a room, that `user_id` may be interested in.
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
from_token (StreamToken)
|
||||
"""
|
||||
now_token = yield self.hs.get_event_sources().get_current_token()
|
||||
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
|
||||
# First we check if any devices have changed
|
||||
changed = yield self.store.get_user_whose_devices_changed(
|
||||
from_token.device_list_key
|
||||
)
|
||||
|
||||
# Then work out if any users have since joined
|
||||
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
|
||||
|
||||
member_events = yield self.store.get_membership_changes_for_user(
|
||||
user_id, from_token.room_key, now_token.room_key
|
||||
)
|
||||
rooms_changed.update(event.room_id for event in member_events)
|
||||
|
||||
stream_ordering = RoomStreamToken.parse_stream_token(
|
||||
from_token.room_key
|
||||
).stream
|
||||
|
||||
possibly_changed = set(changed)
|
||||
possibly_left = set()
|
||||
for room_id in rooms_changed:
|
||||
current_state_ids = yield self.store.get_current_state_ids(room_id)
|
||||
|
||||
# The user may have left the room
|
||||
# TODO: Check if they actually did or if we were just invited.
|
||||
if room_id not in room_ids:
|
||||
for key, event_id in iteritems(current_state_ids):
|
||||
etype, state_key = key
|
||||
if etype != EventTypes.Member:
|
||||
continue
|
||||
possibly_left.add(state_key)
|
||||
continue
|
||||
|
||||
# Fetch the current state at the time.
|
||||
try:
|
||||
event_ids = yield self.store.get_forward_extremeties_for_room(
|
||||
room_id, stream_ordering=stream_ordering
|
||||
)
|
||||
except errors.StoreError:
|
||||
# we have purged the stream_ordering index since the stream
|
||||
# ordering: treat it the same as a new room
|
||||
event_ids = []
|
||||
|
||||
# special-case for an empty prev state: include all members
|
||||
# in the changed list
|
||||
if not event_ids:
|
||||
for key, event_id in iteritems(current_state_ids):
|
||||
etype, state_key = key
|
||||
if etype != EventTypes.Member:
|
||||
continue
|
||||
possibly_changed.add(state_key)
|
||||
continue
|
||||
|
||||
current_member_id = current_state_ids.get((EventTypes.Member, user_id))
|
||||
if not current_member_id:
|
||||
continue
|
||||
|
||||
# mapping from event_id -> state_dict
|
||||
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
|
||||
|
||||
# Check if we've joined the room? If so we just blindly add all the users to
|
||||
# the "possibly changed" users.
|
||||
for state_dict in itervalues(prev_state_ids):
|
||||
member_event = state_dict.get((EventTypes.Member, user_id), None)
|
||||
if not member_event or member_event != current_member_id:
|
||||
for key, event_id in iteritems(current_state_ids):
|
||||
etype, state_key = key
|
||||
if etype != EventTypes.Member:
|
||||
continue
|
||||
possibly_changed.add(state_key)
|
||||
break
|
||||
|
||||
# If there has been any change in membership, include them in the
|
||||
# possibly changed list. We'll check if they are joined below,
|
||||
# and we're not toooo worried about spuriously adding users.
|
||||
for key, event_id in iteritems(current_state_ids):
|
||||
etype, state_key = key
|
||||
if etype != EventTypes.Member:
|
||||
continue
|
||||
|
||||
# check if this member has changed since any of the extremities
|
||||
# at the stream_ordering, and add them to the list if so.
|
||||
for state_dict in itervalues(prev_state_ids):
|
||||
prev_event_id = state_dict.get(key, None)
|
||||
if not prev_event_id or prev_event_id != event_id:
|
||||
if state_key != user_id:
|
||||
possibly_changed.add(state_key)
|
||||
break
|
||||
|
||||
if possibly_changed or possibly_left:
|
||||
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||
user_id
|
||||
)
|
||||
|
||||
# Take the intersection of the users whose devices may have changed
|
||||
# and those that actually still share a room with the user
|
||||
possibly_joined = possibly_changed & users_who_share_room
|
||||
possibly_left = (possibly_changed | possibly_left) - users_who_share_room
|
||||
else:
|
||||
possibly_joined = []
|
||||
possibly_left = []
|
||||
|
||||
defer.returnValue({
|
||||
"changed": list(possibly_joined),
|
||||
"left": list(possibly_left),
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_federation_query_user_devices(self, user_id):
|
||||
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
|
||||
|
||||
@@ -858,6 +858,52 @@ class FederationHandler(BaseHandler):
|
||||
logger.debug("Not backfilling as no extremeties found.")
|
||||
return
|
||||
|
||||
# We only want to paginate if we can actually see the events we'll get,
|
||||
# as otherwise we'll just spend a lot of resources to get redacted
|
||||
# events.
|
||||
#
|
||||
# We do this by filtering all the backwards extremities and seeing if
|
||||
# any remain. Given we don't have the extremity events themselves, we
|
||||
# need to actually check the events that reference them.
|
||||
#
|
||||
# *Note*: the spec wants us to keep backfilling until we reach the start
|
||||
# of the room in case we are allowed to see some of the history. However
|
||||
# in practice that causes more issues than its worth, as a) its
|
||||
# relatively rare for there to be any visible history and b) even when
|
||||
# there is its often sufficiently long ago that clients would stop
|
||||
# attempting to paginate before backfill reached the visible history.
|
||||
#
|
||||
# TODO: If we do do a backfill then we should filter the backwards
|
||||
# extremities to only include those that point to visible portions of
|
||||
# history.
|
||||
#
|
||||
# TODO: Correctly handle the case where we are allowed to see the
|
||||
# forward event but not the backward extremity, e.g. in the case of
|
||||
# initial join of the server where we are allowed to see the join
|
||||
# event but not anything before it. This would require looking at the
|
||||
# state *before* the event, ignoring the special casing certain event
|
||||
# types have.
|
||||
|
||||
forward_events = yield self.store.get_successor_events(
|
||||
list(extremities),
|
||||
)
|
||||
|
||||
extremities_events = yield self.store.get_events(
|
||||
forward_events,
|
||||
check_redacted=False,
|
||||
get_prev_content=False,
|
||||
)
|
||||
|
||||
# We set `check_history_visibility_only` as we might otherwise get false
|
||||
# positives from users having been erased.
|
||||
filtered_extremities = yield filter_events_for_server(
|
||||
self.store, self.server_name, list(extremities_events.values()),
|
||||
redact=False, check_history_visibility_only=True,
|
||||
)
|
||||
|
||||
if not filtered_extremities:
|
||||
defer.returnValue(False)
|
||||
|
||||
# Check if we reached a point where we should start backfilling.
|
||||
sorted_extremeties_tuple = sorted(
|
||||
extremities.items(),
|
||||
|
||||
@@ -436,10 +436,11 @@ class EventCreationHandler(object):
|
||||
|
||||
if event.is_state():
|
||||
prev_state = yield self.deduplicate_state_event(event, context)
|
||||
logger.info(
|
||||
"Not bothering to persist duplicate state event %s", event.event_id,
|
||||
)
|
||||
if prev_state is not None:
|
||||
logger.info(
|
||||
"Not bothering to persist state event %s duplicated by %s",
|
||||
event.event_id, prev_state.event_id,
|
||||
)
|
||||
defer.returnValue(prev_state)
|
||||
|
||||
yield self.handle_new_client_event(
|
||||
|
||||
@@ -816,7 +816,7 @@ class PresenceHandler(object):
|
||||
if self.is_mine(observed_user):
|
||||
yield self.invite_presence(observed_user, observer_user)
|
||||
else:
|
||||
yield self.federation.send_edu(
|
||||
yield self.federation.build_and_send_edu(
|
||||
destination=observed_user.domain,
|
||||
edu_type="m.presence_invite",
|
||||
content={
|
||||
@@ -836,7 +836,7 @@ class PresenceHandler(object):
|
||||
if self.is_mine(observer_user):
|
||||
yield self.accept_presence(observed_user, observer_user)
|
||||
else:
|
||||
self.federation.send_edu(
|
||||
self.federation.build_and_send_edu(
|
||||
destination=observer_user.domain,
|
||||
edu_type="m.presence_accept",
|
||||
content={
|
||||
@@ -848,7 +848,7 @@ class PresenceHandler(object):
|
||||
state_dict = yield self.get_state(observed_user, as_event=False)
|
||||
state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
|
||||
|
||||
self.federation.send_edu(
|
||||
self.federation.build_and_send_edu(
|
||||
destination=observer_user.domain,
|
||||
edu_type="m.presence",
|
||||
content={
|
||||
|
||||
@@ -16,7 +16,6 @@ import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import get_domain_from_id
|
||||
|
||||
from ._base import BaseHandler
|
||||
@@ -38,31 +37,6 @@ class ReceiptsHandler(BaseHandler):
|
||||
self.clock = self.hs.get_clock()
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def received_client_receipt(self, room_id, receipt_type, user_id,
|
||||
event_id):
|
||||
"""Called when a client tells us a local user has read up to the given
|
||||
event_id in the room.
|
||||
"""
|
||||
receipt = {
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"event_ids": [event_id],
|
||||
"data": {
|
||||
"ts": int(self.clock.time_msec()),
|
||||
}
|
||||
}
|
||||
|
||||
is_new = yield self._handle_new_receipts([receipt])
|
||||
|
||||
if is_new:
|
||||
# fire off a process in the background to send the receipt to
|
||||
# remote servers
|
||||
run_as_background_process(
|
||||
'push_receipts_to_remotes', self._push_remotes, receipt
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _received_remote_receipt(self, origin, content):
|
||||
"""Called when we receive an EDU of type m.receipt from a remote HS.
|
||||
@@ -128,43 +102,54 @@ class ReceiptsHandler(BaseHandler):
|
||||
defer.returnValue(True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _push_remotes(self, receipt):
|
||||
"""Given a receipt, works out which remote servers should be
|
||||
poked and pokes them.
|
||||
def received_client_receipt(self, room_id, receipt_type, user_id,
|
||||
event_id):
|
||||
"""Called when a client tells us a local user has read up to the given
|
||||
event_id in the room.
|
||||
"""
|
||||
try:
|
||||
# TODO: optimise this to move some of the work to the workers.
|
||||
room_id = receipt["room_id"]
|
||||
receipt_type = receipt["receipt_type"]
|
||||
user_id = receipt["user_id"]
|
||||
event_ids = receipt["event_ids"]
|
||||
data = receipt["data"]
|
||||
receipt = {
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"event_ids": [event_id],
|
||||
"data": {
|
||||
"ts": int(self.clock.time_msec()),
|
||||
}
|
||||
}
|
||||
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
remotedomains = set(get_domain_from_id(u) for u in users)
|
||||
remotedomains = remotedomains.copy()
|
||||
remotedomains.discard(self.server_name)
|
||||
is_new = yield self._handle_new_receipts([receipt])
|
||||
if not is_new:
|
||||
return
|
||||
|
||||
logger.debug("Sending receipt to: %r", remotedomains)
|
||||
# Work out which remote servers should be poked and poke them.
|
||||
|
||||
for domain in remotedomains:
|
||||
self.federation.send_edu(
|
||||
destination=domain,
|
||||
edu_type="m.receipt",
|
||||
content={
|
||||
room_id: {
|
||||
receipt_type: {
|
||||
user_id: {
|
||||
"event_ids": event_ids,
|
||||
"data": data,
|
||||
}
|
||||
# TODO: optimise this to move some of the work to the workers.
|
||||
data = receipt["data"]
|
||||
|
||||
# XXX why does this not use state.get_current_hosts_in_room() ?
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
remotedomains = set(get_domain_from_id(u) for u in users)
|
||||
remotedomains = remotedomains.copy()
|
||||
remotedomains.discard(self.server_name)
|
||||
|
||||
logger.debug("Sending receipt to: %r", remotedomains)
|
||||
|
||||
for domain in remotedomains:
|
||||
self.federation.build_and_send_edu(
|
||||
destination=domain,
|
||||
edu_type="m.receipt",
|
||||
content={
|
||||
room_id: {
|
||||
receipt_type: {
|
||||
user_id: {
|
||||
"event_ids": [event_id],
|
||||
"data": data,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
key=(room_id, receipt_type, user_id),
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error pushing receipts to remote servers")
|
||||
},
|
||||
key=(room_id, receipt_type, user_id),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_receipts_for_room(self, room_id, to_key):
|
||||
|
||||
@@ -231,7 +231,7 @@ class TypingHandler(object):
|
||||
for domain in set(get_domain_from_id(u) for u in users):
|
||||
if domain != self.server_name:
|
||||
logger.debug("sending typing update to %s", domain)
|
||||
self.federation.send_edu(
|
||||
self.federation.build_and_send_edu(
|
||||
destination=domain,
|
||||
edu_type="m.typing",
|
||||
content={
|
||||
|
||||
@@ -68,9 +68,13 @@ class MatrixFederationAgent(object):
|
||||
TLS policy to use for fetching .well-known files. None to use a default
|
||||
(browser-like) implementation.
|
||||
|
||||
srv_resolver (SrvResolver|None):
|
||||
_srv_resolver (SrvResolver|None):
|
||||
SRVResolver impl to use for looking up SRV records. None to use a default
|
||||
implementation.
|
||||
|
||||
_well_known_cache (TTLCache|None):
|
||||
TTLCache impl for storing cached well-known lookups. None to use a default
|
||||
implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -178,8 +178,6 @@ class Notifier(object):
|
||||
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
|
||||
)
|
||||
|
||||
self.replication_deferred = ObservableDeferred(defer.Deferred())
|
||||
|
||||
# This is not a very cheap test to perform, but it's only executed
|
||||
# when rendering the metrics page, which is likely once per minute at
|
||||
# most when scraping it.
|
||||
@@ -205,7 +203,9 @@ class Notifier(object):
|
||||
|
||||
def add_replication_callback(self, cb):
|
||||
"""Add a callback that will be called when some new data is available.
|
||||
Callback is not given any arguments.
|
||||
Callback is not given any arguments. It should *not* return a Deferred - if
|
||||
it needs to do any asynchronous work, a background thread should be started and
|
||||
wrapped with run_as_background_process.
|
||||
"""
|
||||
self.replication_callbacks.append(cb)
|
||||
|
||||
@@ -517,60 +517,5 @@ class Notifier(object):
|
||||
|
||||
def notify_replication(self):
|
||||
"""Notify the any replication listeners that there's a new event"""
|
||||
with PreserveLoggingContext():
|
||||
deferred = self.replication_deferred
|
||||
self.replication_deferred = ObservableDeferred(defer.Deferred())
|
||||
deferred.callback(None)
|
||||
|
||||
# the callbacks may well outlast the current request, so we run
|
||||
# them in the sentinel logcontext.
|
||||
#
|
||||
# (ideally it would be up to the callbacks to know if they were
|
||||
# starting off background processes and drop the logcontext
|
||||
# accordingly, but that requires more changes)
|
||||
for cb in self.replication_callbacks:
|
||||
cb()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_replication(self, callback, timeout):
|
||||
"""Wait for an event to happen.
|
||||
|
||||
Args:
|
||||
callback: Gets called whenever an event happens. If this returns a
|
||||
truthy value then ``wait_for_replication`` returns, otherwise
|
||||
it waits for another event.
|
||||
timeout: How many milliseconds to wait for callback return a truthy
|
||||
value.
|
||||
|
||||
Returns:
|
||||
A deferred that resolves with the value returned by the callback.
|
||||
"""
|
||||
listener = _NotificationListener(None)
|
||||
|
||||
end_time = self.clock.time_msec() + timeout
|
||||
|
||||
while True:
|
||||
listener.deferred = self.replication_deferred.observe()
|
||||
result = yield callback()
|
||||
if result:
|
||||
break
|
||||
|
||||
now = self.clock.time_msec()
|
||||
if end_time <= now:
|
||||
break
|
||||
|
||||
listener.deferred = timeout_deferred(
|
||||
listener.deferred,
|
||||
timeout=(end_time - now) / 1000.,
|
||||
reactor=self.hs.get_reactor(),
|
||||
)
|
||||
|
||||
try:
|
||||
with PreserveLoggingContext():
|
||||
yield listener.deferred
|
||||
except defer.TimeoutError:
|
||||
break
|
||||
except defer.CancelledError:
|
||||
break
|
||||
|
||||
defer.returnValue(result)
|
||||
for cb in self.replication_callbacks:
|
||||
cb()
|
||||
|
||||
@@ -13,15 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage import DataStore
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage.deviceinbox import DeviceInboxWorkerStore
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
from ._base import BaseSlavedStore, __func__
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
|
||||
class SlavedDeviceInboxStore(BaseSlavedStore):
|
||||
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
|
||||
self._device_inbox_id_gen = SlavedIdTracker(
|
||||
@@ -43,12 +42,6 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
|
||||
expiry_ms=30 * 60 * 1000,
|
||||
)
|
||||
|
||||
get_to_device_stream_token = __func__(DataStore.get_to_device_stream_token)
|
||||
get_new_messages_for_device = __func__(DataStore.get_new_messages_for_device)
|
||||
get_new_device_msgs_for_remote = __func__(DataStore.get_new_device_msgs_for_remote)
|
||||
delete_messages_for_device = __func__(DataStore.delete_messages_for_device)
|
||||
delete_device_msgs_for_remote = __func__(DataStore.delete_device_msgs_for_remote)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedDeviceInboxStore, self).stream_positions()
|
||||
result["to_device"] = self._device_inbox_id_gen.get_current_token()
|
||||
|
||||
@@ -13,15 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.end_to_end_keys import EndToEndKeyStore
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage.devices import DeviceWorkerStore
|
||||
from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
from ._base import BaseSlavedStore, __func__
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
|
||||
|
||||
class SlavedDeviceStore(BaseSlavedStore):
|
||||
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedDeviceStore, self).__init__(db_conn, hs)
|
||||
|
||||
@@ -38,17 +37,6 @@ class SlavedDeviceStore(BaseSlavedStore):
|
||||
"DeviceListFederationStreamChangeCache", device_list_max,
|
||||
)
|
||||
|
||||
get_device_stream_token = __func__(DataStore.get_device_stream_token)
|
||||
get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed)
|
||||
get_devices_by_remote = __func__(DataStore.get_devices_by_remote)
|
||||
_get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn)
|
||||
_get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn)
|
||||
mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote)
|
||||
_mark_as_sent_devices_by_remote_txn = (
|
||||
__func__(DataStore._mark_as_sent_devices_by_remote_txn)
|
||||
)
|
||||
count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedDeviceStore, self).stream_positions()
|
||||
result["device_lists"] = self._device_list_id_gen.get_current_token()
|
||||
@@ -58,14 +46,23 @@ class SlavedDeviceStore(BaseSlavedStore):
|
||||
if stream_name == "device_lists":
|
||||
self._device_list_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self._device_list_stream_cache.entity_has_changed(
|
||||
row.user_id, token
|
||||
self._invalidate_caches_for_devices(
|
||||
token, row.user_id, row.destination,
|
||||
)
|
||||
|
||||
if row.destination:
|
||||
self._device_list_federation_stream_cache.entity_has_changed(
|
||||
row.destination, token
|
||||
)
|
||||
return super(SlavedDeviceStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
|
||||
def _invalidate_caches_for_devices(self, token, user_id, destination):
|
||||
self._device_list_stream_cache.entity_has_changed(
|
||||
user_id, token
|
||||
)
|
||||
|
||||
if destination:
|
||||
self._device_list_federation_stream_cache.entity_has_changed(
|
||||
destination, token
|
||||
)
|
||||
|
||||
self._get_cached_devices_for_user.invalidate((user_id,))
|
||||
self._get_cached_user_device.invalidate_many((user_id,))
|
||||
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
|
||||
|
||||
@@ -54,8 +54,11 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedPresenceStore, self).stream_positions()
|
||||
position = self._presence_id_gen.get_current_token()
|
||||
result["presence"] = position
|
||||
|
||||
if self.hs.config.use_presence:
|
||||
position = self._presence_id_gen.get_current_token()
|
||||
result["presence"] = position
|
||||
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
|
||||
@@ -20,7 +20,7 @@ from ._slaved_id_tracker import SlavedIdTracker
|
||||
from .events import SlavedEventStore
|
||||
|
||||
|
||||
class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
|
||||
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
self._push_rules_stream_id_gen = SlavedIdTracker(
|
||||
db_conn, "push_rules_stream", "stream_id",
|
||||
|
||||
@@ -39,7 +39,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||
Accepts a handler that will be called when new data is available or data
|
||||
is required.
|
||||
"""
|
||||
maxDelay = 5 # Try at least once every N seconds
|
||||
maxDelay = 30 # Try at least once every N seconds
|
||||
|
||||
def __init__(self, hs, client_name, handler):
|
||||
self.client_name = client_name
|
||||
@@ -54,7 +54,6 @@ class ReplicationClientFactory(ReconnectingClientFactory):
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
logger.info("Connected to replication: %r", addr)
|
||||
self.resetDelay()
|
||||
return ClientReplicationStreamProtocol(
|
||||
self.client_name, self.server_name, self._clock, self.handler
|
||||
)
|
||||
@@ -90,15 +89,18 @@ class ReplicationClientHandler(object):
|
||||
# Used for tests.
|
||||
self.awaiting_syncs = {}
|
||||
|
||||
# The factory used to create connections.
|
||||
self.factory = None
|
||||
|
||||
def start_replication(self, hs):
|
||||
"""Helper method to start a replication connection to the remote server
|
||||
using TCP.
|
||||
"""
|
||||
client_name = hs.config.worker_name
|
||||
factory = ReplicationClientFactory(hs, client_name, self)
|
||||
self.factory = ReplicationClientFactory(hs, client_name, self)
|
||||
host = hs.config.worker_replication_host
|
||||
port = hs.config.worker_replication_port
|
||||
hs.get_reactor().connectTCP(host, port, factory)
|
||||
hs.get_reactor().connectTCP(host, port, self.factory)
|
||||
|
||||
def on_rdata(self, stream_name, token, rows):
|
||||
"""Called when we get new replication data. By default this just pokes
|
||||
@@ -140,6 +142,7 @@ class ReplicationClientHandler(object):
|
||||
args["account_data"] = user_account_data
|
||||
elif room_account_data:
|
||||
args["account_data"] = room_account_data
|
||||
|
||||
return args
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
@@ -204,3 +207,14 @@ class ReplicationClientHandler(object):
|
||||
for cmd in self.pending_commands:
|
||||
connection.send_command(cmd)
|
||||
self.pending_commands = []
|
||||
|
||||
def finished_connecting(self):
|
||||
"""Called when we have successfully subscribed and caught up to all
|
||||
streams we're interested in.
|
||||
"""
|
||||
logger.info("Finished connecting to server")
|
||||
|
||||
# We don't reset the delay any earlier as otherwise if there is a
|
||||
# problem during start up we'll end up tight looping connecting to the
|
||||
# server.
|
||||
self.factory.resetDelay()
|
||||
|
||||
@@ -127,8 +127,11 @@ class RdataCommand(Command):
|
||||
|
||||
|
||||
class PositionCommand(Command):
|
||||
"""Sent by the client to tell the client the stream postition without
|
||||
"""Sent by the server to tell the client the stream postition without
|
||||
needing to send an RDATA.
|
||||
|
||||
Sent to the client after all missing updates for a stream have been sent
|
||||
to the client and they're now up to date.
|
||||
"""
|
||||
NAME = "POSITION"
|
||||
|
||||
|
||||
@@ -526,6 +526,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
self.server_name = server_name
|
||||
self.handler = handler
|
||||
|
||||
# Set of stream names that have been subscribe to, but haven't yet
|
||||
# caught up with. This is used to track when the client has been fully
|
||||
# connected to the remote.
|
||||
self.streams_connecting = set()
|
||||
|
||||
# Map of stream to batched updates. See RdataCommand for info on how
|
||||
# batching works.
|
||||
self.pending_batches = {}
|
||||
@@ -548,6 +553,10 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
# We've now finished connecting to so inform the client handler
|
||||
self.handler.update_connection(self)
|
||||
|
||||
# This will happen if we don't actually subscribe to any streams
|
||||
if not self.streams_connecting:
|
||||
self.handler.finished_connecting()
|
||||
|
||||
def on_SERVER(self, cmd):
|
||||
if cmd.data != self.server_name:
|
||||
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
|
||||
@@ -577,6 +586,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
return self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||
|
||||
def on_POSITION(self, cmd):
|
||||
# When we get a `POSITION` command it means we've finished getting
|
||||
# missing updates for the given stream, and are now up to date.
|
||||
self.streams_connecting.discard(cmd.stream_name)
|
||||
if not self.streams_connecting:
|
||||
self.handler.finished_connecting()
|
||||
|
||||
return self.handler.on_position(cmd.stream_name, cmd.token)
|
||||
|
||||
def on_SYNC(self, cmd):
|
||||
@@ -593,6 +608,8 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
self.id(), stream_name, token
|
||||
)
|
||||
|
||||
self.streams_connecting.add(stream_name)
|
||||
|
||||
self.send_command(ReplicateCommand(stream_name, token))
|
||||
|
||||
def on_connection_closed(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2019 New Vector Ltd.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -213,8 +214,7 @@ def get_filename_from_headers(headers):
|
||||
Content-Disposition HTTP header.
|
||||
|
||||
Args:
|
||||
headers (twisted.web.http_headers.Headers): The HTTP
|
||||
request headers.
|
||||
headers (dict[bytes, list[bytes]]): The HTTP request headers.
|
||||
|
||||
Returns:
|
||||
A Unicode string of the filename, or None.
|
||||
@@ -225,23 +225,12 @@ def get_filename_from_headers(headers):
|
||||
if not content_disposition[0]:
|
||||
return
|
||||
|
||||
# dict of unicode: bytes, corresponding to the key value sections of the
|
||||
# Content-Disposition header.
|
||||
params = {}
|
||||
parts = content_disposition[0].split(b";")
|
||||
for i in parts:
|
||||
# Split into key-value pairs, if able
|
||||
# We don't care about things like `inline`, so throw it out
|
||||
if b"=" not in i:
|
||||
continue
|
||||
|
||||
key, value = i.strip().split(b"=")
|
||||
params[key.decode('ascii')] = value
|
||||
_, params = _parse_header(content_disposition[0])
|
||||
|
||||
upload_name = None
|
||||
|
||||
# First check if there is a valid UTF-8 filename
|
||||
upload_name_utf8 = params.get("filename*", None)
|
||||
upload_name_utf8 = params.get(b"filename*", None)
|
||||
if upload_name_utf8:
|
||||
if upload_name_utf8.lower().startswith(b"utf-8''"):
|
||||
upload_name_utf8 = upload_name_utf8[7:]
|
||||
@@ -267,12 +256,68 @@ def get_filename_from_headers(headers):
|
||||
|
||||
# If there isn't check for an ascii name.
|
||||
if not upload_name:
|
||||
upload_name_ascii = params.get("filename", None)
|
||||
upload_name_ascii = params.get(b"filename", None)
|
||||
if upload_name_ascii and is_ascii(upload_name_ascii):
|
||||
# Make sure there's no %-quoted bytes. If there is, reject it as
|
||||
# non-valid ASCII.
|
||||
if b"%" not in upload_name_ascii:
|
||||
upload_name = upload_name_ascii.decode('ascii')
|
||||
upload_name = upload_name_ascii.decode('ascii')
|
||||
|
||||
# This may be None here, indicating we did not find a matching name.
|
||||
return upload_name
|
||||
|
||||
|
||||
def _parse_header(line):
|
||||
"""Parse a Content-type like header.
|
||||
|
||||
Cargo-culted from `cgi`, but works on bytes rather than strings.
|
||||
|
||||
Args:
|
||||
line (bytes): header to be parsed
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, dict[bytes, bytes]]:
|
||||
the main content-type, followed by the parameter dictionary
|
||||
"""
|
||||
parts = _parseparam(b';' + line)
|
||||
key = next(parts)
|
||||
pdict = {}
|
||||
for p in parts:
|
||||
i = p.find(b'=')
|
||||
if i >= 0:
|
||||
name = p[:i].strip().lower()
|
||||
value = p[i + 1:].strip()
|
||||
|
||||
# strip double-quotes
|
||||
if len(value) >= 2 and value[0:1] == value[-1:] == b'"':
|
||||
value = value[1:-1]
|
||||
value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"')
|
||||
pdict[name] = value
|
||||
|
||||
return key, pdict
|
||||
|
||||
|
||||
def _parseparam(s):
|
||||
"""Generator which splits the input on ;, respecting double-quoted sequences
|
||||
|
||||
Cargo-culted from `cgi`, but works on bytes rather than strings.
|
||||
|
||||
Args:
|
||||
s (bytes): header to be parsed
|
||||
|
||||
Returns:
|
||||
Iterable[bytes]: the split input
|
||||
"""
|
||||
while s[:1] == b';':
|
||||
s = s[1:]
|
||||
|
||||
# look for the next ;
|
||||
end = s.find(b';')
|
||||
|
||||
# if there is an odd number of " marks between here and the next ;, skip to the
|
||||
# next ; instead
|
||||
while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2:
|
||||
end = s.find(b';', end + 1)
|
||||
|
||||
if end < 0:
|
||||
end = len(s)
|
||||
f = s[:end]
|
||||
yield f.strip()
|
||||
s = s[end:]
|
||||
|
||||
@@ -51,7 +51,7 @@ from synapse.handlers.acme import AcmeHandler
|
||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
|
||||
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
|
||||
from synapse.handlers.devicemessage import DeviceMessageHandler
|
||||
from synapse.handlers.e2e_keys import E2eKeysHandler
|
||||
from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
|
||||
@@ -307,7 +307,10 @@ class HomeServer(object):
|
||||
return MacaroonGenerator(self)
|
||||
|
||||
def build_device_handler(self):
|
||||
return DeviceHandler(self)
|
||||
if self.config.worker_app:
|
||||
return DeviceWorkerHandler(self)
|
||||
else:
|
||||
return DeviceHandler(self)
|
||||
|
||||
def build_device_message_handler(self):
|
||||
return DeviceMessageHandler(self)
|
||||
|
||||
@@ -7,9 +7,9 @@ import synapse.handlers.auth
|
||||
import synapse.handlers.deactivate_account
|
||||
import synapse.handlers.device
|
||||
import synapse.handlers.e2e_keys
|
||||
import synapse.handlers.message
|
||||
import synapse.handlers.room
|
||||
import synapse.handlers.room_member
|
||||
import synapse.handlers.message
|
||||
import synapse.handlers.set_password
|
||||
import synapse.rest.media.v1.media_repository
|
||||
import synapse.server_notices.server_notices_manager
|
||||
|
||||
@@ -19,14 +19,174 @@ from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
|
||||
from .background_updates import BackgroundUpdateStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeviceInboxStore(BackgroundUpdateStore):
|
||||
class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
def get_to_device_stream_token(self):
|
||||
return self._device_inbox_id_gen.get_current_token()
|
||||
|
||||
def get_new_messages_for_device(
|
||||
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
user_id(str): The recipient user_id.
|
||||
device_id(str): The recipient device_id.
|
||||
current_stream_id(int): The current position of the to device
|
||||
message stream.
|
||||
Returns:
|
||||
Deferred ([dict], int): List of messages for the device and where
|
||||
in the stream the messages got to.
|
||||
"""
|
||||
has_changed = self._device_inbox_stream_cache.has_entity_changed(
|
||||
user_id, last_stream_id
|
||||
)
|
||||
if not has_changed:
|
||||
return defer.succeed(([], current_stream_id))
|
||||
|
||||
def get_new_messages_for_device_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, message_json FROM device_inbox"
|
||||
" WHERE user_id = ? AND device_id = ?"
|
||||
" AND ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (
|
||||
user_id, device_id, last_stream_id, current_stream_id, limit
|
||||
))
|
||||
messages = []
|
||||
for row in txn:
|
||||
stream_pos = row[0]
|
||||
messages.append(json.loads(row[1]))
|
||||
if len(messages) < limit:
|
||||
stream_pos = current_stream_id
|
||||
return (messages, stream_pos)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_new_messages_for_device", get_new_messages_for_device_txn,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
|
||||
"""
|
||||
Args:
|
||||
user_id(str): The recipient user_id.
|
||||
device_id(str): The recipient device_id.
|
||||
up_to_stream_id(int): Where to delete messages up to.
|
||||
Returns:
|
||||
A deferred that resolves to the number of messages deleted.
|
||||
"""
|
||||
# If we have cached the last stream id we've deleted up to, we can
|
||||
# check if there is likely to be anything that needs deleting
|
||||
last_deleted_stream_id = self._last_device_delete_cache.get(
|
||||
(user_id, device_id), None
|
||||
)
|
||||
if last_deleted_stream_id:
|
||||
has_changed = self._device_inbox_stream_cache.has_entity_changed(
|
||||
user_id, last_deleted_stream_id
|
||||
)
|
||||
if not has_changed:
|
||||
defer.returnValue(0)
|
||||
|
||||
def delete_messages_for_device_txn(txn):
|
||||
sql = (
|
||||
"DELETE FROM device_inbox"
|
||||
" WHERE user_id = ? AND device_id = ?"
|
||||
" AND stream_id <= ?"
|
||||
)
|
||||
txn.execute(sql, (user_id, device_id, up_to_stream_id))
|
||||
return txn.rowcount
|
||||
|
||||
count = yield self.runInteraction(
|
||||
"delete_messages_for_device", delete_messages_for_device_txn
|
||||
)
|
||||
|
||||
# Update the cache, ensuring that we only ever increase the value
|
||||
last_deleted_stream_id = self._last_device_delete_cache.get(
|
||||
(user_id, device_id), 0
|
||||
)
|
||||
self._last_device_delete_cache[(user_id, device_id)] = max(
|
||||
last_deleted_stream_id, up_to_stream_id
|
||||
)
|
||||
|
||||
defer.returnValue(count)
|
||||
|
||||
def get_new_device_msgs_for_remote(
|
||||
self, destination, last_stream_id, current_stream_id, limit=100
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
destination(str): The name of the remote server.
|
||||
last_stream_id(int|long): The last position of the device message stream
|
||||
that the server sent up to.
|
||||
current_stream_id(int|long): The current position of the device
|
||||
message stream.
|
||||
Returns:
|
||||
Deferred ([dict], int|long): List of messages for the device and where
|
||||
in the stream the messages got to.
|
||||
"""
|
||||
|
||||
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
|
||||
destination, last_stream_id
|
||||
)
|
||||
if not has_changed or last_stream_id == current_stream_id:
|
||||
return defer.succeed(([], current_stream_id))
|
||||
|
||||
def get_new_messages_for_remote_destination_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, messages_json FROM device_federation_outbox"
|
||||
" WHERE destination = ?"
|
||||
" AND ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (
|
||||
destination, last_stream_id, current_stream_id, limit
|
||||
))
|
||||
messages = []
|
||||
for row in txn:
|
||||
stream_pos = row[0]
|
||||
messages.append(json.loads(row[1]))
|
||||
if len(messages) < limit:
|
||||
stream_pos = current_stream_id
|
||||
return (messages, stream_pos)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_new_device_msgs_for_remote",
|
||||
get_new_messages_for_remote_destination_txn,
|
||||
)
|
||||
|
||||
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
|
||||
"""Used to delete messages when the remote destination acknowledges
|
||||
their receipt.
|
||||
|
||||
Args:
|
||||
destination(str): The destination server_name
|
||||
up_to_stream_id(int): Where to delete messages up to.
|
||||
Returns:
|
||||
A deferred that resolves when the messages have been deleted.
|
||||
"""
|
||||
def delete_messages_for_remote_destination_txn(txn):
|
||||
sql = (
|
||||
"DELETE FROM device_federation_outbox"
|
||||
" WHERE destination = ?"
|
||||
" AND stream_id <= ?"
|
||||
)
|
||||
txn.execute(sql, (destination, up_to_stream_id))
|
||||
|
||||
return self.runInteraction(
|
||||
"delete_device_msgs_for_remote",
|
||||
delete_messages_for_remote_destination_txn
|
||||
)
|
||||
|
||||
|
||||
class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
|
||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
@@ -220,93 +380,6 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||
|
||||
txn.executemany(sql, rows)
|
||||
|
||||
def get_new_messages_for_device(
|
||||
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
user_id(str): The recipient user_id.
|
||||
device_id(str): The recipient device_id.
|
||||
current_stream_id(int): The current position of the to device
|
||||
message stream.
|
||||
Returns:
|
||||
Deferred ([dict], int): List of messages for the device and where
|
||||
in the stream the messages got to.
|
||||
"""
|
||||
has_changed = self._device_inbox_stream_cache.has_entity_changed(
|
||||
user_id, last_stream_id
|
||||
)
|
||||
if not has_changed:
|
||||
return defer.succeed(([], current_stream_id))
|
||||
|
||||
def get_new_messages_for_device_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, message_json FROM device_inbox"
|
||||
" WHERE user_id = ? AND device_id = ?"
|
||||
" AND ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (
|
||||
user_id, device_id, last_stream_id, current_stream_id, limit
|
||||
))
|
||||
messages = []
|
||||
for row in txn:
|
||||
stream_pos = row[0]
|
||||
messages.append(json.loads(row[1]))
|
||||
if len(messages) < limit:
|
||||
stream_pos = current_stream_id
|
||||
return (messages, stream_pos)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_new_messages_for_device", get_new_messages_for_device_txn,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
|
||||
"""
|
||||
Args:
|
||||
user_id(str): The recipient user_id.
|
||||
device_id(str): The recipient device_id.
|
||||
up_to_stream_id(int): Where to delete messages up to.
|
||||
Returns:
|
||||
A deferred that resolves to the number of messages deleted.
|
||||
"""
|
||||
# If we have cached the last stream id we've deleted up to, we can
|
||||
# check if there is likely to be anything that needs deleting
|
||||
last_deleted_stream_id = self._last_device_delete_cache.get(
|
||||
(user_id, device_id), None
|
||||
)
|
||||
if last_deleted_stream_id:
|
||||
has_changed = self._device_inbox_stream_cache.has_entity_changed(
|
||||
user_id, last_deleted_stream_id
|
||||
)
|
||||
if not has_changed:
|
||||
defer.returnValue(0)
|
||||
|
||||
def delete_messages_for_device_txn(txn):
|
||||
sql = (
|
||||
"DELETE FROM device_inbox"
|
||||
" WHERE user_id = ? AND device_id = ?"
|
||||
" AND stream_id <= ?"
|
||||
)
|
||||
txn.execute(sql, (user_id, device_id, up_to_stream_id))
|
||||
return txn.rowcount
|
||||
|
||||
count = yield self.runInteraction(
|
||||
"delete_messages_for_device", delete_messages_for_device_txn
|
||||
)
|
||||
|
||||
# Update the cache, ensuring that we only ever increase the value
|
||||
last_deleted_stream_id = self._last_device_delete_cache.get(
|
||||
(user_id, device_id), 0
|
||||
)
|
||||
self._last_device_delete_cache[(user_id, device_id)] = max(
|
||||
last_deleted_stream_id, up_to_stream_id
|
||||
)
|
||||
|
||||
defer.returnValue(count)
|
||||
|
||||
def get_all_new_device_messages(self, last_pos, current_pos, limit):
|
||||
"""
|
||||
Args:
|
||||
@@ -351,77 +424,6 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||
"get_all_new_device_messages", get_all_new_device_messages_txn
|
||||
)
|
||||
|
||||
def get_to_device_stream_token(self):
|
||||
return self._device_inbox_id_gen.get_current_token()
|
||||
|
||||
def get_new_device_msgs_for_remote(
|
||||
self, destination, last_stream_id, current_stream_id, limit=100
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
destination(str): The name of the remote server.
|
||||
last_stream_id(int|long): The last position of the device message stream
|
||||
that the server sent up to.
|
||||
current_stream_id(int|long): The current position of the device
|
||||
message stream.
|
||||
Returns:
|
||||
Deferred ([dict], int|long): List of messages for the device and where
|
||||
in the stream the messages got to.
|
||||
"""
|
||||
|
||||
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
|
||||
destination, last_stream_id
|
||||
)
|
||||
if not has_changed or last_stream_id == current_stream_id:
|
||||
return defer.succeed(([], current_stream_id))
|
||||
|
||||
def get_new_messages_for_remote_destination_txn(txn):
|
||||
sql = (
|
||||
"SELECT stream_id, messages_json FROM device_federation_outbox"
|
||||
" WHERE destination = ?"
|
||||
" AND ? < stream_id AND stream_id <= ?"
|
||||
" ORDER BY stream_id ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (
|
||||
destination, last_stream_id, current_stream_id, limit
|
||||
))
|
||||
messages = []
|
||||
for row in txn:
|
||||
stream_pos = row[0]
|
||||
messages.append(json.loads(row[1]))
|
||||
if len(messages) < limit:
|
||||
stream_pos = current_stream_id
|
||||
return (messages, stream_pos)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_new_device_msgs_for_remote",
|
||||
get_new_messages_for_remote_destination_txn,
|
||||
)
|
||||
|
||||
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
|
||||
"""Used to delete messages when the remote destination acknowledges
|
||||
their receipt.
|
||||
|
||||
Args:
|
||||
destination(str): The destination server_name
|
||||
up_to_stream_id(int): Where to delete messages up to.
|
||||
Returns:
|
||||
A deferred that resolves when the messages have been deleted.
|
||||
"""
|
||||
def delete_messages_for_remote_destination_txn(txn):
|
||||
sql = (
|
||||
"DELETE FROM device_federation_outbox"
|
||||
" WHERE destination = ?"
|
||||
" AND stream_id <= ?"
|
||||
)
|
||||
txn.execute(sql, (destination, up_to_stream_id))
|
||||
|
||||
return self.runInteraction(
|
||||
"delete_device_msgs_for_remote",
|
||||
delete_messages_for_remote_destination_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_drop_index_device_inbox(self, progress, batch_size):
|
||||
def reindex_txn(conn):
|
||||
|
||||
@@ -22,11 +22,10 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage._base import Cache, SQLBaseStore, db_to_json
|
||||
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||
|
||||
from ._base import Cache, db_to_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
|
||||
@@ -34,7 +33,343 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
|
||||
)
|
||||
|
||||
|
||||
class DeviceStore(BackgroundUpdateStore):
|
||||
class DeviceWorkerStore(SQLBaseStore):
|
||||
def get_device(self, user_id, device_id):
|
||||
"""Retrieve a device.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the device
|
||||
device_id (str): The ID of the device to retrieve
|
||||
Returns:
|
||||
defer.Deferred for a dict containing the device information
|
||||
Raises:
|
||||
StoreError: if the device is not found
|
||||
"""
|
||||
return self._simple_select_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
desc="get_device",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_devices_by_user(self, user_id):
|
||||
"""Retrieve all of a user's registered devices.
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Returns:
|
||||
defer.Deferred: resolves to a dict from device_id to a dict
|
||||
containing "device_id", "user_id" and "display_name" for each
|
||||
device.
|
||||
"""
|
||||
devices = yield self._simple_select_list(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
desc="get_devices_by_user"
|
||||
)
|
||||
|
||||
defer.returnValue({d["device_id"]: d for d in devices})
|
||||
|
||||
def get_devices_by_remote(self, destination, from_stream_id):
|
||||
"""Get stream of updates to send to remote servers
|
||||
|
||||
Returns:
|
||||
(int, list[dict]): current stream id and list of updates
|
||||
"""
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
|
||||
destination, int(from_stream_id)
|
||||
)
|
||||
if not has_changed:
|
||||
return (now_stream_id, [])
|
||||
|
||||
return self.runInteraction(
|
||||
"get_devices_by_remote", self._get_devices_by_remote_txn,
|
||||
destination, from_stream_id, now_stream_id,
|
||||
)
|
||||
|
||||
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
|
||||
now_stream_id):
|
||||
sql = """
|
||||
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
|
||||
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
||||
GROUP BY user_id, device_id
|
||||
LIMIT 20
|
||||
"""
|
||||
txn.execute(
|
||||
sql, (destination, from_stream_id, now_stream_id, False)
|
||||
)
|
||||
|
||||
# maps (user_id, device_id) -> stream_id
|
||||
query_map = {(r[0], r[1]): r[2] for r in txn}
|
||||
if not query_map:
|
||||
return (now_stream_id, [])
|
||||
|
||||
if len(query_map) >= 20:
|
||||
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
|
||||
|
||||
devices = self._get_e2e_device_keys_txn(
|
||||
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
|
||||
)
|
||||
|
||||
prev_sent_id_sql = """
|
||||
SELECT coalesce(max(stream_id), 0) as stream_id
|
||||
FROM device_lists_outbound_last_success
|
||||
WHERE destination = ? AND user_id = ? AND stream_id <= ?
|
||||
"""
|
||||
|
||||
results = []
|
||||
for user_id, user_devices in iteritems(devices):
|
||||
# The prev_id for the first row is always the last row before
|
||||
# `from_stream_id`
|
||||
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
|
||||
rows = txn.fetchall()
|
||||
prev_id = rows[0][0]
|
||||
for device_id, device in iteritems(user_devices):
|
||||
stream_id = query_map[(user_id, device_id)]
|
||||
result = {
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"prev_id": [prev_id] if prev_id else [],
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
prev_id = stream_id
|
||||
|
||||
if device is not None:
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = db_to_json(key_json)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
else:
|
||||
result["deleted"] = True
|
||||
|
||||
results.append(result)
|
||||
|
||||
return (now_stream_id, results)
|
||||
|
||||
def mark_as_sent_devices_by_remote(self, destination, stream_id):
|
||||
"""Mark that updates have successfully been sent to the destination.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
|
||||
destination, stream_id,
|
||||
)
|
||||
|
||||
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
|
||||
# We update the device_lists_outbound_last_success with the successfully
|
||||
# poked users. We do the join to see which users need to be inserted and
|
||||
# which updated.
|
||||
sql = """
|
||||
SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
|
||||
FROM device_lists_outbound_pokes as o
|
||||
LEFT JOIN device_lists_outbound_last_success as s
|
||||
USING (destination, user_id)
|
||||
WHERE destination = ? AND o.stream_id <= ?
|
||||
GROUP BY user_id
|
||||
"""
|
||||
txn.execute(sql, (destination, stream_id,))
|
||||
rows = txn.fetchall()
|
||||
|
||||
sql = """
|
||||
UPDATE device_lists_outbound_last_success
|
||||
SET stream_id = ?
|
||||
WHERE destination = ? AND user_id = ?
|
||||
"""
|
||||
txn.executemany(
|
||||
sql, ((row[1], destination, row[0],) for row in rows if row[2])
|
||||
)
|
||||
|
||||
sql = """
|
||||
INSERT INTO device_lists_outbound_last_success
|
||||
(destination, user_id, stream_id) VALUES (?, ?, ?)
|
||||
"""
|
||||
txn.executemany(
|
||||
sql, ((destination, row[0], row[1],) for row in rows if not row[2])
|
||||
)
|
||||
|
||||
# Delete all sent outbound pokes
|
||||
sql = """
|
||||
DELETE FROM device_lists_outbound_pokes
|
||||
WHERE destination = ? AND stream_id <= ?
|
||||
"""
|
||||
txn.execute(sql, (destination, stream_id,))
|
||||
|
||||
def get_device_stream_token(self):
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_devices_from_cache(self, query_list):
|
||||
"""Get the devices (and keys if any) for remote users from the cache.
|
||||
|
||||
Args:
|
||||
query_list(list): List of (user_id, device_ids), if device_ids is
|
||||
falsey then return all device ids for that user.
|
||||
|
||||
Returns:
|
||||
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
|
||||
a set of user_ids and results_map is a mapping of
|
||||
user_id -> device_id -> device_info
|
||||
"""
|
||||
user_ids = set(user_id for user_id, _ in query_list)
|
||||
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
|
||||
user_ids_in_cache = set(
|
||||
user_id for user_id, stream_id in user_map.items() if stream_id
|
||||
)
|
||||
user_ids_not_in_cache = user_ids - user_ids_in_cache
|
||||
|
||||
results = {}
|
||||
for user_id, device_id in query_list:
|
||||
if user_id not in user_ids_in_cache:
|
||||
continue
|
||||
|
||||
if device_id:
|
||||
device = yield self._get_cached_user_device(user_id, device_id)
|
||||
results.setdefault(user_id, {})[device_id] = device
|
||||
else:
|
||||
results[user_id] = yield self._get_cached_devices_for_user(user_id)
|
||||
|
||||
defer.returnValue((user_ids_not_in_cache, results))
|
||||
|
||||
@cachedInlineCallbacks(num_args=2, tree=True)
|
||||
def _get_cached_user_device(self, user_id, device_id):
|
||||
content = yield self._simple_select_one_onecol(
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
retcol="content",
|
||||
desc="_get_cached_user_device",
|
||||
)
|
||||
defer.returnValue(db_to_json(content))
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def _get_cached_devices_for_user(self, user_id):
|
||||
devices = yield self._simple_select_list(
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
retcols=("device_id", "content"),
|
||||
desc="_get_cached_devices_for_user",
|
||||
)
|
||||
defer.returnValue({
|
||||
device["device_id"]: db_to_json(device["content"])
|
||||
for device in devices
|
||||
})
|
||||
|
||||
def get_devices_with_keys_by_user(self, user_id):
|
||||
"""Get all devices (with any device keys) for a user
|
||||
|
||||
Returns:
|
||||
(stream_id, devices)
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_devices_with_keys_by_user",
|
||||
self._get_devices_with_keys_by_user_txn, user_id,
|
||||
)
|
||||
|
||||
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
devices = self._get_e2e_device_keys_txn(
|
||||
txn, [(user_id, None)], include_all_devices=True
|
||||
)
|
||||
|
||||
if devices:
|
||||
user_devices = devices[user_id]
|
||||
results = []
|
||||
for device_id, device in iteritems(user_devices):
|
||||
result = {
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = db_to_json(key_json)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
|
||||
results.append(result)
|
||||
|
||||
return now_stream_id, results
|
||||
|
||||
return now_stream_id, []
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_whose_devices_changed(self, from_key):
|
||||
"""Get set of users whose devices have changed since `from_key`.
|
||||
"""
|
||||
from_key = int(from_key)
|
||||
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
|
||||
if changed is not None:
|
||||
defer.returnValue(set(changed))
|
||||
|
||||
sql = """
|
||||
SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
|
||||
"""
|
||||
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
|
||||
defer.returnValue(set(row[0] for row in rows))
|
||||
|
||||
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
|
||||
"""Return a list of `(stream_id, user_id, destination)` which is the
|
||||
combined list of changes to devices, and which destinations need to be
|
||||
poked. `destination` may be None if no destinations need to be poked.
|
||||
"""
|
||||
# We do a group by here as there can be a large number of duplicate
|
||||
# entries, since we throw away device IDs.
|
||||
sql = """
|
||||
SELECT MAX(stream_id) AS stream_id, user_id, destination
|
||||
FROM device_lists_stream
|
||||
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
GROUP BY user_id, destination
|
||||
"""
|
||||
return self._execute(
|
||||
"get_all_device_list_changes_for_remotes", None,
|
||||
sql, from_key, to_key
|
||||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def get_device_list_last_stream_id_for_remote(self, user_id):
|
||||
"""Get the last stream_id we got for a user. May be None if we haven't
|
||||
got any information for them.
|
||||
"""
|
||||
return self._simple_select_one_onecol(
|
||||
table="device_lists_remote_extremeties",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="stream_id",
|
||||
desc="get_device_list_last_stream_id_for_remote",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||
list_name="user_ids", inlineCallbacks=True)
|
||||
def get_device_list_last_stream_id_for_remotes(self, user_ids):
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="device_lists_remote_extremeties",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
retcols=("user_id", "stream_id",),
|
||||
desc="get_device_list_last_stream_id_for_remotes",
|
||||
)
|
||||
|
||||
results = {user_id: None for user_id in user_ids}
|
||||
results.update({
|
||||
row["user_id"]: row["stream_id"] for row in rows
|
||||
})
|
||||
|
||||
defer.returnValue(results)
|
||||
|
||||
|
||||
class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(DeviceStore, self).__init__(db_conn, hs)
|
||||
|
||||
@@ -121,24 +456,6 @@ class DeviceStore(BackgroundUpdateStore):
|
||||
initial_device_display_name, e)
|
||||
raise StoreError(500, "Problem storing device.")
|
||||
|
||||
def get_device(self, user_id, device_id):
|
||||
"""Retrieve a device.
|
||||
|
||||
Args:
|
||||
user_id (str): The ID of the user which owns the device
|
||||
device_id (str): The ID of the device to retrieve
|
||||
Returns:
|
||||
defer.Deferred for a dict containing the device information
|
||||
Raises:
|
||||
StoreError: if the device is not found
|
||||
"""
|
||||
return self._simple_select_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
desc="get_device",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_device(self, user_id, device_id):
|
||||
"""Delete a device.
|
||||
@@ -202,57 +519,6 @@ class DeviceStore(BackgroundUpdateStore):
|
||||
desc="update_device",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_devices_by_user(self, user_id):
|
||||
"""Retrieve all of a user's registered devices.
|
||||
|
||||
Args:
|
||||
user_id (str):
|
||||
Returns:
|
||||
defer.Deferred: resolves to a dict from device_id to a dict
|
||||
containing "device_id", "user_id" and "display_name" for each
|
||||
device.
|
||||
"""
|
||||
devices = yield self._simple_select_list(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
desc="get_devices_by_user"
|
||||
)
|
||||
|
||||
defer.returnValue({d["device_id"]: d for d in devices})
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def get_device_list_last_stream_id_for_remote(self, user_id):
|
||||
"""Get the last stream_id we got for a user. May be None if we haven't
|
||||
got any information for them.
|
||||
"""
|
||||
return self._simple_select_one_onecol(
|
||||
table="device_lists_remote_extremeties",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="stream_id",
|
||||
desc="get_device_list_remote_extremity",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||
list_name="user_ids", inlineCallbacks=True)
|
||||
def get_device_list_last_stream_id_for_remotes(self, user_ids):
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="device_lists_remote_extremeties",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
retcols=("user_id", "stream_id",),
|
||||
desc="get_user_devices_from_cache",
|
||||
)
|
||||
|
||||
results = {user_id: None for user_id in user_ids}
|
||||
results.update({
|
||||
row["user_id"]: row["stream_id"] for row in rows
|
||||
})
|
||||
|
||||
defer.returnValue(results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
|
||||
"""Mark that we no longer track device lists for remote user.
|
||||
@@ -405,268 +671,6 @@ class DeviceStore(BackgroundUpdateStore):
|
||||
lock=False,
|
||||
)
|
||||
|
||||
def get_devices_by_remote(self, destination, from_stream_id):
|
||||
"""Get stream of updates to send to remote servers
|
||||
|
||||
Returns:
|
||||
(int, list[dict]): current stream id and list of updates
|
||||
"""
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
|
||||
destination, int(from_stream_id)
|
||||
)
|
||||
if not has_changed:
|
||||
return (now_stream_id, [])
|
||||
|
||||
return self.runInteraction(
|
||||
"get_devices_by_remote", self._get_devices_by_remote_txn,
|
||||
destination, from_stream_id, now_stream_id,
|
||||
)
|
||||
|
||||
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
|
||||
now_stream_id):
|
||||
sql = """
|
||||
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
|
||||
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
|
||||
GROUP BY user_id, device_id
|
||||
LIMIT 20
|
||||
"""
|
||||
txn.execute(
|
||||
sql, (destination, from_stream_id, now_stream_id, False)
|
||||
)
|
||||
|
||||
# maps (user_id, device_id) -> stream_id
|
||||
query_map = {(r[0], r[1]): r[2] for r in txn}
|
||||
if not query_map:
|
||||
return (now_stream_id, [])
|
||||
|
||||
if len(query_map) >= 20:
|
||||
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
|
||||
|
||||
devices = self._get_e2e_device_keys_txn(
|
||||
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
|
||||
)
|
||||
|
||||
prev_sent_id_sql = """
|
||||
SELECT coalesce(max(stream_id), 0) as stream_id
|
||||
FROM device_lists_outbound_last_success
|
||||
WHERE destination = ? AND user_id = ? AND stream_id <= ?
|
||||
"""
|
||||
|
||||
results = []
|
||||
for user_id, user_devices in iteritems(devices):
|
||||
# The prev_id for the first row is always the last row before
|
||||
# `from_stream_id`
|
||||
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
|
||||
rows = txn.fetchall()
|
||||
prev_id = rows[0][0]
|
||||
for device_id, device in iteritems(user_devices):
|
||||
stream_id = query_map[(user_id, device_id)]
|
||||
result = {
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"prev_id": [prev_id] if prev_id else [],
|
||||
"stream_id": stream_id,
|
||||
}
|
||||
|
||||
prev_id = stream_id
|
||||
|
||||
if device is not None:
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = db_to_json(key_json)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
else:
|
||||
result["deleted"] = True
|
||||
|
||||
results.append(result)
|
||||
|
||||
return (now_stream_id, results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_devices_from_cache(self, query_list):
|
||||
"""Get the devices (and keys if any) for remote users from the cache.
|
||||
|
||||
Args:
|
||||
query_list(list): List of (user_id, device_ids), if device_ids is
|
||||
falsey then return all device ids for that user.
|
||||
|
||||
Returns:
|
||||
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
|
||||
a set of user_ids and results_map is a mapping of
|
||||
user_id -> device_id -> device_info
|
||||
"""
|
||||
user_ids = set(user_id for user_id, _ in query_list)
|
||||
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
|
||||
user_ids_in_cache = set(
|
||||
user_id for user_id, stream_id in user_map.items() if stream_id
|
||||
)
|
||||
user_ids_not_in_cache = user_ids - user_ids_in_cache
|
||||
|
||||
results = {}
|
||||
for user_id, device_id in query_list:
|
||||
if user_id not in user_ids_in_cache:
|
||||
continue
|
||||
|
||||
if device_id:
|
||||
device = yield self._get_cached_user_device(user_id, device_id)
|
||||
results.setdefault(user_id, {})[device_id] = device
|
||||
else:
|
||||
results[user_id] = yield self._get_cached_devices_for_user(user_id)
|
||||
|
||||
defer.returnValue((user_ids_not_in_cache, results))
|
||||
|
||||
@cachedInlineCallbacks(num_args=2, tree=True)
|
||||
def _get_cached_user_device(self, user_id, device_id):
|
||||
content = yield self._simple_select_one_onecol(
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
retcol="content",
|
||||
desc="_get_cached_user_device",
|
||||
)
|
||||
defer.returnValue(db_to_json(content))
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def _get_cached_devices_for_user(self, user_id):
|
||||
devices = yield self._simple_select_list(
|
||||
table="device_lists_remote_cache",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
},
|
||||
retcols=("device_id", "content"),
|
||||
desc="_get_cached_devices_for_user",
|
||||
)
|
||||
defer.returnValue({
|
||||
device["device_id"]: db_to_json(device["content"])
|
||||
for device in devices
|
||||
})
|
||||
|
||||
def get_devices_with_keys_by_user(self, user_id):
|
||||
"""Get all devices (with any device keys) for a user
|
||||
|
||||
Returns:
|
||||
(stream_id, devices)
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_devices_with_keys_by_user",
|
||||
self._get_devices_with_keys_by_user_txn, user_id,
|
||||
)
|
||||
|
||||
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
devices = self._get_e2e_device_keys_txn(
|
||||
txn, [(user_id, None)], include_all_devices=True
|
||||
)
|
||||
|
||||
if devices:
|
||||
user_devices = devices[user_id]
|
||||
results = []
|
||||
for device_id, device in iteritems(user_devices):
|
||||
result = {
|
||||
"device_id": device_id,
|
||||
}
|
||||
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = db_to_json(key_json)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
|
||||
results.append(result)
|
||||
|
||||
return now_stream_id, results
|
||||
|
||||
return now_stream_id, []
|
||||
|
||||
def mark_as_sent_devices_by_remote(self, destination, stream_id):
|
||||
"""Mark that updates have successfully been sent to the destination.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
|
||||
destination, stream_id,
|
||||
)
|
||||
|
||||
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
|
||||
# We update the device_lists_outbound_last_success with the successfully
|
||||
# poked users. We do the join to see which users need to be inserted and
|
||||
# which updated.
|
||||
sql = """
|
||||
SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
|
||||
FROM device_lists_outbound_pokes as o
|
||||
LEFT JOIN device_lists_outbound_last_success as s
|
||||
USING (destination, user_id)
|
||||
WHERE destination = ? AND o.stream_id <= ?
|
||||
GROUP BY user_id
|
||||
"""
|
||||
txn.execute(sql, (destination, stream_id,))
|
||||
rows = txn.fetchall()
|
||||
|
||||
sql = """
|
||||
UPDATE device_lists_outbound_last_success
|
||||
SET stream_id = ?
|
||||
WHERE destination = ? AND user_id = ?
|
||||
"""
|
||||
txn.executemany(
|
||||
sql, ((row[1], destination, row[0],) for row in rows if row[2])
|
||||
)
|
||||
|
||||
sql = """
|
||||
INSERT INTO device_lists_outbound_last_success
|
||||
(destination, user_id, stream_id) VALUES (?, ?, ?)
|
||||
"""
|
||||
txn.executemany(
|
||||
sql, ((destination, row[0], row[1],) for row in rows if not row[2])
|
||||
)
|
||||
|
||||
# Delete all sent outbound pokes
|
||||
sql = """
|
||||
DELETE FROM device_lists_outbound_pokes
|
||||
WHERE destination = ? AND stream_id <= ?
|
||||
"""
|
||||
txn.execute(sql, (destination, stream_id,))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_whose_devices_changed(self, from_key):
|
||||
"""Get set of users whose devices have changed since `from_key`.
|
||||
"""
|
||||
from_key = int(from_key)
|
||||
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
|
||||
if changed is not None:
|
||||
defer.returnValue(set(changed))
|
||||
|
||||
sql = """
|
||||
SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
|
||||
"""
|
||||
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
|
||||
defer.returnValue(set(row[0] for row in rows))
|
||||
|
||||
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
|
||||
"""Return a list of `(stream_id, user_id, destination)` which is the
|
||||
combined list of changes to devices, and which destinations need to be
|
||||
poked. `destination` may be None if no destinations need to be poked.
|
||||
"""
|
||||
# We do a group by here as there can be a large number of duplicate
|
||||
# entries, since we throw away device IDs.
|
||||
sql = """
|
||||
SELECT MAX(stream_id) AS stream_id, user_id, destination
|
||||
FROM device_lists_stream
|
||||
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
GROUP BY user_id, destination
|
||||
"""
|
||||
return self._execute(
|
||||
"get_all_device_list_changes_for_remotes", None,
|
||||
sql, from_key, to_key
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_device_change_to_streams(self, user_id, device_ids, hosts):
|
||||
"""Persist that a user's devices have been updated, and which hosts
|
||||
@@ -732,9 +736,6 @@ class DeviceStore(BackgroundUpdateStore):
|
||||
]
|
||||
)
|
||||
|
||||
def get_device_stream_token(self):
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
|
||||
def _prune_old_outbound_device_pokes(self):
|
||||
"""Delete old entries out of the device_lists_outbound_pokes to ensure
|
||||
that we don't fill up due to dead servers. We keep one entry per
|
||||
|
||||
@@ -23,49 +23,7 @@ from synapse.util.caches.descriptors import cached
|
||||
from ._base import SQLBaseStore, db_to_json
|
||||
|
||||
|
||||
class EndToEndKeyStore(SQLBaseStore):
|
||||
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
||||
"""Stores device keys for a device. Returns whether there was a change
|
||||
or the keys were already in the database.
|
||||
"""
|
||||
def _set_e2e_device_keys_txn(txn):
|
||||
old_key_json = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="e2e_device_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
retcol="key_json",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
# In py3 we need old_key_json to match new_key_json type. The DB
|
||||
# returns unicode while encode_canonical_json returns bytes.
|
||||
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
|
||||
|
||||
if old_key_json == new_key_json:
|
||||
return False
|
||||
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="e2e_device_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
values={
|
||||
"ts_added_ms": time_now,
|
||||
"key_json": new_key_json,
|
||||
}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
return self.runInteraction(
|
||||
"set_e2e_device_keys", _set_e2e_device_keys_txn
|
||||
)
|
||||
|
||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_device_keys(
|
||||
self, query_list, include_all_devices=False,
|
||||
@@ -238,6 +196,50 @@ class EndToEndKeyStore(SQLBaseStore):
|
||||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||
)
|
||||
|
||||
|
||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
||||
"""Stores device keys for a device. Returns whether there was a change
|
||||
or the keys were already in the database.
|
||||
"""
|
||||
def _set_e2e_device_keys_txn(txn):
|
||||
old_key_json = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="e2e_device_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
retcol="key_json",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
# In py3 we need old_key_json to match new_key_json type. The DB
|
||||
# returns unicode while encode_canonical_json returns bytes.
|
||||
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
|
||||
|
||||
if old_key_json == new_key_json:
|
||||
return False
|
||||
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
table="e2e_device_keys_json",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
},
|
||||
values={
|
||||
"ts_added_ms": time_now,
|
||||
"key_json": new_key_json,
|
||||
}
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
return self.runInteraction(
|
||||
"set_e2e_device_keys", _set_e2e_device_keys_txn
|
||||
)
|
||||
|
||||
def claim_e2e_one_time_keys(self, query_list):
|
||||
"""Take a list of one time keys out of the database"""
|
||||
def _claim_e2e_one_time_keys(txn):
|
||||
|
||||
@@ -442,6 +442,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
|
||||
event_results.reverse()
|
||||
return event_results
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_successor_events(self, event_ids):
|
||||
"""Fetch all events that have the given events as a prev event
|
||||
|
||||
Args:
|
||||
event_ids (iterable[str])
|
||||
|
||||
Returns:
|
||||
Deferred[list[str]]
|
||||
"""
|
||||
rows = yield self._simple_select_many_batch(
|
||||
table="event_edges",
|
||||
column="prev_event_id",
|
||||
iterable=event_ids,
|
||||
retcols=("event_id",),
|
||||
desc="get_successor_events"
|
||||
)
|
||||
|
||||
defer.returnValue([
|
||||
row["event_id"] for row in rows
|
||||
])
|
||||
|
||||
|
||||
class EventFederationStore(EventFederationWorkerStore):
|
||||
""" Responsible for storing and serving up the various graphs associated
|
||||
|
||||
@@ -346,15 +346,23 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
|
||||
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
|
||||
user_id, event_id, data, stream_id):
|
||||
"""Inserts a read-receipt into the database if it's newer than the current RR
|
||||
|
||||
Returns: int|None
|
||||
None if the RR is older than the current RR
|
||||
otherwise, the rx timestamp of the event that the RR corresponds to
|
||||
(or 0 if the event is unknown)
|
||||
"""
|
||||
res = self._simple_select_one_txn(
|
||||
txn,
|
||||
table="events",
|
||||
retcols=["topological_ordering", "stream_ordering"],
|
||||
retcols=["stream_ordering", "received_ts"],
|
||||
keyvalues={"event_id": event_id},
|
||||
allow_none=True
|
||||
)
|
||||
|
||||
stream_ordering = int(res["stream_ordering"]) if res else None
|
||||
rx_ts = res["received_ts"] if res else 0
|
||||
|
||||
# We don't want to clobber receipts for more recent events, so we
|
||||
# have to compare orderings of existing receipts
|
||||
@@ -373,7 +381,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
"one for later event %s",
|
||||
event_id, eid,
|
||||
)
|
||||
return False
|
||||
return None
|
||||
|
||||
txn.call_after(
|
||||
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
|
||||
@@ -429,7 +437,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
stream_ordering=stream_ordering,
|
||||
)
|
||||
|
||||
return True
|
||||
return rx_ts
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data):
|
||||
@@ -466,7 +474,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
|
||||
stream_id_manager = self._receipts_id_gen.get_next()
|
||||
with stream_id_manager as stream_id:
|
||||
have_persisted = yield self.runInteraction(
|
||||
event_ts = yield self.runInteraction(
|
||||
"insert_linearized_receipt",
|
||||
self.insert_linearized_receipt_txn,
|
||||
room_id, receipt_type, user_id, linearized_event_id,
|
||||
@@ -474,8 +482,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
stream_id=stream_id,
|
||||
)
|
||||
|
||||
if not have_persisted:
|
||||
defer.returnValue(None)
|
||||
if event_ts is None:
|
||||
defer.returnValue(None)
|
||||
|
||||
now = self._clock.time_msec()
|
||||
logger.debug(
|
||||
"RR for event %s in %s (%i ms old)",
|
||||
linearized_event_id, room_id, now - event_ts,
|
||||
)
|
||||
|
||||
yield self.insert_graph_receipt(
|
||||
room_id, receipt_type, user_id, event_ids, data
|
||||
|
||||
@@ -295,6 +295,39 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
return ret['user_id']
|
||||
return None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
||||
yield self._simple_upsert("user_threepids", {
|
||||
"medium": medium,
|
||||
"address": address,
|
||||
}, {
|
||||
"user_id": user_id,
|
||||
"validated_at": validated_at,
|
||||
"added_at": added_at,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_get_threepids(self, user_id):
|
||||
ret = yield self._simple_select_list(
|
||||
"user_threepids", {
|
||||
"user_id": user_id
|
||||
},
|
||||
['medium', 'address', 'validated_at', 'added_at'],
|
||||
'user_get_threepids'
|
||||
)
|
||||
defer.returnValue(ret)
|
||||
|
||||
def user_delete_threepid(self, user_id, medium, address):
|
||||
return self._simple_delete(
|
||||
"user_threepids",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"medium": medium,
|
||||
"address": address,
|
||||
},
|
||||
desc="user_delete_threepids",
|
||||
)
|
||||
|
||||
|
||||
class RegistrationStore(RegistrationWorkerStore,
|
||||
background_updates.BackgroundUpdateStore):
|
||||
@@ -632,39 +665,6 @@ class RegistrationStore(RegistrationWorkerStore,
|
||||
|
||||
defer.returnValue(res if res else False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
|
||||
yield self._simple_upsert("user_threepids", {
|
||||
"medium": medium,
|
||||
"address": address,
|
||||
}, {
|
||||
"user_id": user_id,
|
||||
"validated_at": validated_at,
|
||||
"added_at": added_at,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_get_threepids(self, user_id):
|
||||
ret = yield self._simple_select_list(
|
||||
"user_threepids", {
|
||||
"user_id": user_id
|
||||
},
|
||||
['medium', 'address', 'validated_at', 'added_at'],
|
||||
'user_get_threepids'
|
||||
)
|
||||
defer.returnValue(ret)
|
||||
|
||||
def user_delete_threepid(self, user_id, medium, address):
|
||||
return self._simple_delete(
|
||||
"user_threepids",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"medium": medium,
|
||||
"address": address,
|
||||
},
|
||||
desc="user_delete_threepids",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def save_or_get_3pid_guest_access_token(
|
||||
self, medium, address, access_token, inviter_user_id
|
||||
|
||||
@@ -216,28 +216,36 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def filter_events_for_server(store, server_name, events):
|
||||
# Whatever else we do, we need to check for senders which have requested
|
||||
# erasure of their data.
|
||||
erased_senders = yield store.are_users_erased(
|
||||
(e.sender for e in events),
|
||||
)
|
||||
def filter_events_for_server(store, server_name, events, redact=True,
|
||||
check_history_visibility_only=False):
|
||||
"""Filter a list of events based on whether given server is allowed to
|
||||
see them.
|
||||
|
||||
def redact_disallowed(event, state):
|
||||
# if the sender has been gdpr17ed, always return a redacted
|
||||
# copy of the event.
|
||||
if erased_senders[event.sender]:
|
||||
Args:
|
||||
store (DataStore)
|
||||
server_name (str)
|
||||
events (iterable[FrozenEvent])
|
||||
redact (bool): Whether to return a redacted version of the event, or
|
||||
to filter them out entirely.
|
||||
check_history_visibility_only (bool): Whether to only check the
|
||||
history visibility, rather than things like if the sender has been
|
||||
erased. This is used e.g. during pagination to decide whether to
|
||||
backfill or not.
|
||||
|
||||
Returns
|
||||
Deferred[list[FrozenEvent]]
|
||||
"""
|
||||
|
||||
def is_sender_erased(event, erased_senders):
|
||||
if erased_senders and erased_senders[event.sender]:
|
||||
logger.info(
|
||||
"Sender of %s has been erased, redacting",
|
||||
event.event_id,
|
||||
)
|
||||
return prune_event(event)
|
||||
|
||||
# state will be None if we decided we didn't need to filter by
|
||||
# room membership.
|
||||
if not state:
|
||||
return event
|
||||
return True
|
||||
return False
|
||||
|
||||
def check_event_is_visible(event, state):
|
||||
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||
if history:
|
||||
visibility = history.content.get("history_visibility", "shared")
|
||||
@@ -259,17 +267,17 @@ def filter_events_for_server(store, server_name, events):
|
||||
|
||||
memtype = ev.membership
|
||||
if memtype == Membership.JOIN:
|
||||
return event
|
||||
return True
|
||||
elif memtype == Membership.INVITE:
|
||||
if visibility == "invited":
|
||||
return event
|
||||
return True
|
||||
else:
|
||||
# server has no users in the room: redact
|
||||
return prune_event(event)
|
||||
return False
|
||||
|
||||
return event
|
||||
return True
|
||||
|
||||
# Next lets check to see if all the events have a history visibility
|
||||
# Lets check to see if all the events have a history visibility
|
||||
# of "shared" or "world_readable". If thats the case then we don't
|
||||
# need to check membership (as we know the server is in the room).
|
||||
event_to_state_ids = yield store.get_state_ids_for_events(
|
||||
@@ -296,16 +304,31 @@ def filter_events_for_server(store, server_name, events):
|
||||
for e in itervalues(event_map)
|
||||
)
|
||||
|
||||
if not check_history_visibility_only:
|
||||
erased_senders = yield store.are_users_erased(
|
||||
(e.sender for e in events),
|
||||
)
|
||||
else:
|
||||
# We don't want to check whether users are erased, which is equivalent
|
||||
# to no users having been erased.
|
||||
erased_senders = {}
|
||||
|
||||
if all_open:
|
||||
# all the history_visibility state affecting these events is open, so
|
||||
# we don't need to filter by membership state. We *do* need to check
|
||||
# for user erasure, though.
|
||||
if erased_senders:
|
||||
events = [
|
||||
redact_disallowed(e, None)
|
||||
for e in events
|
||||
]
|
||||
to_return = []
|
||||
for e in events:
|
||||
if not is_sender_erased(e, erased_senders):
|
||||
to_return.append(e)
|
||||
elif redact:
|
||||
to_return.append(prune_event(e))
|
||||
|
||||
defer.returnValue(to_return)
|
||||
|
||||
# If there are no erased users then we can just return the given list
|
||||
# of events without having to copy it.
|
||||
defer.returnValue(events)
|
||||
|
||||
# Ok, so we're dealing with events that have non-trivial visibility
|
||||
@@ -361,7 +384,13 @@ def filter_events_for_server(store, server_name, events):
|
||||
for e_id, key_to_eid in iteritems(event_to_state_ids)
|
||||
}
|
||||
|
||||
defer.returnValue([
|
||||
redact_disallowed(e, event_to_state[e.event_id])
|
||||
for e in events
|
||||
])
|
||||
to_return = []
|
||||
for e in events:
|
||||
erased = is_sender_erased(e, erased_senders)
|
||||
visible = check_event_is_visible(e, event_to_state[e.event_id])
|
||||
if visible and not erased:
|
||||
to_return.append(e)
|
||||
elif redact:
|
||||
to_return.append(prune_event(e))
|
||||
|
||||
defer.returnValue(to_return)
|
||||
|
||||
@@ -24,13 +24,17 @@ from synapse.api.errors import AuthError
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import register_federation_servlets
|
||||
|
||||
from ..utils import (
|
||||
DeferredMockCallable,
|
||||
MockClock,
|
||||
MockHttpResource,
|
||||
setup_test_homeserver,
|
||||
)
|
||||
# Some local users to test with
|
||||
U_APPLE = UserID.from_string("@apple:test")
|
||||
U_BANANA = UserID.from_string("@banana:test")
|
||||
|
||||
# Remote user
|
||||
U_ONION = UserID.from_string("@onion:farm")
|
||||
|
||||
# Test room id
|
||||
ROOM_ID = "a-room"
|
||||
|
||||
|
||||
def _expect_edu_transaction(edu_type, content, origin="test"):
|
||||
@@ -46,30 +50,21 @@ def _make_edu_transaction_json(edu_type, content):
|
||||
return json.dumps(_expect_edu_transaction(edu_type, content)).encode('utf8')
|
||||
|
||||
|
||||
class TypingNotificationsTestCase(unittest.TestCase):
|
||||
"""Tests typing notifications to rooms."""
|
||||
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [register_federation_servlets]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.clock = MockClock()
|
||||
def make_homeserver(self, reactor, clock):
|
||||
# we mock out the keyring so as to skip the authentication check on the
|
||||
# federation API call.
|
||||
mock_keyring = Mock(spec=["verify_json_for_server"])
|
||||
mock_keyring.verify_json_for_server.return_value = defer.succeed(True)
|
||||
|
||||
self.mock_http_client = Mock(spec=[])
|
||||
self.mock_http_client.put_json = DeferredMockCallable()
|
||||
# we mock out the federation client too
|
||||
mock_federation_client = Mock(spec=["put_json"])
|
||||
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
|
||||
|
||||
self.mock_federation_resource = MockHttpResource()
|
||||
|
||||
mock_notifier = Mock()
|
||||
self.on_new_event = mock_notifier.on_new_event
|
||||
|
||||
self.auth = Mock(spec=[])
|
||||
self.state_handler = Mock()
|
||||
|
||||
hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
"test",
|
||||
auth=self.auth,
|
||||
clock=self.clock,
|
||||
datastore=Mock(
|
||||
hs = self.setup_test_homeserver(
|
||||
datastore=(Mock(
|
||||
spec=[
|
||||
# Bits that Federation needs
|
||||
"prep_send_transaction",
|
||||
@@ -82,16 +77,21 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
"get_user_directory_stream_pos",
|
||||
"get_current_state_deltas",
|
||||
]
|
||||
),
|
||||
state_handler=self.state_handler,
|
||||
handlers=Mock(),
|
||||
notifier=mock_notifier,
|
||||
resource_for_client=Mock(),
|
||||
resource_for_federation=self.mock_federation_resource,
|
||||
http_client=self.mock_http_client,
|
||||
keyring=Mock(),
|
||||
)),
|
||||
notifier=Mock(),
|
||||
http_client=mock_federation_client,
|
||||
keyring=mock_keyring,
|
||||
)
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# the tests assume that we are starting at unix time 1000
|
||||
reactor.pump((1000, ))
|
||||
|
||||
mock_notifier = hs.get_notifier()
|
||||
self.on_new_event = mock_notifier.on_new_event
|
||||
|
||||
self.handler = hs.get_typing_handler()
|
||||
|
||||
self.event_source = hs.get_event_sources().sources["typing"]
|
||||
@@ -109,13 +109,12 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
|
||||
self.datastore.get_received_txn_response = get_received_txn_response
|
||||
|
||||
self.room_id = "a-room"
|
||||
|
||||
self.room_members = []
|
||||
|
||||
def check_joined_room(room_id, user_id):
|
||||
if user_id not in [u.to_string() for u in self.room_members]:
|
||||
raise AuthError(401, "User is not in the room")
|
||||
hs.get_auth().check_joined_room = check_joined_room
|
||||
|
||||
def get_joined_hosts_for_room(room_id):
|
||||
return set(member.domain for member in self.room_members)
|
||||
@@ -124,8 +123,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
|
||||
def get_current_user_in_room(room_id):
|
||||
return set(str(u) for u in self.room_members)
|
||||
|
||||
self.state_handler.get_current_user_in_room = get_current_user_in_room
|
||||
hs.get_state_handler().get_current_user_in_room = get_current_user_in_room
|
||||
|
||||
self.datastore.get_user_directory_stream_pos.return_value = (
|
||||
# we deliberately return a non-None stream pos to avoid doing an initial_spam
|
||||
@@ -134,230 +132,208 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
|
||||
self.datastore.get_current_state_deltas.return_value = None
|
||||
|
||||
self.auth.check_joined_room = check_joined_room
|
||||
|
||||
self.datastore.get_to_device_stream_token = lambda: 0
|
||||
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0)
|
||||
self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None
|
||||
|
||||
# Some local users to test with
|
||||
self.u_apple = UserID.from_string("@apple:test")
|
||||
self.u_banana = UserID.from_string("@banana:test")
|
||||
|
||||
# Remote user
|
||||
self.u_onion = UserID.from_string("@onion:farm")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_started_typing_local(self):
|
||||
self.room_members = [self.u_apple, self.u_banana]
|
||||
self.room_members = [U_APPLE, U_BANANA]
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||
|
||||
yield self.handler.started_typing(
|
||||
target_user=self.u_apple,
|
||||
auth_user=self.u_apple,
|
||||
room_id=self.room_id,
|
||||
self.successResultOf(self.handler.started_typing(
|
||||
target_user=U_APPLE,
|
||||
auth_user=U_APPLE,
|
||||
room_id=ROOM_ID,
|
||||
timeout=20000,
|
||||
)
|
||||
))
|
||||
|
||||
self.on_new_event.assert_has_calls(
|
||||
[call('typing_key', 1, rooms=[self.room_id])]
|
||||
[call('typing_key', 1, rooms=[ROOM_ID])]
|
||||
)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
events = yield self.event_source.get_new_events(
|
||||
room_ids=[self.room_id], from_key=0
|
||||
events = self.event_source.get_new_events(
|
||||
room_ids=[ROOM_ID], from_key=0
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
{
|
||||
"type": "m.typing",
|
||||
"room_id": self.room_id,
|
||||
"content": {"user_ids": [self.u_apple.to_string()]},
|
||||
"room_id": ROOM_ID,
|
||||
"content": {"user_ids": [U_APPLE.to_string()]},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_started_typing_remote_send(self):
|
||||
self.room_members = [self.u_apple, self.u_onion]
|
||||
self.room_members = [U_APPLE, U_ONION]
|
||||
|
||||
put_json = self.mock_http_client.put_json
|
||||
put_json.expect_call_and_return(
|
||||
call(
|
||||
"farm",
|
||||
path="/_matrix/federation/v1/send/1000000/",
|
||||
data=_expect_edu_transaction(
|
||||
"m.typing",
|
||||
content={
|
||||
"room_id": self.room_id,
|
||||
"user_id": self.u_apple.to_string(),
|
||||
"typing": True,
|
||||
},
|
||||
),
|
||||
json_data_callback=ANY,
|
||||
long_retries=True,
|
||||
backoff_on_404=True,
|
||||
),
|
||||
defer.succeed((200, "OK")),
|
||||
)
|
||||
|
||||
yield self.handler.started_typing(
|
||||
target_user=self.u_apple,
|
||||
auth_user=self.u_apple,
|
||||
room_id=self.room_id,
|
||||
self.successResultOf(self.handler.started_typing(
|
||||
target_user=U_APPLE,
|
||||
auth_user=U_APPLE,
|
||||
room_id=ROOM_ID,
|
||||
timeout=20000,
|
||||
))
|
||||
|
||||
put_json = self.hs.get_http_client().put_json
|
||||
put_json.assert_called_once_with(
|
||||
"farm",
|
||||
path="/_matrix/federation/v1/send/1000000/",
|
||||
data=_expect_edu_transaction(
|
||||
"m.typing",
|
||||
content={
|
||||
"room_id": ROOM_ID,
|
||||
"user_id": U_APPLE.to_string(),
|
||||
"typing": True,
|
||||
},
|
||||
),
|
||||
json_data_callback=ANY,
|
||||
long_retries=True,
|
||||
backoff_on_404=True,
|
||||
)
|
||||
|
||||
yield put_json.await_calls()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_started_typing_remote_recv(self):
|
||||
self.room_members = [self.u_apple, self.u_onion]
|
||||
self.room_members = [U_APPLE, U_ONION]
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||
|
||||
(code, response) = yield self.mock_federation_resource.trigger(
|
||||
(request, channel) = self.make_request(
|
||||
"PUT",
|
||||
"/_matrix/federation/v1/send/1000000/",
|
||||
_make_edu_transaction_json(
|
||||
"m.typing",
|
||||
content={
|
||||
"room_id": self.room_id,
|
||||
"user_id": self.u_onion.to_string(),
|
||||
"room_id": ROOM_ID,
|
||||
"user_id": U_ONION.to_string(),
|
||||
"typing": True,
|
||||
},
|
||||
),
|
||||
federation_auth_origin=b'farm',
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
self.on_new_event.assert_has_calls(
|
||||
[call('typing_key', 1, rooms=[self.room_id])]
|
||||
[call('typing_key', 1, rooms=[ROOM_ID])]
|
||||
)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
events = yield self.event_source.get_new_events(
|
||||
room_ids=[self.room_id], from_key=0
|
||||
events = self.event_source.get_new_events(
|
||||
room_ids=[ROOM_ID], from_key=0
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
{
|
||||
"type": "m.typing",
|
||||
"room_id": self.room_id,
|
||||
"content": {"user_ids": [self.u_onion.to_string()]},
|
||||
"room_id": ROOM_ID,
|
||||
"content": {"user_ids": [U_ONION.to_string()]},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_stopped_typing(self):
|
||||
self.room_members = [self.u_apple, self.u_banana, self.u_onion]
|
||||
|
||||
put_json = self.mock_http_client.put_json
|
||||
put_json.expect_call_and_return(
|
||||
call(
|
||||
"farm",
|
||||
path="/_matrix/federation/v1/send/1000000/",
|
||||
data=_expect_edu_transaction(
|
||||
"m.typing",
|
||||
content={
|
||||
"room_id": self.room_id,
|
||||
"user_id": self.u_apple.to_string(),
|
||||
"typing": False,
|
||||
},
|
||||
),
|
||||
json_data_callback=ANY,
|
||||
long_retries=True,
|
||||
backoff_on_404=True,
|
||||
),
|
||||
defer.succeed((200, "OK")),
|
||||
)
|
||||
self.room_members = [U_APPLE, U_BANANA, U_ONION]
|
||||
|
||||
# Gut-wrenching
|
||||
from synapse.handlers.typing import RoomMember
|
||||
|
||||
member = RoomMember(self.room_id, self.u_apple.to_string())
|
||||
member = RoomMember(ROOM_ID, U_APPLE.to_string())
|
||||
self.handler._member_typing_until[member] = 1002000
|
||||
self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()])
|
||||
self.handler._room_typing[ROOM_ID] = set([U_APPLE.to_string()])
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||
|
||||
yield self.handler.stopped_typing(
|
||||
target_user=self.u_apple, auth_user=self.u_apple, room_id=self.room_id
|
||||
)
|
||||
self.successResultOf(self.handler.stopped_typing(
|
||||
target_user=U_APPLE, auth_user=U_APPLE, room_id=ROOM_ID
|
||||
))
|
||||
|
||||
self.on_new_event.assert_has_calls(
|
||||
[call('typing_key', 1, rooms=[self.room_id])]
|
||||
[call('typing_key', 1, rooms=[ROOM_ID])]
|
||||
)
|
||||
|
||||
yield put_json.await_calls()
|
||||
put_json = self.hs.get_http_client().put_json
|
||||
put_json.assert_called_once_with(
|
||||
"farm",
|
||||
path="/_matrix/federation/v1/send/1000000/",
|
||||
data=_expect_edu_transaction(
|
||||
"m.typing",
|
||||
content={
|
||||
"room_id": ROOM_ID,
|
||||
"user_id": U_APPLE.to_string(),
|
||||
"typing": False,
|
||||
},
|
||||
),
|
||||
json_data_callback=ANY,
|
||||
long_retries=True,
|
||||
backoff_on_404=True,
|
||||
)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
events = yield self.event_source.get_new_events(
|
||||
room_ids=[self.room_id], from_key=0
|
||||
events = self.event_source.get_new_events(
|
||||
room_ids=[ROOM_ID], from_key=0
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
{
|
||||
"type": "m.typing",
|
||||
"room_id": self.room_id,
|
||||
"room_id": ROOM_ID,
|
||||
"content": {"user_ids": []},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_typing_timeout(self):
|
||||
self.room_members = [self.u_apple, self.u_banana]
|
||||
self.room_members = [U_APPLE, U_BANANA]
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 0)
|
||||
|
||||
yield self.handler.started_typing(
|
||||
target_user=self.u_apple,
|
||||
auth_user=self.u_apple,
|
||||
room_id=self.room_id,
|
||||
self.successResultOf(self.handler.started_typing(
|
||||
target_user=U_APPLE,
|
||||
auth_user=U_APPLE,
|
||||
room_id=ROOM_ID,
|
||||
timeout=10000,
|
||||
)
|
||||
))
|
||||
|
||||
self.on_new_event.assert_has_calls(
|
||||
[call('typing_key', 1, rooms=[self.room_id])]
|
||||
[call('typing_key', 1, rooms=[ROOM_ID])]
|
||||
)
|
||||
self.on_new_event.reset_mock()
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
events = yield self.event_source.get_new_events(
|
||||
room_ids=[self.room_id], from_key=0
|
||||
events = self.event_source.get_new_events(
|
||||
room_ids=[ROOM_ID], from_key=0
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
{
|
||||
"type": "m.typing",
|
||||
"room_id": self.room_id,
|
||||
"content": {"user_ids": [self.u_apple.to_string()]},
|
||||
"room_id": ROOM_ID,
|
||||
"content": {"user_ids": [U_APPLE.to_string()]},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
self.clock.advance_time(16)
|
||||
self.reactor.pump([16, ])
|
||||
|
||||
self.on_new_event.assert_has_calls(
|
||||
[call('typing_key', 2, rooms=[self.room_id])]
|
||||
[call('typing_key', 2, rooms=[ROOM_ID])]
|
||||
)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 2)
|
||||
events = yield self.event_source.get_new_events(
|
||||
room_ids=[self.room_id], from_key=1
|
||||
events = self.event_source.get_new_events(
|
||||
room_ids=[ROOM_ID], from_key=1
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
{
|
||||
"type": "m.typing",
|
||||
"room_id": self.room_id,
|
||||
"room_id": ROOM_ID,
|
||||
"content": {"user_ids": []},
|
||||
}
|
||||
],
|
||||
@@ -365,29 +341,29 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||
|
||||
# SYN-230 - see if we can still set after timeout
|
||||
|
||||
yield self.handler.started_typing(
|
||||
target_user=self.u_apple,
|
||||
auth_user=self.u_apple,
|
||||
room_id=self.room_id,
|
||||
self.successResultOf(self.handler.started_typing(
|
||||
target_user=U_APPLE,
|
||||
auth_user=U_APPLE,
|
||||
room_id=ROOM_ID,
|
||||
timeout=10000,
|
||||
)
|
||||
))
|
||||
|
||||
self.on_new_event.assert_has_calls(
|
||||
[call('typing_key', 3, rooms=[self.room_id])]
|
||||
[call('typing_key', 3, rooms=[ROOM_ID])]
|
||||
)
|
||||
self.on_new_event.reset_mock()
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 3)
|
||||
events = yield self.event_source.get_new_events(
|
||||
room_ids=[self.room_id], from_key=0
|
||||
events = self.event_source.get_new_events(
|
||||
room_ids=[ROOM_ID], from_key=0
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
{
|
||||
"type": "m.typing",
|
||||
"room_id": self.room_id,
|
||||
"content": {"user_ids": [self.u_apple.to_string()]},
|
||||
"room_id": ROOM_ID,
|
||||
"content": {"user_ids": [U_APPLE.to_string()]},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
45
tests/rest/media/v1/test_base.py
Normal file
45
tests/rest/media/v1/test_base.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.rest.media.v1._base import get_filename_from_headers
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class GetFileNameFromHeadersTests(unittest.TestCase):
|
||||
# input -> expected result
|
||||
TEST_CASES = {
|
||||
b"inline; filename=abc.txt": u"abc.txt",
|
||||
b'inline; filename="azerty"': u"azerty",
|
||||
b'inline; filename="aze%20rty"': u"aze%20rty",
|
||||
b'inline; filename="aze\"rty"': u'aze"rty',
|
||||
b'inline; filename="azer;ty"': u"azer;ty",
|
||||
|
||||
b"inline; filename*=utf-8''foo%C2%A3bar": u"foo£bar",
|
||||
}
|
||||
|
||||
def tests(self):
|
||||
for hdr, expected in self.TEST_CASES.items():
|
||||
res = get_filename_from_headers(
|
||||
{
|
||||
b'Content-Disposition': [hdr],
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
res, expected,
|
||||
"expected output for %s to be %s but was %s" % (
|
||||
hdr, expected, res,
|
||||
)
|
||||
)
|
||||
@@ -137,6 +137,7 @@ def make_request(
|
||||
access_token=None,
|
||||
request=SynapseRequest,
|
||||
shorthand=True,
|
||||
federation_auth_origin=None,
|
||||
):
|
||||
"""
|
||||
Make a web request using the given method and path, feed it the
|
||||
@@ -150,9 +151,11 @@ def make_request(
|
||||
a dict.
|
||||
shorthand: Whether to try and be helpful and prefix the given URL
|
||||
with the usual REST API path, if it doesn't contain it.
|
||||
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
|
||||
Authorization header pretenting to be the given server name.
|
||||
|
||||
Returns:
|
||||
A synapse.http.site.SynapseRequest.
|
||||
Tuple[synapse.http.site.SynapseRequest, channel]
|
||||
"""
|
||||
if not isinstance(method, bytes):
|
||||
method = method.encode('ascii')
|
||||
@@ -184,6 +187,11 @@ def make_request(
|
||||
b"Authorization", b"Bearer " + access_token.encode('ascii')
|
||||
)
|
||||
|
||||
if federation_auth_origin is not None:
|
||||
req.requestHeaders.addRawHeader(
|
||||
b"Authorization", b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
|
||||
)
|
||||
|
||||
if content:
|
||||
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
|
||||
|
||||
@@ -288,9 +296,6 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
|
||||
**kwargs
|
||||
)
|
||||
|
||||
pool.runWithConnection = runWithConnection
|
||||
pool.runInteraction = runInteraction
|
||||
|
||||
class ThreadPool:
|
||||
"""
|
||||
Threadless thread pool.
|
||||
@@ -316,8 +321,12 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
|
||||
return d
|
||||
|
||||
clock.threadpool = ThreadPool()
|
||||
pool.threadpool = ThreadPool()
|
||||
pool.running = True
|
||||
|
||||
if pool:
|
||||
pool.runWithConnection = runWithConnection
|
||||
pool.runInteraction = runInteraction
|
||||
pool.threadpool = ThreadPool()
|
||||
pool.running = True
|
||||
return d
|
||||
|
||||
|
||||
|
||||
@@ -262,6 +262,7 @@ class HomeserverTestCase(TestCase):
|
||||
access_token=None,
|
||||
request=SynapseRequest,
|
||||
shorthand=True,
|
||||
federation_auth_origin=None,
|
||||
):
|
||||
"""
|
||||
Create a SynapseRequest at the path using the method and containing the
|
||||
@@ -275,15 +276,18 @@ class HomeserverTestCase(TestCase):
|
||||
a dict.
|
||||
shorthand: Whether to try and be helpful and prefix the given URL
|
||||
with the usual REST API path, if it doesn't contain it.
|
||||
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
|
||||
Authorization header pretenting to be the given server name.
|
||||
|
||||
Returns:
|
||||
A synapse.http.site.SynapseRequest.
|
||||
Tuple[synapse.http.site.SynapseRequest, channel]
|
||||
"""
|
||||
if isinstance(content, dict):
|
||||
content = json.dumps(content).encode('utf8')
|
||||
|
||||
return make_request(
|
||||
self.reactor, method, path, content, access_token, request, shorthand
|
||||
self.reactor, method, path, content, access_token, request, shorthand,
|
||||
federation_auth_origin,
|
||||
)
|
||||
|
||||
def render(self, request):
|
||||
|
||||
@@ -29,7 +29,7 @@ from twisted.internet import defer, reactor
|
||||
from synapse.api.constants import EventTypes, RoomVersions
|
||||
from synapse.api.errors import CodeMessageException, cs_error
|
||||
from synapse.config.server import ServerConfig
|
||||
from synapse.federation.transport import server
|
||||
from synapse.federation.transport import server as federation_server
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage import DataStore
|
||||
@@ -45,7 +45,9 @@ from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
# set this to True to run the tests against postgres instead of sqlite.
|
||||
USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
|
||||
LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
|
||||
POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
|
||||
POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None)
|
||||
POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None)
|
||||
POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None)
|
||||
POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
|
||||
|
||||
|
||||
@@ -58,6 +60,8 @@ def setupdb():
|
||||
"args": {
|
||||
"database": POSTGRES_BASE_DB,
|
||||
"user": POSTGRES_USER,
|
||||
"host": POSTGRES_HOST,
|
||||
"password": POSTGRES_PASSWORD,
|
||||
"cp_min": 1,
|
||||
"cp_max": 5,
|
||||
},
|
||||
@@ -66,7 +70,9 @@ def setupdb():
|
||||
config.password_providers = []
|
||||
config.database_config = pgconfig
|
||||
db_engine = create_engine(pgconfig)
|
||||
db_conn = db_engine.module.connect(user=POSTGRES_USER)
|
||||
db_conn = db_engine.module.connect(
|
||||
user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD
|
||||
)
|
||||
db_conn.autocommit = True
|
||||
cur = db_conn.cursor()
|
||||
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
|
||||
@@ -76,7 +82,10 @@ def setupdb():
|
||||
|
||||
# Set up in the db
|
||||
db_conn = db_engine.module.connect(
|
||||
database=POSTGRES_BASE_DB, user=POSTGRES_USER
|
||||
database=POSTGRES_BASE_DB,
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
password=POSTGRES_PASSWORD,
|
||||
)
|
||||
cur = db_conn.cursor()
|
||||
_get_or_create_schema_state(cur, db_engine)
|
||||
@@ -86,7 +95,9 @@ def setupdb():
|
||||
db_conn.close()
|
||||
|
||||
def _cleanup():
|
||||
db_conn = db_engine.module.connect(user=POSTGRES_USER)
|
||||
db_conn = db_engine.module.connect(
|
||||
user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD
|
||||
)
|
||||
db_conn.autocommit = True
|
||||
cur = db_conn.cursor()
|
||||
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
|
||||
@@ -142,6 +153,9 @@ def default_config(name):
|
||||
config.saml2_enabled = False
|
||||
config.public_baseurl = None
|
||||
config.default_identity_server = None
|
||||
config.key_refresh_interval = 24 * 60 * 60 * 1000
|
||||
config.old_signing_keys = {}
|
||||
config.tls_fingerprints = []
|
||||
|
||||
config.use_frozen_dicts = False
|
||||
|
||||
@@ -186,6 +200,9 @@ def setup_test_homeserver(
|
||||
Args:
|
||||
cleanup_func : The function used to register a cleanup routine for
|
||||
after the test.
|
||||
|
||||
Calling this method directly is deprecated: you should instead derive from
|
||||
HomeserverTestCase.
|
||||
"""
|
||||
if reactor is None:
|
||||
from twisted.internet import reactor
|
||||
@@ -203,7 +220,14 @@ def setup_test_homeserver(
|
||||
|
||||
config.database_config = {
|
||||
"name": "psycopg2",
|
||||
"args": {"database": test_db, "cp_min": 1, "cp_max": 5},
|
||||
"args": {
|
||||
"database": test_db,
|
||||
"host": POSTGRES_HOST,
|
||||
"password": POSTGRES_PASSWORD,
|
||||
"user": POSTGRES_USER,
|
||||
"cp_min": 1,
|
||||
"cp_max": 5,
|
||||
},
|
||||
}
|
||||
else:
|
||||
config.database_config = {
|
||||
@@ -217,7 +241,10 @@ def setup_test_homeserver(
|
||||
# the template database we generate in setupdb()
|
||||
if datastore is None and isinstance(db_engine, PostgresEngine):
|
||||
db_conn = db_engine.module.connect(
|
||||
database=POSTGRES_BASE_DB, user=POSTGRES_USER
|
||||
database=POSTGRES_BASE_DB,
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
password=POSTGRES_PASSWORD,
|
||||
)
|
||||
db_conn.autocommit = True
|
||||
cur = db_conn.cursor()
|
||||
@@ -267,7 +294,10 @@ def setup_test_homeserver(
|
||||
|
||||
# Drop the test database
|
||||
db_conn = db_engine.module.connect(
|
||||
database=POSTGRES_BASE_DB, user=POSTGRES_USER
|
||||
database=POSTGRES_BASE_DB,
|
||||
user=POSTGRES_USER,
|
||||
host=POSTGRES_HOST,
|
||||
password=POSTGRES_PASSWORD,
|
||||
)
|
||||
db_conn.autocommit = True
|
||||
cur = db_conn.cursor()
|
||||
@@ -324,23 +354,27 @@ def setup_test_homeserver(
|
||||
|
||||
fed = kargs.get("resource_for_federation", None)
|
||||
if fed:
|
||||
server.register_servlets(
|
||||
hs,
|
||||
resource=fed,
|
||||
authenticator=server.Authenticator(hs),
|
||||
ratelimiter=FederationRateLimiter(
|
||||
hs.get_clock(),
|
||||
window_size=hs.config.federation_rc_window_size,
|
||||
sleep_limit=hs.config.federation_rc_sleep_limit,
|
||||
sleep_msec=hs.config.federation_rc_sleep_delay,
|
||||
reject_limit=hs.config.federation_rc_reject_limit,
|
||||
concurrent_requests=hs.config.federation_rc_concurrent,
|
||||
),
|
||||
)
|
||||
register_federation_servlets(hs, fed)
|
||||
|
||||
defer.returnValue(hs)
|
||||
|
||||
|
||||
def register_federation_servlets(hs, resource):
|
||||
federation_server.register_servlets(
|
||||
hs,
|
||||
resource=resource,
|
||||
authenticator=federation_server.Authenticator(hs),
|
||||
ratelimiter=FederationRateLimiter(
|
||||
hs.get_clock(),
|
||||
window_size=hs.config.federation_rc_window_size,
|
||||
sleep_limit=hs.config.federation_rc_sleep_limit,
|
||||
sleep_msec=hs.config.federation_rc_sleep_delay,
|
||||
reject_limit=hs.config.federation_rc_reject_limit,
|
||||
concurrent_requests=hs.config.federation_rc_concurrent,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_mock_call_args(pattern_func, mock_func):
|
||||
""" Return the arguments the mock function was called with interpreted
|
||||
by the pattern functions argument list.
|
||||
@@ -457,6 +491,9 @@ class MockKey(object):
|
||||
def verify(self, message, sig):
|
||||
assert sig == b"\x9a\x87$"
|
||||
|
||||
def encode(self):
|
||||
return b"<fake_encoded_key>"
|
||||
|
||||
|
||||
class MockClock(object):
|
||||
now = 1000
|
||||
@@ -486,7 +523,7 @@ class MockClock(object):
|
||||
return t
|
||||
|
||||
def looping_call(self, function, interval):
|
||||
self.loopers.append([function, interval / 1000., self.now])
|
||||
self.loopers.append([function, interval / 1000.0, self.now])
|
||||
|
||||
def cancel_call_later(self, timer, ignore_errs=False):
|
||||
if timer[2]:
|
||||
@@ -522,7 +559,7 @@ class MockClock(object):
|
||||
looped[2] = self.now
|
||||
|
||||
def advance_time_msec(self, ms):
|
||||
self.advance_time(ms / 1000.)
|
||||
self.advance_time(ms / 1000.0)
|
||||
|
||||
def time_bound_deferred(self, d, *args, **kwargs):
|
||||
# We don't bother timing things out for now.
|
||||
@@ -631,7 +668,7 @@ def create_room(hs, room_id, creator_id):
|
||||
"sender": creator_id,
|
||||
"room_id": room_id,
|
||||
"content": {},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
event, context = yield event_creation_handler.create_new_client_event(builder)
|
||||
|
||||
Reference in New Issue
Block a user