1
0

Compare commits

..

75 Commits
hhs-1 ... hhs-6

Author SHA1 Message Date
Neil Johnson
c5440b2ca0 Merge pull request #3800 from matrix-org/neilj/remove-guests-from-mau-count
guest users should not be part of mau total
2018-09-06 17:45:55 +01:00
Neil Johnson
84a750e0c3 ensure guests never enter mau list 2018-09-06 17:22:53 +01:00
Amber Brown
654324eded Merge pull request #3805 from matrix-org/erikj/limit_transaction_pdus_edus
Limit the number of PDUs/EDUs per fedreation transaction
2018-09-07 01:33:31 +10:00
Amber Brown
1d371fc5b3 Merge pull request #3806 from matrix-org/erikj/limit_postgres_travis
Only start postgres instance for postgres tests on Travis CI
2018-09-07 01:06:21 +10:00
Erik Johnston
b07a2cbee9 Spelling 2018-09-06 15:53:02 +01:00
Amber Brown
70fd75cd1d Merge pull request #3788 from matrix-org/erikj/remove_conn_id
Remove conn_id from repl prometheus metrics
2018-09-07 00:48:19 +10:00
Erik Johnston
5d848992bf Newsfile 2018-09-06 15:28:46 +01:00
Erik Johnston
3b4223aa23 Only start postgres instance for postgres tests on Travis CI 2018-09-06 15:27:54 +01:00
Erik Johnston
28f5bfdcf7 Newsfile 2018-09-06 15:25:58 +01:00
Amber Brown
ee7c8bd2b5 Merge pull request #3795 from matrix-org/erikj/faster_sync_state
User iter* during sync state calculations
2018-09-07 00:24:43 +10:00
Erik Johnston
6707a3212c Limit the number of PDUs/EDUs per fedreation transaction 2018-09-06 15:23:55 +01:00
Amber Brown
135f3b4390 Merge pull request #3804 from matrix-org/rav/fix_openssl_dep
bump dep on pyopenssl to 16.x
2018-09-07 00:23:39 +10:00
Amber Brown
2608ebc04c Port handlers/ to Python 3 (#3803) 2018-09-07 00:22:23 +10:00
Richard van der Hoff
4f8baab0c4 Merge branch 'master' into develop 2018-09-06 13:05:22 +01:00
Richard van der Hoff
b3c2ebba32 changelog 2018-09-06 12:55:17 +01:00
Richard van der Hoff
625542878d bump dep on pyopenssl to 16.x 2018-09-06 12:53:15 +01:00
Richard van der Hoff
2fd17b5ad1 Merge tag 'v0.33.3.1'
Synapse 0.33.3.1 (2018-09-06)
=============================

SECURITY FIXES
--------------

- Fix an issue where event signatures were not always correctly validated ([\#3796](https://github.com/matrix-org/synapse/issues/3796))
- Fix an issue where server_acls could be circumvented for incoming events ([\#3796](https://github.com/matrix-org/synapse/issues/3796))

Internal Changes
----------------

- Unignore synctl in .dockerignore to fix docker builds ([\#3802](https://github.com/matrix-org/synapse/issues/3802))
2018-09-06 12:31:35 +01:00
Richard van der Hoff
80189ed27c prepare v0.33.3.1 2018-09-06 10:26:23 +01:00
Jan Christian Grünhage
0cd7b209e2 Create 3802.misc 2018-09-06 10:24:59 +01:00
Jan Christian Grünhage
78d1042c10 remove synctl from .dockerignore 2018-09-06 10:24:59 +01:00
Neil Johnson
92657be7d0 towncrier 2018-09-05 22:33:45 +01:00
Neil Johnson
61b05727fa guest users should not be part of mau total 2018-09-05 22:30:36 +01:00
Richard van der Hoff
dfba1d843d Merge pull request #3790 from matrix-org/rav/respect_event_format_in_filter
Implement 'event_format' filter param in /sync
2018-09-05 16:24:14 +01:00
Erik Johnston
2254790ae4 Newsfile 2018-09-05 16:21:18 +01:00
Erik Johnston
7419764351 User iter* during sync state calculations 2018-09-05 16:19:50 +01:00
Amber Brown
2d2828dcbc Port http/ to Python 3 (#3771) 2018-09-06 00:10:47 +10:00
Richard van der Hoff
c127c8d042 Fix origin handling for pushed transactions
Use the actual origin for push transactions, rather than whatever the remote
server claimed.
2018-09-05 13:08:07 +01:00
Richard van der Hoff
804dd41e18 Check that signatures on events are valid
We should check that both the sender's server, and the server which created the
event_id (which may be different from whatever the remote server has told us
the origin is), have signed the event.
2018-09-05 13:08:07 +01:00
Richard van der Hoff
c91bd295f5 changelog 2018-09-04 15:23:19 +01:00
Richard van der Hoff
87c18d12ee Implement 'event_format' filter param in /sync
This has been specced and part-implemented; let's implement it for /sync (but
no other endpoints yet :/).
2018-09-04 15:20:09 +01:00
Neil Johnson
a6cf7d9d9a Merge pull request #3789 from matrix-org/neilj/improve_threepid_error_strings
improve human readable error messages
2018-09-04 13:16:00 +00:00
Neil Johnson
70fc599ede towncrier 2018-09-04 12:08:27 +01:00
Neil Johnson
bae37cd811 improve human readable error message 2018-09-04 12:07:00 +01:00
Neil Johnson
c42f7fd7b9 improve human readable error messages 2018-09-04 12:03:17 +01:00
Erik Johnston
3e242dc149 Remove conn_id 2018-09-04 11:45:52 +01:00
Erik Johnston
87b111f96a Newsfile 2018-09-03 17:26:15 +01:00
Erik Johnston
b13836da7f Remove conn_id from repl prometheus metrics
`conn_id` gets set to a random string, and so we end up filling up
prometheus with tonnes of data series, which is bad.
2018-09-03 17:22:49 +01:00
Amber Brown
77055dba92 Fix tests on postgresql (#3740) 2018-09-04 02:21:48 +10:00
Erik Johnston
567363e497 Merge pull request #3737 from matrix-org/erikj/remove_redundant_state_func
Remove unnecessary resolve_events_with_state_map
2018-09-03 16:19:41 +01:00
Erik Johnston
06ee2b7cc5 Newsfile 2018-09-03 15:43:01 +01:00
Amber Brown
30cf40ff30 Merge pull request #3378 from NickEckardt/develop
matrix-synapse-auto-deploy is no longer maintained.
2018-09-03 22:47:23 +10:00
Amber Brown
905f8de673 Create 3378.misc 2018-09-03 22:17:06 +10:00
Amber Brown
4fc4b881c5 Merge branch 'develop' into develop 2018-09-03 21:08:35 +10:00
Neil Johnson
99178f8602 Merge pull request #3777 from matrix-org/neilj/fix_register_user_registration
fix bug where preserved threepid user comes to sign up and server is …
2018-08-31 16:31:48 +00:00
Neil Johnson
301cb60d0b assert rather than warn 2018-08-31 17:29:35 +01:00
Neil Johnson
0b01281e77 move threepid checker to config, add missing yields 2018-08-31 17:11:11 +01:00
Neil Johnson
e8e540630e fix reference to is_threepid_reserved 2018-08-31 16:09:15 +01:00
Neil Johnson
a796bdd35e typo 2018-08-31 15:57:33 +01:00
Neil Johnson
09f3cf1a7e ensure post registration auth checks do not fail erroneously 2018-08-31 15:42:51 +01:00
Neil Johnson
3d6aa06577 news fragemnt 2018-08-31 10:56:37 +01:00
Neil Johnson
ea068d6f3c fix bug where preserved threepid user comes to sign up and server is mau blocked 2018-08-31 10:49:14 +01:00
Amber Brown
14e4d4f4bf Port storage/ to Python 3 (#3725) 2018-08-31 00:19:58 +10:00
Richard van der Hoff
475253a88e Merge pull request #3764 from matrix-org/rav/close_db_conn_after_init
Make sure that we close db connections opened during init
2018-08-30 10:47:27 +01:00
Erik Johnston
7f0399586d Merge pull request #3768 from krombel/fix_3445
fix #3445 - do not use itervalues() on SortedDict()
2018-08-29 16:29:57 +01:00
Krombel
8c0c51ecb3 changelog 2018-08-29 16:49:26 +02:00
Krombel
79a8a347a6 fix #3445
itervalues(d) calls d.itervalues() [PY2] and d.values() [PY3]
but SortedDict only implements d.values()
2018-08-29 16:28:25 +02:00
Amber Brown
82276a18d1 Update CONTRIBUTING to clarify miscs & Markdown (#3730) 2018-08-29 19:06:59 +10:00
Matthew Hodgson
b1580f50fe don't return non-LL-member state in incremental sync state blocks (#3760)
don't return non-LL-member state in incremental sync state blocks
2018-08-28 23:25:58 +01:00
Richard van der Hoff
414fa36f3e Fix up tests 2018-08-28 17:21:05 +01:00
Richard van der Hoff
32eb1dedd2 use abc.abstractproperty
This gives clearer messages when someone gets it wrong
2018-08-28 17:10:43 +01:00
Richard van der Hoff
71990b3cae changelog 2018-08-28 13:42:20 +01:00
Richard van der Hoff
0b07f02e19 Make sure that we close db connections opened during init
We should explicitly close any db connections we open, because failing to do so
can block other transactions as per
https://github.com/matrix-org/synapse/issues/3682.

Let's also try to factor out some of the boilerplate by having server classes
define their datastore class rather than duplicating the whole of `setup`.
2018-08-28 13:39:49 +01:00
Erik Johnston
9fbaed325f Merge pull request #3758 from matrix-org/erikj/admin_contact
Change admin_uri to admin_contact in config and errors
2018-08-24 17:15:32 +01:00
Erik Johnston
9db2476991 Merge branch 'develop' of github.com:matrix-org/synapse into erikj/admin_contact 2018-08-24 17:00:37 +01:00
Erik Johnston
69f7f1418f Newsfile 2018-08-24 16:54:06 +01:00
Erik Johnston
48b04c4a4c Merge pull request #3756 from matrix-org/erikj/fix_tags_server_notices
Fix checking if service notice room is already tagged
2018-08-24 16:53:11 +01:00
Erik Johnston
08abe8e13c Newsfile 2018-08-24 16:52:52 +01:00
Erik Johnston
05077e06fa Change admin_uri to admin_contact in config and errors 2018-08-24 16:51:27 +01:00
Erik Johnston
01a5a8b9e3 Fix checking if service notice room is already tagged
This manifested in synapse repeatedly setting the tag for the room
2018-08-24 16:22:37 +01:00
Erik Johnston
897d976c1e Merge pull request #3755 from matrix-org/erikj/fix_server_notice_tags
Fix up tagging server notice rooms.
2018-08-24 15:04:41 +01:00
Erik Johnston
45fc0ead10 Newsfile 2018-08-24 15:04:29 +01:00
Erik Johnston
cdd24449ee Ensure we wake up /sync when we add tag to notice room 2018-08-24 14:50:03 +01:00
Erik Johnston
14d49c51db Make content of tag an empty object rather than null 2018-08-24 14:44:16 +01:00
Erik Johnston
5c261107c9 Remove unnecessary resolve_events_with_state_map
We only ever used the synchronous resolve_events_with_state_map in one
place, which is trivial to replace with the async version.
2018-08-22 15:41:15 +01:00
Nicholas Eckardt
76c80e3fdf The project matrix-synapse-auto-deploy does not seem to be maintained anymore.
It has been over a year since any code has been commited. Some of the relevant links in
the documentation are broken, but since no pull requests are being accepted, they
won't get fixed. We should probably remove this from the README.
2018-06-10 18:24:12 -07:00
109 changed files with 1209 additions and 821 deletions

View File

@@ -3,6 +3,5 @@ Dockerfile
.gitignore
demo/etc
tox.ini
synctl
.git/*
.tox/*

1
.gitignore vendored
View File

@@ -44,6 +44,7 @@ media_store/
build/
venv/
venv*/
*venv/
localhost-800*/
static/client/register/register_config.js

View File

@@ -8,9 +8,6 @@ before_script:
- git remote set-branches --add origin develop
- git fetch origin develop
services:
- postgresql
matrix:
fast_finish: true
include:
@@ -25,6 +22,8 @@ matrix:
- python: 2.7
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
services:
- postgresql
- python: 3.6
env: TOX_ENV=py36
@@ -35,10 +34,6 @@ matrix:
- python: 3.6
env: TOX_ENV=check-newsfragment
allow_failures:
- python: 2.7
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
install:
- pip install tox

View File

@@ -1,3 +1,17 @@
Synapse 0.33.3.1 (2018-09-06)
=============================
SECURITY FIXES
--------------
- Fix an issue where event signatures were not always correctly validated ([\#3796](https://github.com/matrix-org/synapse/issues/3796))
- Fix an issue where server_acls could be circumvented for incoming events ([\#3796](https://github.com/matrix-org/synapse/issues/3796))
Internal Changes
----------------
- Unignore synctl in .dockerignore to fix docker builds ([\#3802](https://github.com/matrix-org/synapse/issues/3802))
Synapse 0.33.3 (2018-08-22)
===========================

View File

@@ -59,9 +59,10 @@ To create a changelog entry, make a new file in the ``changelog.d``
file named in the format of ``PRnumber.type``. The type can be
one of ``feature``, ``bugfix``, ``removal`` (also used for
deprecations), or ``misc`` (for internal-only changes). The content of
the file is your changelog entry, which can contain RestructuredText
formatting. A note of contributors is welcomed in changelogs for
non-misc changes (the content of misc changes is not displayed).
the file is your changelog entry, which can contain Markdown
formatting. Adding credits to the changelog is encouraged, we value
your contributions and would like to have you shouted out in the
release notes!
For example, a fix in PR #1234 would have its changelog entry in
``changelog.d/1234.bugfix``, and contain content like "The security levels of

View File

@@ -167,11 +167,6 @@ Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a
Dockerfile to automate a synapse server in a single Docker image, at
https://hub.docker.com/r/avhost/docker-matrix/tags/
Also, Martin Giess has created an auto-deployment process with vagrant/ansible,
tested with VirtualBox/AWS/DigitalOcean - see
https://github.com/EMnify/matrix-synapse-auto-deploy
for details.
Configuring synapse
-------------------

1
changelog.d/3378.misc Normal file
View File

@@ -0,0 +1 @@
Removed the link to the unmaintained matrix-synapse-auto-deploy project from the readme.

1
changelog.d/3725.misc Normal file
View File

@@ -0,0 +1 @@
The synapse.storage module has been ported to Python 3.

1
changelog.d/3730.misc Normal file
View File

@@ -0,0 +1 @@
The CONTRIBUTING guidelines have been updated to mention our use of Markdown and that .misc files have content.

1
changelog.d/3737.misc Normal file
View File

@@ -0,0 +1 @@
Remove redundant state resolution function

1
changelog.d/3740.misc Normal file
View File

@@ -0,0 +1 @@
The test suite now passes on PostgreSQL.

1
changelog.d/3755.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix tagging of server notice rooms

1
changelog.d/3756.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix tagging of server notice rooms

1
changelog.d/3758.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix 'admin_uri' config variable and error parameter to be 'admin_contact' to match the spec.

1
changelog.d/3760.bugfix Normal file
View File

@@ -0,0 +1 @@
Don't return non-LL-member state in incremental sync state blocks

1
changelog.d/3764.misc Normal file
View File

@@ -0,0 +1 @@
Make sure that we close db connections opened during init

1
changelog.d/3768.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug in sending presence over federation

1
changelog.d/3771.misc Normal file
View File

@@ -0,0 +1 @@
http/ is now ported to Python 3.

1
changelog.d/3777.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug where preserved threepid user comes to sign up and server is mau blocked

1
changelog.d/3788.bugfix Normal file
View File

@@ -0,0 +1 @@
Remove connection ID for replication prometheus metrics, as it creates a large number of new series.

1
changelog.d/3789.misc Normal file
View File

@@ -0,0 +1 @@
Improve human readable error messages for threepid registration/account update

1
changelog.d/3790.feature Normal file
View File

@@ -0,0 +1 @@
Implement `event_format` filter param in `/sync`

1
changelog.d/3795.misc Normal file
View File

@@ -0,0 +1 @@
Make /sync slightly faster by avoiding needless copies

1
changelog.d/3800.bugfix Normal file
View File

@@ -0,0 +1 @@
guest users should not be part of mau total

1
changelog.d/3803.misc Normal file
View File

@@ -0,0 +1 @@
handlers/ is now ported to Python 3.

1
changelog.d/3804.bugfix Normal file
View File

@@ -0,0 +1 @@
Bump dependency on pyopenssl 16.x, to avoid incompatibility with recent Twisted.

1
changelog.d/3805.misc Normal file
View File

@@ -0,0 +1 @@
Limit the number of PDUs/EDUs per federation transaction

1
changelog.d/3806.misc Normal file
View File

@@ -0,0 +1 @@
Only start postgres instance for postgres tests on Travis CI

View File

@@ -31,5 +31,5 @@ $TOX_BIN/pip install 'setuptools>=18.5'
$TOX_BIN/pip install 'pip>=10'
{ python synapse/python_dependencies.py
echo lxml psycopg2
echo lxml
} | xargs $TOX_BIN/pip install

View File

@@ -17,13 +17,14 @@ ignore =
[pep8]
max-line-length = 90
# W503 requires that binary operators be at the end, not start, of lines. Erik
# doesn't like it. E203 is contrary to PEP8.
ignore = W503,E203
# doesn't like it. E203 is contrary to PEP8. E731 is silly.
ignore = W503,E203,E731
[flake8]
# note that flake8 inherits the "ignore" settings from "pep8" (because it uses
# pep8 to do those checks), but not the "max-line-length" setting
max-line-length = 90
ignore=W503,E203,E731
[isort]
line_length = 89

View File

@@ -17,4 +17,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.33.3"
__version__ = "0.33.3.1"

View File

@@ -26,6 +26,7 @@ import synapse.types
from synapse import event_auth
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
@@ -775,13 +776,19 @@ class Auth(object):
)
@defer.inlineCallbacks
def check_auth_blocking(self, user_id=None):
def check_auth_blocking(self, user_id=None, threepid=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
user_id(str|None): If present, checks for presence against existing
MAU cohort
threepid(dict|None): If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
"""
# Never fail an auth check for the server notices users
@@ -793,10 +800,12 @@ class Auth(object):
raise ResourceLimitError(
403, self.hs.config.hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_uri=self.hs.config.admin_uri,
admin_contact=self.hs.config.admin_contact,
limit_type=self.hs.config.hs_disabled_limit_type
)
if self.hs.config.limit_usage_by_mau is True:
assert not (user_id and threepid)
# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
@@ -806,13 +815,17 @@ class Auth(object):
is_trial = yield self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
# If the user does not exist yet, but is signing up with a
# reserved threepid then pass auth check
if is_threepid_reserved(self.hs.config, threepid):
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise ResourceLimitError(
403, "Monthly Active User Limit Exceeded",
admin_uri=self.hs.config.admin_uri,
admin_contact=self.hs.config.admin_contact,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
limit_type="monthly_active_user"
)

View File

@@ -239,10 +239,10 @@ class ResourceLimitError(SynapseError):
def __init__(
self, code, msg,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_uri=None,
admin_contact=None,
limit_type=None,
):
self.admin_uri = admin_uri
self.admin_contact = admin_contact
self.limit_type = limit_type
super(ResourceLimitError, self).__init__(code, msg, errcode=errcode)
@@ -250,7 +250,7 @@ class ResourceLimitError(SynapseError):
return cs_error(
self.msg,
self.errcode,
admin_uri=self.admin_uri,
admin_contact=self.admin_contact,
limit_type=self.limit_type
)

View File

@@ -251,6 +251,7 @@ class FilterCollection(object):
"include_leave", False
)
self.event_fields = filter_json.get("event_fields", [])
self.event_format = filter_json.get("event_format", "client")
def __repr__(self):
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)

View File

@@ -51,10 +51,7 @@ class AppserviceSlaveStore(
class AppserviceServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = AppserviceSlaveStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -74,10 +74,7 @@ class ClientReaderSlavedStore(
class ClientReaderServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = ClientReaderSlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -90,10 +90,7 @@ class EventCreatorSlavedStore(
class EventCreatorServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = EventCreatorSlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -72,10 +72,7 @@ class FederationReaderSlavedStore(
class FederationReaderServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = FederationReaderSlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -78,10 +78,7 @@ class FederationSenderSlaveStore(
class FederationSenderServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = FederationSenderSlaveStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -148,10 +148,7 @@ class FrontendProxySlavedStore(
class FrontendProxyServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = FrontendProxySlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -62,7 +62,7 @@ from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.server import HomeServer
from synapse.storage import are_all_users_on_domain
from synapse.storage import DataStore, are_all_users_on_domain
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
from synapse.util.caches import CACHE_SIZE_FACTOR
@@ -111,6 +111,8 @@ def build_resource_for_web_client(hs):
class SynapseHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
def _listener_http(self, config, listener_config):
port = listener_config["port"]
bind_addresses = listener_config["bind_addresses"]
@@ -356,13 +358,13 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name'])
try:
db_conn = hs.get_db_conn(run_new_connection=False)
prepare_database(db_conn, database_engine, config=config)
database_engine.on_new_connection(db_conn)
with hs.get_db_conn(run_new_connection=False) as db_conn:
prepare_database(db_conn, database_engine, config=config)
database_engine.on_new_connection(db_conn)
hs.run_startup_checks(db_conn, database_engine)
hs.run_startup_checks(db_conn, database_engine)
db_conn.commit()
db_conn.commit()
except UpgradeDatabaseException:
sys.stderr.write(
"\nFailed to upgrade database.\n"

View File

@@ -60,10 +60,7 @@ class MediaRepositorySlavedStore(
class MediaRepositoryServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = MediaRepositorySlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -78,10 +78,7 @@ class PusherSlaveStore(
class PusherServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = PusherSlaveStore
def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)

View File

@@ -249,10 +249,7 @@ class SynchrotronApplicationService(object):
class SynchrotronServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = SynchrotronSlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -94,10 +94,7 @@ class UserDirectorySlaveStore(
class UserDirectoryServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = UserDirectorySlaveStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from six.moves import urllib
from prometheus_client import Counter
@@ -98,7 +99,7 @@ class ApplicationServiceApi(SimpleHttpClient):
def query_user(self, service, user_id):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/users/%s" % urllib.quote(user_id))
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
response = None
try:
response = yield self.get_json(uri, {
@@ -119,7 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient):
def query_alias(self, service, alias):
if service.url is None:
defer.returnValue(False)
uri = service.url + ("/rooms/%s" % urllib.quote(alias))
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
response = None
try:
response = yield self.get_json(uri, {
@@ -153,7 +154,7 @@ class ApplicationServiceApi(SimpleHttpClient):
service.url,
APP_SERVICE_PREFIX,
kind,
urllib.quote(protocol)
urllib.parse.quote(protocol)
)
try:
response = yield self.get_json(uri, fields)
@@ -188,7 +189,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.quote(protocol)
urllib.parse.quote(protocol)
)
try:
info = yield self.get_json(uri, {})
@@ -228,7 +229,7 @@ class ApplicationServiceApi(SimpleHttpClient):
txn_id = str(txn_id)
uri = service.url + ("/transactions/%s" %
urllib.quote(txn_id))
urllib.parse.quote(txn_id))
try:
yield self.put_json(
uri=uri,

View File

@@ -93,7 +93,7 @@ class ServerConfig(Config):
# Admin uri to direct users at should their instance become blocked
# due to resource constraints
self.admin_uri = config.get("admin_uri", None)
self.admin_contact = config.get("admin_contact", None)
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None
@@ -357,7 +357,7 @@ class ServerConfig(Config):
# Homeserver blocking
#
# How to reach the server admin, used in ResourceLimitError
# admin_uri: 'mailto:admin@server.com'
# admin_contact: 'mailto:admin@server.com'
#
# Global block config
#
@@ -404,6 +404,23 @@ class ServerConfig(Config):
" service on the given port.")
def is_threepid_reserved(config, threepid):
"""Check the threepid against the reserved threepid config
Args:
config(ServerConfig) - to access server config attributes
threepid(dict) - The threepid to test for
Returns:
boolean Is the threepid undertest reserved_user
"""
for tp in config.mau_limits_reserved_threepids:
if (threepid['medium'] == tp['medium']
and threepid['address'] == tp['address']):
return True
return False
def read_gc_thresholds(thresholds):
"""Reads the three integer thresholds for garbage collection. Ensures that
the thresholds are integers if thresholds are supplied.

View File

@@ -13,17 +13,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import namedtuple
import six
from twisted.internet import defer
from twisted.internet.defer import DeferredList
from synapse.api.constants import MAX_DEPTH
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
from synapse.types import get_domain_from_id
from synapse.util import logcontext, unwrapFirstError
logger = logging.getLogger(__name__)
@@ -133,34 +136,25 @@ class FederationBase(object):
* throws a SynapseError if the signature check failed.
The deferreds run their callbacks in the sentinel logcontext.
"""
redacted_pdus = [
prune_event(pdu)
for pdu in pdus
]
deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
deferreds = _check_sigs_on_pdus(self.keyring, pdus)
ctx = logcontext.LoggingContext.current_context()
def callback(_, pdu, redacted):
def callback(_, pdu):
with logcontext.PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
)
return redacted
return prune_event(pdu)
if self.spam_checker.check_event_for_spam(pdu):
logger.warn(
"Event contains spam, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
)
return redacted
return prune_event(pdu)
return pdu
@@ -173,16 +167,116 @@ class FederationBase(object):
)
return failure
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
for deferred, pdu in zip(deferreds, pdus):
deferred.addCallbacks(
callback, errback,
callbackArgs=[pdu, redacted],
callbackArgs=[pdu],
errbackArgs=[pdu],
)
return deferreds
class PduToCheckSig(namedtuple("PduToCheckSig", [
"pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds",
])):
pass
def _check_sigs_on_pdus(keyring, pdus):
"""Check that the given events are correctly signed
Args:
keyring (synapse.crypto.Keyring): keyring object to do the checks
pdus (Collection[EventBase]): the events to be checked
Returns:
List[Deferred]: a Deferred for each event in pdus, which will either succeed if
the signatures are valid, or fail (with a SynapseError) if not.
"""
# (currently this is written assuming the v1 room structure; we'll probably want a
# separate function for checking v2 rooms)
# we want to check that the event is signed by:
#
# (a) the server which created the event_id
#
# (b) the sender's server.
#
# - except in the case of invites created from a 3pid invite, which are exempt
# from this check, because the sender has to match that of the original 3pid
# invite, but the event may come from a different HS, for reasons that I don't
# entirely grok (why do the senders have to match? and if they do, why doesn't the
# joining server ask the inviting server to do the switcheroo with
# exchange_third_party_invite?).
#
# That's pretty awful, since redacting such an invite will render it invalid
# (because it will then look like a regular invite without a valid signature),
# and signatures are *supposed* to be valid whether or not an event has been
# redacted. But this isn't the worst of the ways that 3pid invites are broken.
#
# let's start by getting the domain for each pdu, and flattening the event back
# to JSON.
pdus_to_check = [
PduToCheckSig(
pdu=p,
redacted_pdu_json=prune_event(p).get_pdu_json(),
event_id_domain=get_domain_from_id(p.event_id),
sender_domain=get_domain_from_id(p.sender),
deferreds=[],
)
for p in pdus
]
# first make sure that the event is signed by the event_id's domain
deferreds = keyring.verify_json_objects_for_server([
(p.event_id_domain, p.redacted_pdu_json)
for p in pdus_to_check
])
for p, d in zip(pdus_to_check, deferreds):
p.deferreds.append(d)
# now let's look for events where the sender's domain is different to the
# event id's domain (normally only the case for joins/leaves), and add additional
# checks.
pdus_to_check_sender = [
p for p in pdus_to_check
if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu)
]
more_deferreds = keyring.verify_json_objects_for_server([
(p.sender_domain, p.redacted_pdu_json)
for p in pdus_to_check_sender
])
for p, d in zip(pdus_to_check_sender, more_deferreds):
p.deferreds.append(d)
# replace lists of deferreds with single Deferreds
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
def _flatten_deferred_list(deferreds):
"""Given a list of one or more deferreds, either return the single deferred, or
combine into a DeferredList.
"""
if len(deferreds) > 1:
return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
else:
assert len(deferreds) == 1
return deferreds[0]
def _is_invite_via_3pid(event):
return (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
def event_from_pdu_json(pdu_json, outlier=False):
"""Construct a FrozenEvent from an event json received over federation

View File

@@ -99,7 +99,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def on_incoming_transaction(self, transaction_data):
def on_incoming_transaction(self, origin, transaction_data):
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
@@ -108,34 +108,33 @@ class FederationServer(FederationBase):
if not transaction.transaction_id:
raise Exception("Transaction missing transaction_id")
if not transaction.origin:
raise Exception("Transaction missing origin")
logger.debug("[%s] Got transaction", transaction.transaction_id)
# use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel.
with (yield self._transaction_linearizer.queue(
(transaction.origin, transaction.transaction_id),
(origin, transaction.transaction_id),
)):
result = yield self._handle_incoming_transaction(
transaction, request_time,
origin, transaction, request_time,
)
defer.returnValue(result)
@defer.inlineCallbacks
def _handle_incoming_transaction(self, transaction, request_time):
def _handle_incoming_transaction(self, origin, transaction, request_time):
""" Process an incoming transaction and return the HTTP response
Args:
origin (unicode): the server making the request
transaction (Transaction): incoming transaction
request_time (int): timestamp that the HTTP request arrived at
Returns:
Deferred[(int, object)]: http response code and body
"""
response = yield self.transaction_actions.have_responded(transaction)
response = yield self.transaction_actions.have_responded(origin, transaction)
if response:
logger.debug(
@@ -149,7 +148,7 @@ class FederationServer(FederationBase):
received_pdus_counter.inc(len(transaction.pdus))
origin_host, _ = parse_server_name(transaction.origin)
origin_host, _ = parse_server_name(origin)
pdus_by_room = {}
@@ -190,7 +189,7 @@ class FederationServer(FederationBase):
event_id = pdu.event_id
try:
yield self._handle_received_pdu(
transaction.origin, pdu
origin, pdu
)
pdu_results[event_id] = {}
except FederationError as e:
@@ -212,7 +211,7 @@ class FederationServer(FederationBase):
if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus):
yield self.received_edu(
transaction.origin,
origin,
edu.edu_type,
edu.content
)
@@ -224,6 +223,7 @@ class FederationServer(FederationBase):
logger.debug("Returning: %s", str(response))
yield self.transaction_actions.set_response(
origin,
transaction,
200, response
)
@@ -838,9 +838,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
)
return self._send_edu(
edu_type=edu_type,
origin=origin,
content=content,
edu_type=edu_type,
origin=origin,
content=content,
)
def on_query(self, query_type, args):
@@ -851,6 +851,6 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
return handler(args)
return self._get_query_client(
query_type=query_type,
args=args,
query_type=query_type,
args=args,
)

View File

@@ -36,7 +36,7 @@ class TransactionActions(object):
self.store = datastore
@log_function
def have_responded(self, transaction):
def have_responded(self, origin, transaction):
""" Have we already responded to a transaction with the same id and
origin?
@@ -50,11 +50,11 @@ class TransactionActions(object):
"transaction_id")
return self.store.get_received_txn_response(
transaction.transaction_id, transaction.origin
transaction.transaction_id, origin
)
@log_function
def set_response(self, transaction, code, response):
def set_response(self, origin, transaction, code, response):
""" Persist how we responded to a transaction.
Returns:
@@ -66,7 +66,7 @@ class TransactionActions(object):
return self.store.set_received_txn_response(
transaction.transaction_id,
transaction.origin,
origin,
code,
response,
)

View File

@@ -32,7 +32,7 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
from six import iteritems, itervalues
from six import iteritems
from sortedcontainers import SortedDict
@@ -117,7 +117,7 @@ class FederationRemoteSendQueue(object):
user_ids = set(
user_id
for uids in itervalues(self.presence_changed)
for uids in self.presence_changed.values()
for user_id in uids
)

View File

@@ -463,7 +463,19 @@ class TransactionQueue(object):
# pending_transactions flag.
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
# We can only include at most 50 PDUs per transactions
pending_pdus, leftover_pdus = pending_pdus[:50], pending_pdus[50:]
if leftover_pdus:
self.pending_pdus_by_dest[destination] = leftover_pdus
pending_edus = self.pending_edus_by_dest.pop(destination, [])
# We can only include at most 100 EDUs per transactions
pending_edus, leftover_edus = pending_edus[:100], pending_edus[100:]
if leftover_edus:
self.pending_edus_by_dest[destination] = leftover_edus
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_edus.extend(

View File

@@ -353,7 +353,7 @@ class FederationSendServlet(BaseFederationServlet):
try:
code, response = yield self.handler.on_incoming_transaction(
transaction_data
origin, transaction_data,
)
except Exception:
logger.exception("on_incoming_transaction failed")

View File

@@ -895,22 +895,24 @@ class AuthHandler(BaseHandler):
Args:
password (unicode): Password to hash.
stored_hash (unicode): Expected hash value.
stored_hash (bytes): Expected hash value.
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
"""
def _do_validate_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
stored_hash.encode('utf8')
stored_hash
)
if stored_hash:
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode('ascii')
return make_deferred_yieldable(
threads.deferToThreadPool(
self.hs.get_reactor(),

View File

@@ -330,7 +330,8 @@ class E2eKeysHandler(object):
(algorithm, key_id, ex_json, key)
)
else:
new_keys.append((algorithm, key_id, encode_canonical_json(key)))
new_keys.append((
algorithm, key_id, encode_canonical_json(key).decode('ascii')))
yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, new_keys
@@ -358,7 +359,7 @@ def _exception_to_failure(e):
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't
# give a string for e.message, which json then fails to serialize.
return {
"status": 503, "message": str(e.message),
"status": 503, "message": str(e),
}

View File

@@ -594,7 +594,7 @@ class FederationHandler(BaseHandler):
required_auth = set(
a_id
for event in events + state_events.values() + auth_events.values()
for event in events + list(state_events.values()) + list(auth_events.values())
for a_id, _ in event.auth_events
)
auth_events.update({
@@ -802,7 +802,7 @@ class FederationHandler(BaseHandler):
)
continue
except NotRetryingDestination as e:
logger.info(e.message)
logger.info(str(e))
continue
except FederationDeniedError as e:
logger.info(e)
@@ -1358,7 +1358,7 @@ class FederationHandler(BaseHandler):
)
if state_groups:
_, state = state_groups.items().pop()
_, state = list(state_groups.items()).pop()
results = state
if event.is_state():
@@ -1831,7 +1831,7 @@ class FederationHandler(BaseHandler):
room_version = yield self.store.get_room_version(event.room_id)
new_state = self.state_handler.resolve_events(
new_state = yield self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())],
event

View File

@@ -125,6 +125,7 @@ class RegistrationHandler(BaseHandler):
guest_access_token=None,
make_guest=False,
admin=False,
threepid=None,
):
"""Registers a new client on the server.
@@ -145,7 +146,7 @@ class RegistrationHandler(BaseHandler):
RegistrationError if there was a problem registering.
"""
yield self.auth.check_auth_blocking()
yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
if password:
password_hash = yield self.auth_handler().hash(password)

View File

@@ -162,7 +162,7 @@ class RoomListHandler(BaseHandler):
# Filter out rooms that we don't want to return
rooms_to_scan = [
r for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0
if r not in newly_unpublished and rooms_to_num_joined[r] > 0
]
total_room_count = len(rooms_to_scan)

View File

@@ -54,7 +54,7 @@ class SearchHandler(BaseHandler):
batch_token = None
if batch:
try:
b = decode_base64(batch)
b = decode_base64(batch).decode('ascii')
batch_group, batch_group_key, batch_token = b.split("\n")
assert batch_group is not None
@@ -258,18 +258,18 @@ class SearchHandler(BaseHandler):
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key:
global_next_batch = encode_base64("%s\n%s\n%s" % (
global_next_batch = encode_base64(("%s\n%s\n%s" % (
batch_group, batch_group_key, pagination_token
))
)).encode('ascii'))
else:
global_next_batch = encode_base64("%s\n%s\n%s" % (
global_next_batch = encode_base64(("%s\n%s\n%s" % (
"all", "", pagination_token
))
)).encode('ascii'))
for room_id, group in room_groups.items():
group["next_batch"] = encode_base64("%s\n%s\n%s" % (
group["next_batch"] = encode_base64(("%s\n%s\n%s" % (
"room_id", room_id, pagination_token
))
)).encode('ascii'))
allowed_events.extend(room_events)

View File

@@ -545,7 +545,7 @@ class SyncHandler(object):
member_ids = {
state_key: event_id
for (t, state_key), event_id in state_ids.iteritems()
for (t, state_key), event_id in iteritems(state_ids)
if t == EventTypes.Member
}
name_id = state_ids.get((EventTypes.Name, ''))
@@ -745,9 +745,16 @@ class SyncHandler(object):
state_ids = {}
if lazy_load_members:
if types:
# We're returning an incremental sync, with no "gap" since
# the previous sync, so normally there would be no state to return
# But we're lazy-loading, so the client might need some more
# member events to understand the events in this timeline.
# So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below.
state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types,
filtered_types=filtered_types,
filtered_types=None, # we only want members!
)
if lazy_load_members and not include_redundant_members:
@@ -767,7 +774,7 @@ class SyncHandler(object):
logger.debug("filtering state from %r...", state_ids)
state_ids = {
t: event_id
for t, event_id in state_ids.iteritems()
for t, event_id in iteritems(state_ids)
if cache.get(t[1]) != event_id
}
logger.debug("...to %r", state_ids)
@@ -1722,17 +1729,17 @@ def _calculate_state(
event_id_to_key = {
e: key
for key, e in itertools.chain(
timeline_contains.items(),
previous.items(),
timeline_start.items(),
current.items(),
iteritems(timeline_contains),
iteritems(previous),
iteritems(timeline_start),
iteritems(current),
)
}
c_ids = set(e for e in current.values())
ts_ids = set(e for e in timeline_start.values())
p_ids = set(e for e in previous.values())
tc_ids = set(e for e in timeline_contains.values())
c_ids = set(e for e in itervalues(current))
ts_ids = set(e for e in itervalues(timeline_start))
p_ids = set(e for e in itervalues(previous))
tc_ids = set(e for e in itervalues(timeline_contains))
# If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync,
@@ -1746,7 +1753,7 @@ def _calculate_state(
if lazy_load_members:
p_ids.difference_update(
e for t, e in timeline_start.iteritems()
e for t, e in iteritems(timeline_start)
if t[0] == EventTypes.Member
)

View File

@@ -13,24 +13,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
from six import StringIO
from six import text_type
from six.moves import urllib
import treq
from canonicaljson import encode_canonical_json, json
from prometheus_client import Counter
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, protocol, reactor, ssl, task
from twisted.internet import defer, protocol, reactor, ssl
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web._newclient import ResponseDone
from twisted.web.client import (
Agent,
BrowserLikeRedirectAgent,
ContentDecoderAgent,
FileBodyProducer as TwistedFileBodyProducer,
GzipDecoder,
HTTPConnectionPool,
PartialDownloadError,
@@ -83,18 +84,20 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,)
self.user_agent = self.user_agent.encode('ascii')
@defer.inlineCallbacks
def request(self, method, uri, *args, **kwargs):
def request(self, method, uri, data=b'', headers=None):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
outgoing_requests_counter.labels(method).inc()
# log request but strip `access_token` (AS requests for example include this)
logger.info("Sending request %s %s", method, redact_uri(uri))
logger.info("Sending request %s %s", method, redact_uri(uri.encode('ascii')))
try:
request_deferred = self.agent.request(
method, uri, *args, **kwargs
request_deferred = treq.request(
method, uri, agent=self.agent, data=data, headers=headers
)
add_timeout_to_deferred(
request_deferred, 60, self.hs.get_reactor(),
@@ -105,14 +108,14 @@ class SimpleHttpClient(object):
incoming_responses_counter.labels(method, response.code).inc()
logger.info(
"Received response to %s %s: %s",
method, redact_uri(uri), response.code
method, redact_uri(uri.encode('ascii')), response.code
)
defer.returnValue(response)
except Exception as e:
incoming_responses_counter.labels(method, "ERR").inc()
logger.info(
"Error sending request to %s %s: %s %s",
method, redact_uri(uri), type(e).__name__, e.message
method, redact_uri(uri.encode('ascii')), type(e).__name__, e.args[0]
)
raise
@@ -137,7 +140,8 @@ class SimpleHttpClient(object):
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
query_bytes = urllib.parse.urlencode(
encode_urlencode_args(args), True).encode("utf8")
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@@ -148,15 +152,14 @@ class SimpleHttpClient(object):
response = yield self.request(
"POST",
uri.encode("ascii"),
uri,
headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(query_bytes))
data=query_bytes
)
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
body = yield make_deferred_yieldable(treq.json_content(response))
defer.returnValue(body)
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -191,9 +194,9 @@ class SimpleHttpClient(object):
response = yield self.request(
"POST",
uri.encode("ascii"),
uri,
headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
data=json_str
)
body = yield make_deferred_yieldable(readBody(response))
@@ -248,7 +251,7 @@ class SimpleHttpClient(object):
ValueError: if the response was not JSON
"""
if len(args):
query_bytes = urllib.urlencode(args, True)
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
json_str = encode_canonical_json(json_body)
@@ -262,9 +265,9 @@ class SimpleHttpClient(object):
response = yield self.request(
"PUT",
uri.encode("ascii"),
uri,
headers=Headers(actual_headers),
bodyProducer=FileBodyProducer(StringIO(json_str))
data=json_str
)
body = yield make_deferred_yieldable(readBody(response))
@@ -293,7 +296,7 @@ class SimpleHttpClient(object):
HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
query_bytes = urllib.urlencode(args, True)
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
actual_headers = {
@@ -304,7 +307,7 @@ class SimpleHttpClient(object):
response = yield self.request(
"GET",
uri.encode("ascii"),
uri,
headers=Headers(actual_headers),
)
@@ -339,7 +342,7 @@ class SimpleHttpClient(object):
response = yield self.request(
"GET",
url.encode("ascii"),
url,
headers=Headers(actual_headers),
)
@@ -434,12 +437,12 @@ class CaptchaServerHttpClient(SimpleHttpClient):
@defer.inlineCallbacks
def post_urlencoded_get_raw(self, url, args={}):
query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True)
response = yield self.request(
"POST",
url.encode("ascii"),
bodyProducer=FileBodyProducer(StringIO(query_bytes)),
url,
data=query_bytes,
headers=Headers({
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
@@ -510,7 +513,7 @@ def encode_urlencode_args(args):
def encode_urlencode_arg(arg):
if isinstance(arg, unicode):
if isinstance(arg, text_type):
return arg.encode('utf-8')
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
@@ -542,26 +545,3 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
def creatorForNetloc(self, hostname, port):
return self
class FileBodyProducer(TwistedFileBodyProducer):
"""Workaround for https://twistedmatrix.com/trac/ticket/8473
We override the pauseProducing and resumeProducing methods in twisted's
FileBodyProducer so that they do not raise exceptions if the task has
already completed.
"""
def pauseProducing(self):
try:
super(FileBodyProducer, self).pauseProducing()
except task.TaskDone:
# task has already completed
pass
def resumeProducing(self):
try:
super(FileBodyProducer, self).resumeProducing()
except task.NotPaused:
# task was not paused (probably because it had already completed)
pass

View File

@@ -17,19 +17,19 @@ import cgi
import logging
import random
import sys
import urllib
from six import string_types
from six.moves.urllib import parse as urlparse
from six import PY3, string_types
from six.moves import urllib
from canonicaljson import encode_canonical_json, json
import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
from twisted.internet import defer, protocol, reactor
from twisted.internet.error import DNSLookupError
from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody
from twisted.web.client import Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
import synapse.metrics
@@ -58,13 +58,18 @@ incoming_responses_counter = Counter("synapse_http_matrixfederationclient_respon
MAX_LONG_RETRIES = 10
MAX_SHORT_RETRIES = 3
if PY3:
MAXINT = sys.maxsize
else:
MAXINT = sys.maxint
class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.tls_client_options_factory = hs.tls_client_options_factory
def endpointForURI(self, uri):
destination = uri.netloc
destination = uri.netloc.decode('ascii')
return matrix_federation_endpoint(
reactor, destination, timeout=10,
@@ -93,26 +98,32 @@ class MatrixFederationHttpClient(object):
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
self.version_string = hs.version_string
self.version_string = hs.version_string.encode('ascii')
self._next_id = 1
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
return urlparse.urlunparse(
("matrix", destination, path_bytes, param_bytes, query_bytes, "")
return urllib.parse.urlunparse(
(b"matrix", destination, path_bytes, param_bytes, query_bytes, b"")
)
@defer.inlineCallbacks
def _request(self, destination, method, path,
body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True,
json=None, json_callback=None,
param_bytes=b"",
query=None, retry_on_dns_fail=True,
timeout=None, long_retries=False,
ignore_backoff=False,
backoff_on_404=False):
""" Creates and sends a request to the given server
"""
Creates and sends a request to the given server.
Args:
destination (str): The remote server to send the HTTP request to.
method (str): HTTP method
path (str): The HTTP path
json (dict or None): JSON to send in the body.
json_callback (func or None): A callback to generate the JSON.
query (dict or None): Query arguments.
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): Back off if we get a 404
@@ -146,22 +157,29 @@ class MatrixFederationHttpClient(object):
ignore_backoff=ignore_backoff,
)
destination = destination.encode("ascii")
headers_dict = {}
path_bytes = path.encode("ascii")
with limiter:
headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination]
if query:
query_bytes = encode_query_args(query)
else:
query_bytes = b""
url_bytes = self._create_url(
destination, path_bytes, param_bytes, query_bytes
)
headers_dict = {
"User-Agent": [self.version_string],
"Host": [destination],
}
with limiter:
url = self._create_url(
destination.encode("ascii"), path_bytes, param_bytes, query_bytes
).decode('ascii')
txn_id = "%s-O-%s" % (method, self._next_id)
self._next_id = (self._next_id + 1) % (sys.maxint - 1)
self._next_id = (self._next_id + 1) % (MAXINT - 1)
outbound_logger.info(
"{%s} [%s] Sending request: %s %s",
txn_id, destination, method, url_bytes
txn_id, destination, method, url
)
# XXX: Would be much nicer to retry only at the transaction-layer
@@ -171,23 +189,33 @@ class MatrixFederationHttpClient(object):
else:
retries_left = MAX_SHORT_RETRIES
http_url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "")
)
http_url = urllib.parse.urlunparse(
(b"", b"", path_bytes, param_bytes, query_bytes, b"")
).decode('ascii')
log_result = None
try:
while True:
producer = None
if body_callback:
producer = body_callback(method, http_url_bytes, headers_dict)
try:
request_deferred = self.agent.request(
if json_callback:
json = json_callback()
if json:
data = encode_canonical_json(json)
headers_dict["Content-Type"] = ["application/json"]
self.sign_request(
destination, method, http_url, headers_dict, json
)
else:
data = None
self.sign_request(destination, method, http_url, headers_dict)
request_deferred = treq.request(
method,
url_bytes,
Headers(headers_dict),
producer
url,
headers=Headers(headers_dict),
data=data,
agent=self.agent,
)
add_timeout_to_deferred(
request_deferred,
@@ -218,7 +246,7 @@ class MatrixFederationHttpClient(object):
txn_id,
destination,
method,
url_bytes,
url,
_flatten_response_never_received(e),
)
@@ -252,7 +280,7 @@ class MatrixFederationHttpClient(object):
# :'(
# Update transactions table?
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
body = yield treq.content(response)
raise HttpResponseException(
response.code, response.phrase, body
)
@@ -297,11 +325,11 @@ class MatrixFederationHttpClient(object):
auth_headers = []
for key, sig in request["signatures"][self.server_name].items():
auth_headers.append(bytes(
auth_headers.append((
"X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
self.server_name, key, sig,
)
))
)).encode('ascii')
)
headers_dict[b"Authorization"] = auth_headers
@@ -347,24 +375,14 @@ class MatrixFederationHttpClient(object):
"""
if not json_data_callback:
def json_data_callback():
return data
def body_callback(method, url_bytes, headers_dict):
json_data = json_data_callback()
self.sign_request(
destination, method, url_bytes, headers_dict, json_data
)
producer = _JsonProducer(json_data)
return producer
json_data_callback = lambda: data
response = yield self._request(
destination,
"PUT",
path,
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
query_bytes=encode_query_args(args),
json_callback=json_data_callback,
query=args,
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
@@ -376,8 +394,8 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
defer.returnValue(json.loads(body))
body = yield treq.json_content(response)
defer.returnValue(body)
@defer.inlineCallbacks
def post_json(self, destination, path, data={}, long_retries=False,
@@ -410,20 +428,12 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
def body_callback(method, url_bytes, headers_dict):
self.sign_request(
destination, method, url_bytes, headers_dict, data
)
return _JsonProducer(data)
response = yield self._request(
destination,
"POST",
path,
query_bytes=encode_query_args(args),
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
query=args,
json=data,
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
@@ -434,9 +444,9 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
body = yield treq.json_content(response)
defer.returnValue(json.loads(body))
defer.returnValue(body)
@defer.inlineCallbacks
def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
@@ -471,16 +481,11 @@ class MatrixFederationHttpClient(object):
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._request(
destination,
"GET",
path,
query_bytes=encode_query_args(args),
body_callback=body_callback,
query=args,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
ignore_backoff=ignore_backoff,
@@ -491,9 +496,9 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
body = yield treq.json_content(response)
defer.returnValue(json.loads(body))
defer.returnValue(body)
@defer.inlineCallbacks
def delete_json(self, destination, path, long_retries=False,
@@ -523,13 +528,11 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
response = yield self._request(
destination,
"DELETE",
path,
query_bytes=encode_query_args(args),
headers_dict={"Content-Type": ["application/json"]},
query=args,
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
@@ -540,9 +543,9 @@ class MatrixFederationHttpClient(object):
check_content_type_is_json(response.headers)
with logcontext.PreserveLoggingContext():
body = yield readBody(response)
body = yield treq.json_content(response)
defer.returnValue(json.loads(body))
defer.returnValue(body)
@defer.inlineCallbacks
def get_file(self, destination, path, output_stream, args={},
@@ -569,26 +572,11 @@ class MatrixFederationHttpClient(object):
Fails with ``FederationDeniedError`` if this destination
is not on our federation whitelist
"""
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, string_types):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
logger.debug("Query bytes: %s Retry DNS: %s", query_bytes, retry_on_dns_fail)
def body_callback(method, url_bytes, headers_dict):
self.sign_request(destination, method, url_bytes, headers_dict)
return None
response = yield self._request(
destination,
"GET",
path,
query_bytes=query_bytes,
body_callback=body_callback,
query=args,
retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
)
@@ -639,30 +627,6 @@ def _readBodyToFile(response, stream, max_size):
return d
class _JsonProducer(object):
""" Used by the twisted http client to create the HTTP body from json
"""
def __init__(self, jsn):
self.reset(jsn)
def reset(self, jsn):
self.body = encode_canonical_json(jsn)
self.length = len(self.body)
def startProducing(self, consumer):
consumer.write(self.body)
return defer.succeed(None)
def pauseProducing(self):
pass
def stopProducing(self):
pass
def resumeProducing(self):
pass
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
@@ -693,7 +657,7 @@ def check_content_type_is_json(headers):
"No Content-Type header"
)
c_type = c_type[0] # only the first header
c_type = c_type[0].decode('ascii') # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RuntimeError(
@@ -711,6 +675,6 @@ def encode_query_args(args):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
query_bytes = urllib.urlencode(encoded_args, True)
query_bytes = urllib.parse.urlencode(encoded_args, True)
return query_bytes
return query_bytes.encode('utf8')

View File

@@ -204,14 +204,14 @@ class SynapseRequest(Request):
self.start_time = time.time()
self.request_metrics = RequestMetrics()
self.request_metrics.start(
self.start_time, name=servlet_name, method=self.method,
self.start_time, name=servlet_name, method=self.method.decode('ascii'),
)
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.method,
self.method.decode('ascii'),
self.get_redacted_uri()
)

View File

@@ -40,9 +40,10 @@ REQUIREMENTS = {
"pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=17.1.0": ["twisted>=17.1.0"],
"treq>=15.1": ["treq>=15.1"],
# We use crypto.get_elliptic_curve which is only supported in >=0.15
"pyopenssl>=0.15": ["OpenSSL>=0.15"],
# Twisted has required pyopenssl 16.0 since about Twisted 16.6.
"pyopenssl>=16.0.0": ["OpenSSL>=16.0.0"],
"pyyaml": ["yaml"],
"pyasn1": ["pyasn1"],
@@ -78,6 +79,9 @@ CONDITIONAL_REQUIREMENTS = {
"affinity": {
"affinity": ["affinity"],
},
"postgres": {
"psycopg2>=2.6": ["psycopg2"]
}
}

View File

@@ -590,9 +590,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
pending_commands = LaterGauge(
"synapse_replication_tcp_protocol_pending_commands",
"",
["name", "conn_id"],
["name"],
lambda: {
(p.name, p.conn_id): len(p.pending_commands) for p in connected_connections
(p.name,): len(p.pending_commands) for p in connected_connections
},
)
@@ -607,9 +607,9 @@ def transport_buffer_size(protocol):
transport_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_send_buffer",
"",
["name", "conn_id"],
["name"],
lambda: {
(p.name, p.conn_id): transport_buffer_size(p) for p in connected_connections
(p.name,): transport_buffer_size(p) for p in connected_connections
},
)
@@ -632,9 +632,9 @@ def transport_kernel_read_buffer_size(protocol, read=True):
tcp_transport_kernel_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_kernel_send_buffer",
"",
["name", "conn_id"],
["name"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
(p.name,): transport_kernel_read_buffer_size(p, False)
for p in connected_connections
},
)
@@ -643,9 +643,9 @@ tcp_transport_kernel_send_buffer = LaterGauge(
tcp_transport_kernel_read_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_kernel_read_buffer",
"",
["name", "conn_id"],
["name"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
(p.name,): transport_kernel_read_buffer_size(p, True)
for p in connected_connections
},
)
@@ -654,9 +654,9 @@ tcp_transport_kernel_read_buffer = LaterGauge(
tcp_inbound_commands = LaterGauge(
"synapse_replication_tcp_protocol_inbound_commands",
"",
["command", "name", "conn_id"],
["command", "name"],
lambda: {
(k[0], p.name, p.conn_id): count
(k[0], p.name,): count
for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter)
},
@@ -665,9 +665,9 @@ tcp_inbound_commands = LaterGauge(
tcp_outbound_commands = LaterGauge(
"synapse_replication_tcp_protocol_outbound_commands",
"",
["command", "name", "conn_id"],
["command", "name"],
lambda: {
(k[0], p.name, p.conn_id): count
(k[0], p.name,): count
for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter)
},

View File

@@ -23,6 +23,7 @@ from twisted.internet import defer
import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.rest.client.v1.base import ClientV1RestServlet
from synapse.types import create_requester
@@ -281,12 +282,20 @@ class RegisterRestServlet(ClientV1RestServlet):
register_json["user"].encode("utf-8")
if "user" in register_json else None
)
threepid = None
if session.get(LoginType.EMAIL_IDENTITY):
threepid = session["threepidCreds"]
handler = self.handlers.registration_handler
(user_id, token) = yield handler.register(
localpart=desired_user_id,
password=password
password=password,
threepid=threepid,
)
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(self.hs.config, threepid):
yield self.store.upsert_monthly_active_user(user_id)
if session[LoginType.EMAIL_IDENTITY]:
logger.debug("Binding emails %s to %s" % (

View File

@@ -53,7 +53,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
403,
"Your email domain is not authorized on this server",
Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
@@ -89,7 +91,9 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid(
@@ -241,7 +245,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
403,
"Your email domain is not authorized on this server",
Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid(
@@ -276,7 +282,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid(

View File

@@ -26,6 +26,7 @@ import synapse
import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -74,7 +75,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
403,
"Your email domain is not authorized to register on this server",
Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
@@ -114,7 +117,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
403,
"Phone numbers are not authorized to register on this server",
Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
@@ -372,7 +377,9 @@ class RegisterRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, medium, address):
raise SynapseError(
403, "Third party identifier is not allowed",
403,
"Third party identifiers (email/phone numbers)" +
" are not authorized on this server",
Codes.THREEPID_DENIED,
)
@@ -395,12 +402,21 @@ class RegisterRestServlet(RestServlet):
if desired_username is not None:
desired_username = desired_username.lower()
threepid = None
if auth_result:
threepid = auth_result.get(LoginType.EMAIL_IDENTITY)
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
guest_access_token=guest_access_token,
generate_token=False,
threepid=threepid,
)
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(self.hs.config, threepid):
yield self.store.upsert_monthly_active_user(registered_user_id)
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)

View File

@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
format_event_raw,
serialize_event,
)
from synapse.handlers.presence import format_user_presence_state
@@ -175,17 +176,28 @@ class SyncRestServlet(RestServlet):
@staticmethod
def encode_response(time_now, sync_result, access_token_id, filter):
if filter.event_format == 'client':
event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == 'federation':
event_formatter = format_event_raw
else:
raise Exception("Unknown event format %s" % (filter.event_format, ))
joined = SyncRestServlet.encode_joined(
sync_result.joined, time_now, access_token_id, filter.event_fields
sync_result.joined, time_now, access_token_id,
filter.event_fields,
event_formatter,
)
invited = SyncRestServlet.encode_invited(
sync_result.invited, time_now, access_token_id,
event_formatter,
)
archived = SyncRestServlet.encode_archived(
sync_result.archived, time_now, access_token_id,
filter.event_fields,
event_formatter,
)
return {
@@ -228,7 +240,7 @@ class SyncRestServlet(RestServlet):
}
@staticmethod
def encode_joined(rooms, time_now, token_id, event_fields):
def encode_joined(rooms, time_now, token_id, event_fields, event_formatter):
"""
Encode the joined rooms in a sync result
@@ -240,7 +252,9 @@ class SyncRestServlet(RestServlet):
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned.
all fields will be returned.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns:
dict[str, dict[str, object]]: the joined rooms list, in our
response format
@@ -248,13 +262,14 @@ class SyncRestServlet(RestServlet):
joined = {}
for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, only_fields=event_fields
room, time_now, token_id, joined=True, only_fields=event_fields,
event_formatter=event_formatter,
)
return joined
@staticmethod
def encode_invited(rooms, time_now, token_id):
def encode_invited(rooms, time_now, token_id, event_formatter):
"""
Encode the invited rooms in a sync result
@@ -264,7 +279,9 @@ class SyncRestServlet(RestServlet):
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
of transaction IDs
event_formatter (func[dict]): function to convert from federation format
to client format
Returns:
dict[str, dict[str, object]]: the invited rooms list, in our
@@ -274,7 +291,7 @@ class SyncRestServlet(RestServlet):
for room in rooms:
invite = serialize_event(
room.invite, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id,
event_format=event_formatter,
is_invite=True,
)
unsigned = dict(invite.get("unsigned", {}))
@@ -288,7 +305,7 @@ class SyncRestServlet(RestServlet):
return invited
@staticmethod
def encode_archived(rooms, time_now, token_id, event_fields):
def encode_archived(rooms, time_now, token_id, event_fields, event_formatter):
"""
Encode the archived rooms in a sync result
@@ -300,7 +317,9 @@ class SyncRestServlet(RestServlet):
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
event_fields(list<str>): List of event fields to include. If empty,
all fields will be returned.
all fields will be returned.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns:
dict[str, dict[str, object]]: The invited rooms list, in our
response format
@@ -308,13 +327,18 @@ class SyncRestServlet(RestServlet):
joined = {}
for room in rooms:
joined[room.room_id] = SyncRestServlet.encode_room(
room, time_now, token_id, joined=False, only_fields=event_fields
room, time_now, token_id, joined=False,
only_fields=event_fields,
event_formatter=event_formatter,
)
return joined
@staticmethod
def encode_room(room, time_now, token_id, joined=True, only_fields=None):
def encode_room(
room, time_now, token_id, joined,
only_fields, event_formatter,
):
"""
Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a
@@ -326,14 +350,15 @@ class SyncRestServlet(RestServlet):
joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events
only_fields(list<str>): Optional. The list of event fields to include.
event_formatter (func[dict]): function to convert from federation format
to client format
Returns:
dict[str, object]: the room, encoded in our response format
"""
def serialize(event):
# TODO(mjark): Respect formatting requirements in the filter.
return serialize_event(
event, time_now, token_id=token_id,
event_format=format_event_for_client_v2_without_room_id,
event_format=event_formatter,
only_event_fields=only_fields,
)

View File

@@ -19,6 +19,7 @@
# partial one for unit test mocking.
# Imports required for the default HomeServer() implementation
import abc
import logging
from twisted.enterprise import adbapi
@@ -81,7 +82,6 @@ from synapse.server_notices.server_notices_manager import ServerNoticesManager
from synapse.server_notices.server_notices_sender import ServerNoticesSender
from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
@@ -111,6 +111,8 @@ class HomeServer(object):
config (synapse.config.homeserver.HomeserverConfig):
"""
__metaclass__ = abc.ABCMeta
DEPENDENCIES = [
'http_client',
'db_pool',
@@ -172,6 +174,11 @@ class HomeServer(object):
'room_context_handler',
]
# This is overridden in derived application classes
# (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
# instantiated during setup() for future return by get_datastore()
DATASTORE_CLASS = abc.abstractproperty()
def __init__(self, hostname, reactor=None, **kwargs):
"""
Args:
@@ -188,13 +195,16 @@ class HomeServer(object):
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
self.datastore = None
# Other kwargs are explicit dependencies
for depname in kwargs:
setattr(self, depname, kwargs[depname])
def setup(self):
logger.info("Setting up.")
self.datastore = DataStore(self.get_db_conn(), self)
with self.get_db_conn() as conn:
self.datastore = self.DATASTORE_CLASS(conn, self)
logger.info("Finished setting up.")
def get_reactor(self):

View File

@@ -46,6 +46,8 @@ class ResourceLimitsServerNotices(object):
self._message_handler = hs.get_message_handler()
self._state = hs.get_state_handler()
self._notifier = hs.get_notifier()
@defer.inlineCallbacks
def maybe_send_server_notice_to_user(self, user_id):
"""Check if we need to send a notice to this user, this will be true in
@@ -116,7 +118,7 @@ class ResourceLimitsServerNotices(object):
'body': event_content,
'msgtype': ServerNoticeMsgType,
'server_notice_type': ServerNoticeLimitReached,
'admin_uri': self._config.admin_uri,
'admin_contact': self._config.admin_contact,
'limit_type': event_limit_type
}
event = yield self._server_notices_manager.send_notice(
@@ -144,16 +146,18 @@ class ResourceLimitsServerNotices(object):
user_id(str): the user in question
room_id(str): the server notices room for that user
"""
tags = yield self._store.get_tags_for_user(user_id)
server_notices_tags = tags.get(room_id)
tags = yield self._store.get_tags_for_room(user_id, room_id)
need_to_set_tag = True
if server_notices_tags:
if server_notices_tags.get(SERVER_NOTICE_ROOM_TAG):
if tags:
if SERVER_NOTICE_ROOM_TAG in tags:
# tag already present, nothing to do here
need_to_set_tag = False
if need_to_set_tag:
yield self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, None
max_id = yield self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
self._notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
@defer.inlineCallbacks

View File

@@ -39,6 +39,8 @@ class ServerNoticesManager(object):
self._event_creation_handler = hs.get_event_creation_handler()
self._is_mine_id = hs.is_mine_id
self._notifier = hs.get_notifier()
def is_enabled(self):
"""Checks if server notices are enabled on this server.
@@ -153,8 +155,12 @@ class ServerNoticesManager(object):
creator_join_profile=join_profile,
)
room_id = info['room_id']
yield self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, None
max_id = yield self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {},
)
self._notifier.on_new_event(
"account_data_key", max_id, users=[user_id]
)
logger.info("Created server notices room %s for %s", room_id, user_id)

View File

@@ -385,6 +385,7 @@ class StateHandler(object):
ev_ids, get_prev_content=False, check_redacted=False,
)
@defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
@@ -401,15 +402,17 @@ class StateHandler(object):
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(
room_version, state_set_ids, state_map,
new_state = yield resolve_events_with_factory(
room_version, state_set_ids,
event_map=state_map,
state_map_factory=self._state_map_factory
)
new_state = {
key: state_map[ev_id] for key, ev_id in iteritems(new_state)
}
return new_state
defer.returnValue(new_state)
class StateResolutionHandler(object):
@@ -589,31 +592,6 @@ def _make_state_cache_entry(
)
def resolve_events_with_state_map(room_version, state_sets, state_map):
"""
Args:
room_version(str): Version of the room
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map(dict): a dict from event_id to event, for all events in
state_sets.
Returns
dict[(str, str), str]:
a map from (type, state_key) to event_id.
"""
if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
return v1.resolve_events_with_state_map(
state_sets, state_map,
)
else:
# This should only happen if we added a version but forgot to add it to
# the list above.
raise Exception(
"No state resolution algorithm defined for version %r" % (room_version,)
)
def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
"""
Args:

View File

@@ -30,34 +30,6 @@ logger = logging.getLogger(__name__)
POWER_KEY = (EventTypes.PowerLevels, "")
def resolve_events_with_state_map(state_sets, state_map):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map(dict): a dict from event_id to event, for all events in
state_sets.
Returns
dict[(str, str), str]:
a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
@defer.inlineCallbacks
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
"""

View File

@@ -17,9 +17,10 @@ import sys
import threading
import time
from six import iteritems, iterkeys, itervalues
from six import PY2, iteritems, iterkeys, itervalues
from six.moves import intern, range
from canonicaljson import json
from prometheus_client import Histogram
from twisted.internet import defer
@@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
something went wrong.
"""
pass
def db_to_json(db_content):
"""
Take some data from a database row and return a JSON-decoded object.
Args:
db_content (memoryview|buffer|bytes|bytearray|unicode)
"""
# psycopg2 on Python 3 returns memoryview objects, which we need to
# cast to bytes to decode
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
# psycopg2 on Python 2 returns buffer objects, which we need to cast to
# bytes to decode
if PY2 and isinstance(db_content, buffer):
db_content = bytes(db_content)
# Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode('utf8')
try:
return json.loads(db_content)
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise

View File

@@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
devices = messages_by_device.keys()
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = (

View File

@@ -24,7 +24,7 @@ from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from ._base import Cache, SQLBaseStore
from ._base import Cache, SQLBaseStore, db_to_json
logger = logging.getLogger(__name__)
@@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore):
if device is not None:
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore):
retcol="content",
desc="_get_cached_user_device",
)
defer.returnValue(json.loads(content))
defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
@@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore):
desc="_get_cached_devices_for_user",
)
defer.returnValue({
device["device_id"]: json.loads(device["content"])
device["device_id"]: db_to_json(device["content"])
for device in devices
})
@@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name

View File

@@ -14,13 +14,13 @@
# limitations under the License.
from six import iteritems
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
from ._base import SQLBaseStore, db_to_json
class EndToEndKeyStore(SQLBaseStore):
@@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore):
for user_id, device_keys in iteritems(results):
for device_id, device_info in iteritems(device_keys):
device_info["keys"] = json.loads(device_info.pop("key_json"))
device_info["keys"] = db_to_json(device_info.pop("key_json"))
defer.returnValue(results)

View File

@@ -41,13 +41,18 @@ class PostgresEngine(object):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
# Set the bytea output to escape, vs the default of hex
cursor = db_conn.cursor()
cursor.execute("SET bytea_output TO escape")
# Asynchronous commit, don't wait for the server to call fsync before
# ending the transaction.
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
if not self.synchronous_commit:
cursor = db_conn.cursor()
cursor.execute("SET synchronous_commit TO OFF")
cursor.close()
cursor.close()
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):

View File

@@ -19,7 +19,7 @@ import logging
from collections import OrderedDict, deque, namedtuple
from functools import wraps
from six import iteritems
from six import iteritems, text_type
from six.moves import range
from canonicaljson import json
@@ -1220,7 +1220,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
"sender": event.sender,
"contains_url": (
"url" in event.content
and isinstance(event.content["url"], basestring)
and isinstance(event.content["url"], text_type)
),
}
for event, _ in events_and_contexts
@@ -1529,7 +1529,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
contains_url = "url" in content
if contains_url:
contains_url &= isinstance(content["url"], basestring)
contains_url &= isinstance(content["url"], text_type)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@@ -1910,9 +1910,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
(room_id,)
)
rows = txn.fetchall()
max_depth = max(row[0] for row in rows)
max_depth = max(row[1] for row in rows)
if max_depth <= token.topological:
if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)

View File

@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import namedtuple
@@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
with Measure(self._clock, "_fetch_event_list"):
try:
event_id_lists = zip(*event_list)[0]
event_id_lists = list(zip(*event_list))[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
@@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
def fire(evs):
def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
d.errback(e)
d.errback(exc)
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list)
self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):

View File

@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks
from ._base import SQLBaseStore
from ._base import SQLBaseStore, db_to_json
class FilteringStore(SQLBaseStore):
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
defer.returnValue(db_to_json(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)

View File

@@ -36,7 +36,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
@defer.inlineCallbacks
def initialise_reserved_users(self, threepids):
# TODO Why can't I do this in init?
store = self.hs.get_datastore()
reserved_user_list = []
@@ -200,10 +199,14 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Args:
user_id(str): the user_id to query
"""
if self.hs.config.limit_usage_by_mau:
# Trial users and guests should not be included as part of MAU group
is_guest = yield self.is_guest(user_id)
if is_guest:
return
is_trial = yield self.is_trial_user(user_id)
if is_trial:
# we don't track trial users in the MAU table.
return
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)

View File

@@ -15,7 +15,8 @@
# limitations under the License.
import logging
import types
import six
from canonicaljson import encode_canonical_json, json
@@ -27,6 +28,11 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
if six.PY2:
db_binary_type = buffer
else:
db_binary_type = memoryview
class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
@@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore):
dataJson = r['data']
r['data'] = None
try:
if isinstance(dataJson, types.BufferType):
if isinstance(dataJson, db_binary_type):
dataJson = str(dataJson).decode("UTF8")
r['data'] = json.loads(dataJson)
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
r['id'], dataJson, e.message,
r['id'], dataJson, e.args[0],
)
pass
if isinstance(r['pushkey'], types.BufferType):
if isinstance(r['pushkey'], db_binary_type):
r['pushkey'] = str(r['pushkey']).decode("UTF8")
return rows

View File

@@ -18,14 +18,14 @@ from collections import namedtuple
import six
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
@@ -95,7 +95,8 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
return result["response_code"], json.loads(str(result["response_json"]))
return result["response_code"], db_to_json(result["response_json"])
else:
return None

View File

@@ -457,7 +457,7 @@ class AuthTestCase(unittest.TestCase):
with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking()
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@@ -467,13 +467,30 @@ class AuthTestCase(unittest.TestCase):
)
yield self.auth.check_auth_blocking()
@defer.inlineCallbacks
def test_reserved_threepid(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
threepid = {'medium': 'email', 'address': 'reserved@server.com'}
unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'}
self.hs.config.mau_limits_reserved_threepids = [threepid]
yield self.store.register(user_id='user1', token="123", password_hash=None)
with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking()
with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking(threepid=unknown_threepid)
yield self.auth.check_auth_blocking(threepid=threepid)
@defer.inlineCallbacks
def test_hs_disabled(self):
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield self.auth.check_auth_blocking()
self.assertEquals(e.exception.admin_uri, self.hs.config.admin_uri)
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,79 +14,79 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import synapse.api.errors
import synapse.handlers.device
import synapse.storage
from tests import unittest, utils
from tests import unittest
user1 = "@boris:aaa"
user2 = "@theresa:bbb"
class DeviceTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(DeviceTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
self.handler = None # type: synapse.handlers.device.DeviceHandler
self.clock = None # type: utils.MockClock
@defer.inlineCallbacks
def setUp(self):
hs = yield utils.setup_test_homeserver(self.addCleanup)
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver("server", http_client=None)
self.handler = hs.get_device_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
return hs
def prepare(self, reactor, clock, hs):
# These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000)
@defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self):
res = yield self.handler.check_device_registered(
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name",
res = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name",
)
)
self.assertEqual(res, "fco")
dev = yield self.handler.store.get_device("@boris:foo", "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_is_preserved_if_exists(self):
res1 = yield self.handler.check_device_registered(
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name",
res1 = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name",
)
)
self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered(
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="new display name",
res2 = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="new display name",
)
)
self.assertEqual(res2, "fco")
dev = yield self.handler.store.get_device("@boris:foo", "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self):
device_id = yield self.handler.check_device_registered(
user_id="@theresa:foo",
device_id=None,
initial_device_display_name="display",
device_id = self.get_success(
self.handler.check_device_registered(
user_id="@theresa:foo",
device_id=None,
initial_device_display_name="display",
)
)
dev = yield self.handler.store.get_device("@theresa:foo", device_id)
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
self.assertEqual(dev["display_name"], "display")
@defer.inlineCallbacks
def test_get_devices_by_user(self):
yield self._record_users()
self._record_users()
res = self.get_success(self.handler.get_devices_by_user(user1))
res = yield self.handler.get_devices_by_user(user1)
self.assertEqual(3, len(res))
device_map = {d["device_id"]: d for d in res}
self.assertDictContainsSubset(
@@ -119,11 +120,10 @@ class DeviceTestCase(unittest.TestCase):
device_map["abc"],
)
@defer.inlineCallbacks
def test_get_device(self):
yield self._record_users()
self._record_users()
res = yield self.handler.get_device(user1, "abc")
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertDictContainsSubset(
{
"user_id": user1,
@@ -135,59 +135,66 @@ class DeviceTestCase(unittest.TestCase):
res,
)
@defer.inlineCallbacks
def test_delete_device(self):
yield self._record_users()
self._record_users()
# delete the device
yield self.handler.delete_device(user1, "abc")
self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.get_device(user1, "abc")
res = self.handler.get_device(user1, "abc")
self.pump()
self.assertIsInstance(
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
)
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
@defer.inlineCallbacks
def test_update_device(self):
yield self._record_users()
self._record_users()
update = {"display_name": "new display"}
yield self.handler.update_device(user1, "abc", update)
self.get_success(self.handler.update_device(user1, "abc", update))
res = yield self.handler.get_device(user1, "abc")
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "new display")
@defer.inlineCallbacks
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.update_device("user_id", "unknown_device_id", update)
res = self.handler.update_device("user_id", "unknown_device_id", update)
self.pump()
self.assertIsInstance(
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
)
@defer.inlineCallbacks
def _record_users(self):
# check this works for both devices which have a recorded client_ip,
# and those which don't.
yield self._record_user(user1, "xyz", "display 0")
yield self._record_user(user1, "fco", "display 1", "token1", "ip1")
yield self._record_user(user1, "abc", "display 2", "token2", "ip2")
yield self._record_user(user1, "abc", "display 2", "token3", "ip3")
self._record_user(user1, "xyz", "display 0")
self._record_user(user1, "fco", "display 1", "token1", "ip1")
self._record_user(user1, "abc", "display 2", "token2", "ip2")
self._record_user(user1, "abc", "display 2", "token3", "ip3")
yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
self._record_user(user2, "def", "dispkay", "token4", "ip4")
self.reactor.advance(10000)
@defer.inlineCallbacks
def _record_user(
self, user_id, device_id, display_name, access_token=None, ip=None
):
device_id = yield self.handler.check_device_registered(
user_id=user_id,
device_id=device_id,
initial_device_display_name=display_name,
device_id = self.get_success(
self.handler.check_device_registered(
user_id=user_id,
device_id=device_id,
initial_device_display_name=display_name,
)
)
if ip is not None:
yield self.store.insert_client_ip(
user_id, access_token, ip, "user_agent", device_id
self.get_success(
self.store.insert_client_ip(
user_id, access_token, ip, "user_agent", device_id
)
)
self.clock.advance_time(1000)
self.reactor.advance(1000)

View File

@@ -33,7 +33,7 @@ from ..utils import (
)
def _expect_edu(destination, edu_type, content, origin="test"):
def _expect_edu_transaction(edu_type, content, origin="test"):
return {
"origin": origin,
"origin_server_ts": 1000000,
@@ -42,8 +42,8 @@ def _expect_edu(destination, edu_type, content, origin="test"):
}
def _make_edu_json(origin, edu_type, content):
return json.dumps(_expect_edu("test", edu_type, content, origin=origin)).encode(
def _make_edu_transaction_json(edu_type, content):
return json.dumps(_expect_edu_transaction(edu_type, content)).encode(
'utf8'
)
@@ -190,8 +190,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
call(
"farm",
path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu(
"farm",
data=_expect_edu_transaction(
"m.typing",
content={
"room_id": self.room_id,
@@ -221,11 +220,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
yield self.mock_federation_resource.trigger(
(code, response) = yield self.mock_federation_resource.trigger(
"PUT",
"/_matrix/federation/v1/send/1000000/",
_make_edu_json(
"farm",
_make_edu_transaction_json(
"m.typing",
content={
"room_id": self.room_id,
@@ -233,7 +231,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"typing": True,
},
),
federation_auth=True,
federation_auth_origin=b'farm',
)
self.on_new_event.assert_has_calls(
@@ -264,8 +262,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
call(
"farm",
path="/_matrix/federation/v1/send/1000000/",
data=_expect_edu(
"farm",
data=_expect_edu_transaction(
"m.typing",
content={
"room_id": self.room_id,

View File

@@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,89 +12,91 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from mock import Mock, NonCallableMock
from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred
import attr
from synapse.replication.tcp.client import (
ReplicationClientFactory,
ReplicationClientHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
from tests import unittest
from tests.utils import setup_test_homeserver
class TestReplicationClientHandler(ReplicationClientHandler):
"""Overrides on_rdata so that we can wait for it to happen"""
class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def __init__(self, store):
super(TestReplicationClientHandler, self).__init__(store)
self._rdata_awaiters = []
def await_replication(self):
d = Deferred()
self._rdata_awaiters.append(d)
return make_deferred_yieldable(d)
def on_rdata(self, stream_name, token, rows):
awaiters = self._rdata_awaiters
self._rdata_awaiters = []
super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
with PreserveLoggingContext():
for a in awaiters:
a.callback(None)
class BaseSlavedStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(
self.addCleanup,
hs = self.setup_test_homeserver(
"blue",
http_client=None,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
hs.get_ratelimiter().send_message.return_value = (True, 0)
return hs
def prepare(self, reactor, clock, hs):
self.master_store = self.hs.get_datastore()
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
# XXX: mktemp is unsafe and should never be used. but we're just a test.
path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
listener = reactor.listenUNIX(path, server_factory)
self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer
self.replication_handler = TestReplicationClientHandler(self.slaved_store)
self.replication_handler = ReplicationClientHandler(self.slaved_store)
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
client_connector = reactor.connectUNIX(path, client_factory)
self.addCleanup(client_factory.stopTrying)
self.addCleanup(client_connector.disconnect)
server = server_factory.buildProtocol(None)
client = client_factory.buildProtocol(None)
@attr.s
class FakeTransport(object):
other = attr.ib()
disconnecting = False
buffer = attr.ib(default=b'')
def registerProducer(self, producer, streaming):
self.producer = producer
def _produce():
self.producer.resumeProducing()
reactor.callLater(0.1, _produce)
reactor.callLater(0.0, _produce)
def write(self, byt):
self.buffer = self.buffer + byt
if getattr(self.other, "transport") is not None:
self.other.dataReceived(self.buffer)
self.buffer = b""
def writeSequence(self, seq):
for x in seq:
self.write(x)
client.makeConnection(FakeTransport(server))
server.makeConnection(FakeTransport(client))
def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
# xxx: should we be more specific in what we wait for?
d = self.replication_handler.await_replication()
self.streamer.on_notifier_poke()
return d
self.pump(0.1)
@defer.inlineCallbacks
def check(self, method, args, expected_result=None):
master_result = yield getattr(self.master_store, method)(*args)
slaved_result = yield getattr(self.slaved_store, method)(*args)
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result)

View File

@@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from ._base import BaseSlavedStoreTestCase
@@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedAccountDataStore
@defer.inlineCallbacks
def test_user_account_data(self):
yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
yield self.replicate()
yield self.check(
self.get_success(
self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
)
self.replicate()
self.check(
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1}
)
yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
yield self.replicate()
yield self.check(
self.get_success(
self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
)
self.replicate()
self.check(
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2}
)

View File

@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext
from synapse.replication.slave.storage.events import SlavedEventStore
@@ -55,70 +53,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def tearDown(self):
[unpatch() for unpatch in self.unpatches]
@defer.inlineCallbacks
def test_get_latest_event_ids_in_room(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
create = self.persist(type="m.room.create", key="", creator=USER_ID)
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
join = yield self.persist(
join = self.persist(
type="m.room.member",
key=USER_ID,
membership="join",
prev_events=[(create.event_id, {})],
)
yield self.replicate()
yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
@defer.inlineCallbacks
def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
self.replicate()
self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id)
yield self.replicate()
redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
self.check("get_event", [msg.event_id], redacted)
@defer.inlineCallbacks
def test_backfilled_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
self.replicate()
self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(
redaction = self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True
)
yield self.replicate()
self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
self.check("get_event", [msg.event_id], redacted)
@defer.inlineCallbacks
def test_invites(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
event = yield self.persist(
type="m.room.member", key=USER_ID_2, membership="invite"
)
yield self.replicate()
yield self.check(
self.persist(type="m.room.create", key="", creator=USER_ID)
self.check("get_invited_rooms_for_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
self.replicate()
self.check(
"get_invited_rooms_for_user",
[USER_ID_2],
[
@@ -132,37 +126,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
],
)
@defer.inlineCallbacks
def test_push_actions_for_user(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.join", key=USER_ID, membership="join")
yield self.persist(
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.join", key=USER_ID, membership="join")
self.persist(
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
)
event1 = yield self.persist(
type="m.room.message", msgtype="m.text", body="hello"
)
yield self.replicate()
yield self.check(
event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 0},
)
yield self.persist(
self.persist(
type="m.room.message",
msgtype="m.text",
body="world",
push_actions=[(USER_ID_2, ["notify"])],
)
yield self.replicate()
yield self.check(
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 1},
)
yield self.persist(
self.persist(
type="m.room.message",
msgtype="m.text",
body="world",
@@ -170,8 +161,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
(USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
],
)
yield self.replicate()
yield self.check(
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 1, "notify_count": 2},
@@ -179,7 +170,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event_id = 0
@defer.inlineCallbacks
def persist(
self,
sender=USER_ID,
@@ -206,8 +196,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
depth = self.event_id
if not prev_events:
latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
room_id
latest_event_ids = self.get_success(
self.master_store.get_latest_event_ids_in_room(room_id)
)
prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
@@ -240,19 +230,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
else:
state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event)
context = self.get_success(state_handler.compute_event_context(event))
yield self.master_store.add_push_actions_to_staging(
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
)
ordering = None
if backfill:
yield self.master_store.persist_events([(event, context)], backfilled=True)
self.get_success(
self.master_store.persist_events([(event, context)], backfilled=True)
)
else:
ordering, _ = yield self.master_store.persist_event(event, context)
ordering, _ = self.get_success(
self.master_store.persist_event(event, context)
)
if ordering:
event.internal_metadata.stream_ordering = ordering
defer.returnValue(event)
return event

View File

@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from ._base import BaseSlavedStoreTestCase
@@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedReceiptsStore
@defer.inlineCallbacks
def test_receipt(self):
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
yield self.master_store.insert_receipt(
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
)
yield self.replicate()
yield self.check(
"get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}
self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
self.get_success(
self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {})
)
self.replicate()
self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID})

View File

@@ -240,7 +240,6 @@ class RestHelper(object):
self.assertEquals(200, code)
defer.returnValue(response)
@defer.inlineCallbacks
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -248,9 +247,16 @@ class RestHelper(object):
body = "body_text_here"
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
content = '{"msgtype":"m.text","body":"%s"}' % body
content = {"msgtype": "m.text", "body": body}
if tok:
path = path + "?access_token=%s" % tok
(code, response) = yield self.mock_resource.trigger("PUT", path, content)
self.assertEquals(expect_code, code, msg=str(response))
request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
% (expect_code, int(channel.result["code"]), channel.result["body"])
)
return channel.json_body

View File

@@ -232,6 +232,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
clock.threadpool = ThreadPool()
pool.threadpool = ThreadPool()
pool.running = True
return d

View File

@@ -55,7 +55,8 @@ class TestResourceLimitsServerNotices(unittest.TestCase):
returnValue=""
)
self._rlsn._store.add_tag_to_room = Mock()
self.hs.config.admin_uri = "mailto:user@test.com"
self._rlsn._store.get_tags_for_room = Mock(return_value={})
self.hs.config.admin_contact = "mailto:user@test.com"
@defer.inlineCallbacks
def test_maybe_send_server_notice_to_user_flag_off(self):
@@ -172,7 +173,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.TestCase):
self.user_id = "@user_id:test"
self.hs.config.admin_uri = "mailto:user@test.com"
self.hs.config.admin_contact = "mailto:user@test.com"
@defer.inlineCallbacks
def test_server_notice_only_sent_once(self):

View File

@@ -37,18 +37,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.as_yaml_files = []
config = Mock(
app_service_config_files=self.as_yaml_files,
event_cache_size=1,
password_providers=[],
)
hs = yield setup_test_homeserver(
self.addCleanup,
config=config,
federation_sender=Mock(),
federation_client=Mock(),
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
hs.config.app_service_config_files = self.as_yaml_files
hs.config.event_cache_size = 1
hs.config.password_providers = []
self.as_token = "token1"
self.as_url = "some_url"
self.as_id = "as1"
@@ -58,7 +54,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
self.store = ApplicationServiceStore(None, hs)
self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
@@ -105,18 +101,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def setUp(self):
self.as_yaml_files = []
config = Mock(
app_service_config_files=self.as_yaml_files,
event_cache_size=1,
password_providers=[],
)
hs = yield setup_test_homeserver(
self.addCleanup,
config=config,
federation_sender=Mock(),
federation_client=Mock(),
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
hs.config.app_service_config_files = self.as_yaml_files
hs.config.event_cache_size = 1
hs.config.password_providers = []
self.db_pool = hs.get_db_pool()
self.engine = hs.database_engine
self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
@@ -129,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = []
self.store = TestTransactionStore(None, hs)
self.store = TestTransactionStore(hs.get_db_conn(), hs)
def _add_service(self, url, as_token, id):
as_yaml = dict(
@@ -146,29 +140,35 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files.append(as_token)
def _set_state(self, id, state, txn=None):
return self.db_pool.runQuery(
"INSERT INTO application_services_state(as_id, state, last_txn) "
"VALUES(?,?,?)",
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_state(as_id, state, last_txn) "
"VALUES(?,?,?)"
),
(id, state, txn),
)
def _insert_txn(self, as_id, txn_id, events):
return self.db_pool.runQuery(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)"
),
(as_id, txn_id, json.dumps([e.event_id for e in events])),
)
def _set_last_txn(self, as_id, txn_id):
return self.db_pool.runQuery(
"INSERT INTO application_services_state(as_id, last_txn, state) "
"VALUES(?,?,?)",
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_state(as_id, last_txn, state) "
"VALUES(?,?,?)"
),
(as_id, txn_id, ApplicationServiceState.UP),
)
@defer.inlineCallbacks
def test_get_appservice_state_none(self):
service = Mock(id=999)
service = Mock(id="999")
state = yield self.store.get_appservice_state(service)
self.assertEquals(None, state)
@@ -200,7 +200,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
rows = yield self.db_pool.runQuery(
"SELECT as_id FROM application_services_state WHERE state=?",
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
),
(ApplicationServiceState.DOWN,),
)
self.assertEquals(service.id, rows[0][0])
@@ -212,7 +214,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
rows = yield self.db_pool.runQuery(
"SELECT as_id FROM application_services_state WHERE state=?",
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
),
(ApplicationServiceState.UP,),
)
self.assertEquals(service.id, rows[0][0])
@@ -279,14 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
self.engine.convert_param_style(
"SELECT last_txn FROM application_services_state WHERE as_id=?"
),
(service.id,),
)
self.assertEquals(1, len(res))
self.assertEquals(txn_id, res[0][0])
res = yield self.db_pool.runQuery(
"SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
self.engine.convert_param_style(
"SELECT * FROM application_services_txns WHERE txn_id=?"
),
(txn_id,),
)
self.assertEquals(0, len(res))
@@ -300,7 +309,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery(
"SELECT last_txn, state FROM application_services_state WHERE " "as_id=?",
self.engine.convert_param_style(
"SELECT last_txn, state FROM application_services_state WHERE as_id=?"
),
(service.id,),
)
self.assertEquals(1, len(res))
@@ -308,7 +319,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.assertEquals(ApplicationServiceState.UP, res[0][1])
res = yield self.db_pool.runQuery(
"SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
self.engine.convert_param_style(
"SELECT * FROM application_services_txns WHERE txn_id=?"
),
(txn_id,),
)
self.assertEquals(0, len(res))
@@ -394,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2")
config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
)
hs = yield setup_test_homeserver(
self.addCleanup,
config=config,
datastore=Mock(),
federation_sender=Mock(),
federation_client=Mock(),
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
ApplicationServiceStore(None, hs)
hs.config.app_service_config_files = [f1, f2]
hs.config.event_cache_size = 1
hs.config.password_providers = []
ApplicationServiceStore(hs.get_db_conn(), hs)
@defer.inlineCallbacks
def test_duplicate_ids(self):
f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2")
config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
)
hs = yield setup_test_homeserver(
self.addCleanup,
config=config,
datastore=Mock(),
federation_sender=Mock(),
federation_client=Mock(),
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
hs.config.app_service_config_files = [f1, f2]
hs.config.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(None, hs)
ApplicationServiceStore(hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))
@@ -436,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2")
config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
)
hs = yield setup_test_homeserver(
self.addCleanup,
config=config,
datastore=Mock(),
federation_sender=Mock(),
federation_client=Mock(),
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
hs.config.app_service_config_files = [f1, f2]
hs.config.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(None, hs)
ApplicationServiceStore(hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))

View File

@@ -20,11 +20,11 @@ from mock import Mock
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import create_engine
from tests import unittest
from tests.utils import TestHomeServer
class SQLBaseStoreTestCase(unittest.TestCase):
@@ -51,7 +51,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
config = Mock()
config.event_cache_size = 1
config.database_config = {"name": "sqlite3"}
hs = HomeServer(
hs = TestHomeServer(
"test",
db_pool=self.db_pool,
config=config,

View File

@@ -101,13 +101,11 @@ class ClientIpStoreTestCase(tests.unittest.TestCase):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 50
user_id = "@user:server"
yield self.store.register(user_id=user_id, token="123", password_hash=None)
active = yield self.store.user_last_seen_monthly_active(user_id)
self.assertFalse(active)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)
yield self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
)

View File

@@ -16,7 +16,6 @@
from twisted.internet import defer
from synapse.storage.directory import DirectoryStore
from synapse.types import RoomAlias, RoomID
from tests import unittest
@@ -28,7 +27,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
self.store = DirectoryStore(None, hs)
self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")

View File

@@ -37,10 +37,10 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
(
"INSERT INTO events ("
" room_id, event_id, type, depth, topological_ordering,"
" content, processed, outlier) "
"VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
" content, processed, outlier, stream_ordering) "
"VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)"
),
(room_id, event_id, i, i, True, False),
(room_id, event_id, i, i, True, False, i),
)
txn.execute(

Some files were not shown because too many files have changed in this diff Show More