1
0

Compare commits

..

31 Commits
hhs-3 ... hhs-5

Author SHA1 Message Date
Neil Johnson
a6cf7d9d9a Merge pull request #3789 from matrix-org/neilj/improve_threepid_error_strings
improve human readable error messages
2018-09-04 13:16:00 +00:00
Neil Johnson
70fc599ede towncrier 2018-09-04 12:08:27 +01:00
Neil Johnson
bae37cd811 improve human readable error message 2018-09-04 12:07:00 +01:00
Neil Johnson
c42f7fd7b9 improve human readable error messages 2018-09-04 12:03:17 +01:00
Amber Brown
77055dba92 Fix tests on postgresql (#3740) 2018-09-04 02:21:48 +10:00
Erik Johnston
567363e497 Merge pull request #3737 from matrix-org/erikj/remove_redundant_state_func
Remove unnecessary resolve_events_with_state_map
2018-09-03 16:19:41 +01:00
Erik Johnston
06ee2b7cc5 Newsfile 2018-09-03 15:43:01 +01:00
Amber Brown
30cf40ff30 Merge pull request #3378 from NickEckardt/develop
matrix-synapse-auto-deploy is no longer maintained.
2018-09-03 22:47:23 +10:00
Amber Brown
905f8de673 Create 3378.misc 2018-09-03 22:17:06 +10:00
Amber Brown
4fc4b881c5 Merge branch 'develop' into develop 2018-09-03 21:08:35 +10:00
Neil Johnson
99178f8602 Merge pull request #3777 from matrix-org/neilj/fix_register_user_registration
fix bug where preserved threepid user comes to sign up and server is …
2018-08-31 16:31:48 +00:00
Neil Johnson
301cb60d0b assert rather than warn 2018-08-31 17:29:35 +01:00
Neil Johnson
0b01281e77 move threepid checker to config, add missing yields 2018-08-31 17:11:11 +01:00
Neil Johnson
e8e540630e fix reference to is_threepid_reserved 2018-08-31 16:09:15 +01:00
Neil Johnson
a796bdd35e typo 2018-08-31 15:57:33 +01:00
Neil Johnson
09f3cf1a7e ensure post registration auth checks do not fail erroneously 2018-08-31 15:42:51 +01:00
Neil Johnson
3d6aa06577 news fragemnt 2018-08-31 10:56:37 +01:00
Neil Johnson
ea068d6f3c fix bug where preserved threepid user comes to sign up and server is mau blocked 2018-08-31 10:49:14 +01:00
Amber Brown
14e4d4f4bf Port storage/ to Python 3 (#3725) 2018-08-31 00:19:58 +10:00
Richard van der Hoff
475253a88e Merge pull request #3764 from matrix-org/rav/close_db_conn_after_init
Make sure that we close db connections opened during init
2018-08-30 10:47:27 +01:00
Erik Johnston
7f0399586d Merge pull request #3768 from krombel/fix_3445
fix #3445 - do not use itervalues() on SortedDict()
2018-08-29 16:29:57 +01:00
Krombel
8c0c51ecb3 changelog 2018-08-29 16:49:26 +02:00
Krombel
79a8a347a6 fix #3445
itervalues(d) calls d.itervalues() [PY2] and d.values() [PY3]
but SortedDict only implements d.values()
2018-08-29 16:28:25 +02:00
Amber Brown
82276a18d1 Update CONTRIBUTING to clarify miscs & Markdown (#3730) 2018-08-29 19:06:59 +10:00
Matthew Hodgson
b1580f50fe don't return non-LL-member state in incremental sync state blocks (#3760)
don't return non-LL-member state in incremental sync state blocks
2018-08-28 23:25:58 +01:00
Richard van der Hoff
414fa36f3e Fix up tests 2018-08-28 17:21:05 +01:00
Richard van der Hoff
32eb1dedd2 use abc.abstractproperty
This gives clearer messages when someone gets it wrong
2018-08-28 17:10:43 +01:00
Richard van der Hoff
71990b3cae changelog 2018-08-28 13:42:20 +01:00
Richard van der Hoff
0b07f02e19 Make sure that we close db connections opened during init
We should explicitly close any db connections we open, because failing to do so
can block other transactions as per
https://github.com/matrix-org/synapse/issues/3682.

Let's also try to factor out some of the boilerplate by having server classes
define their datastore class rather than duplicating the whole of `setup`.
2018-08-28 13:39:49 +01:00
Erik Johnston
5c261107c9 Remove unnecessary resolve_events_with_state_map
We only ever used the synchronous resolve_events_with_state_map in one
place, which is trivial to replace with the async version.
2018-08-22 15:41:15 +01:00
Nicholas Eckardt
76c80e3fdf The project matrix-synapse-auto-deploy does not seem to be maintained anymore.
It has been over a year since any code has been commited. Some of the relevant links in
the documentation are broken, but since no pull requests are being accepted, they
won't get fixed. We should probably remove this from the README.
2018-06-10 18:24:12 -07:00
70 changed files with 727 additions and 507 deletions

View File

@@ -35,10 +35,6 @@ matrix:
- python: 3.6
env: TOX_ENV=check-newsfragment
allow_failures:
- python: 2.7
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
install:
- pip install tox

View File

@@ -59,9 +59,10 @@ To create a changelog entry, make a new file in the ``changelog.d``
file named in the format of ``PRnumber.type``. The type can be
one of ``feature``, ``bugfix``, ``removal`` (also used for
deprecations), or ``misc`` (for internal-only changes). The content of
the file is your changelog entry, which can contain RestructuredText
formatting. A note of contributors is welcomed in changelogs for
non-misc changes (the content of misc changes is not displayed).
the file is your changelog entry, which can contain Markdown
formatting. Adding credits to the changelog is encouraged, we value
your contributions and would like to have you shouted out in the
release notes!
For example, a fix in PR #1234 would have its changelog entry in
``changelog.d/1234.bugfix``, and contain content like "The security levels of

View File

@@ -167,11 +167,6 @@ Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a
Dockerfile to automate a synapse server in a single Docker image, at
https://hub.docker.com/r/avhost/docker-matrix/tags/
Also, Martin Giess has created an auto-deployment process with vagrant/ansible,
tested with VirtualBox/AWS/DigitalOcean - see
https://github.com/EMnify/matrix-synapse-auto-deploy
for details.
Configuring synapse
-------------------

1
changelog.d/3378.misc Normal file
View File

@@ -0,0 +1 @@
Removed the link to the unmaintained matrix-synapse-auto-deploy project from the readme.

1
changelog.d/3725.misc Normal file
View File

@@ -0,0 +1 @@
The synapse.storage module has been ported to Python 3.

1
changelog.d/3730.misc Normal file
View File

@@ -0,0 +1 @@
The CONTRIBUTING guidelines have been updated to mention our use of Markdown and that .misc files have content.

1
changelog.d/3737.misc Normal file
View File

@@ -0,0 +1 @@
Remove redundant state resolution function

1
changelog.d/3740.misc Normal file
View File

@@ -0,0 +1 @@
The test suite now passes on PostgreSQL.

1
changelog.d/3760.bugfix Normal file
View File

@@ -0,0 +1 @@
Don't return non-LL-member state in incremental sync state blocks

1
changelog.d/3764.misc Normal file
View File

@@ -0,0 +1 @@
Make sure that we close db connections opened during init

1
changelog.d/3768.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug in sending presence over federation

1
changelog.d/3777.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug where preserved threepid user comes to sign up and server is mau blocked

1
changelog.d/3789.misc Normal file
View File

@@ -0,0 +1 @@
Improve human readable error messages for threepid registration/account update

View File

@@ -31,5 +31,5 @@ $TOX_BIN/pip install 'setuptools>=18.5'
$TOX_BIN/pip install 'pip>=10'
{ python synapse/python_dependencies.py
echo lxml psycopg2
echo lxml
} | xargs $TOX_BIN/pip install

View File

@@ -26,6 +26,7 @@ import synapse.types
from synapse import event_auth
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
@@ -775,13 +776,19 @@ class Auth(object):
)
@defer.inlineCallbacks
def check_auth_blocking(self, user_id=None):
def check_auth_blocking(self, user_id=None, threepid=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
user_id(str|None): If present, checks for presence against existing
MAU cohort
threepid(dict|None): If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
"""
# Never fail an auth check for the server notices users
@@ -797,6 +804,8 @@ class Auth(object):
limit_type=self.hs.config.hs_disabled_limit_type
)
if self.hs.config.limit_usage_by_mau is True:
assert not (user_id and threepid)
# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
@@ -806,12 +815,16 @@ class Auth(object):
is_trial = yield self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
# If the user does not exist yet, but is signing up with a
# reserved threepid then pass auth check
if is_threepid_reserved(self.hs.config, threepid):
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise ResourceLimitError(
403, "Monthly Active User Limit Exceeded",
admin_contact=self.hs.config.admin_contact,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
limit_type="monthly_active_user"

View File

@@ -51,10 +51,7 @@ class AppserviceSlaveStore(
class AppserviceServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = AppserviceSlaveStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -74,10 +74,7 @@ class ClientReaderSlavedStore(
class ClientReaderServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = ClientReaderSlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -90,10 +90,7 @@ class EventCreatorSlavedStore(
class EventCreatorServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = EventCreatorSlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -72,10 +72,7 @@ class FederationReaderSlavedStore(
class FederationReaderServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = FederationReaderSlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -78,10 +78,7 @@ class FederationSenderSlaveStore(
class FederationSenderServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = FederationSenderSlaveStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -148,10 +148,7 @@ class FrontendProxySlavedStore(
class FrontendProxyServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = FrontendProxySlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -62,7 +62,7 @@ from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.server import HomeServer
from synapse.storage import are_all_users_on_domain
from synapse.storage import DataStore, are_all_users_on_domain
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
from synapse.util.caches import CACHE_SIZE_FACTOR
@@ -111,6 +111,8 @@ def build_resource_for_web_client(hs):
class SynapseHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
def _listener_http(self, config, listener_config):
port = listener_config["port"]
bind_addresses = listener_config["bind_addresses"]
@@ -356,13 +358,13 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name'])
try:
db_conn = hs.get_db_conn(run_new_connection=False)
prepare_database(db_conn, database_engine, config=config)
database_engine.on_new_connection(db_conn)
with hs.get_db_conn(run_new_connection=False) as db_conn:
prepare_database(db_conn, database_engine, config=config)
database_engine.on_new_connection(db_conn)
hs.run_startup_checks(db_conn, database_engine)
hs.run_startup_checks(db_conn, database_engine)
db_conn.commit()
db_conn.commit()
except UpgradeDatabaseException:
sys.stderr.write(
"\nFailed to upgrade database.\n"

View File

@@ -60,10 +60,7 @@ class MediaRepositorySlavedStore(
class MediaRepositoryServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = MediaRepositorySlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -78,10 +78,7 @@ class PusherSlaveStore(
class PusherServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = PusherSlaveStore
def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)

View File

@@ -249,10 +249,7 @@ class SynchrotronApplicationService(object):
class SynchrotronServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = SynchrotronSlavedStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -94,10 +94,7 @@ class UserDirectorySlaveStore(
class UserDirectoryServer(HomeServer):
def setup(self):
logger.info("Setting up.")
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
DATASTORE_CLASS = UserDirectorySlaveStore
def _listen_http(self, listener_config):
port = listener_config["port"]

View File

@@ -404,6 +404,23 @@ class ServerConfig(Config):
" service on the given port.")
def is_threepid_reserved(config, threepid):
"""Check the threepid against the reserved threepid config
Args:
config(ServerConfig) - to access server config attributes
threepid(dict) - The threepid to test for
Returns:
boolean Is the threepid undertest reserved_user
"""
for tp in config.mau_limits_reserved_threepids:
if (threepid['medium'] == tp['medium']
and threepid['address'] == tp['address']):
return True
return False
def read_gc_thresholds(thresholds):
"""Reads the three integer thresholds for garbage collection. Ensures that
the thresholds are integers if thresholds are supplied.

View File

@@ -32,7 +32,7 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
from six import iteritems, itervalues
from six import iteritems
from sortedcontainers import SortedDict
@@ -117,7 +117,7 @@ class FederationRemoteSendQueue(object):
user_ids = set(
user_id
for uids in itervalues(self.presence_changed)
for uids in self.presence_changed.values()
for user_id in uids
)

View File

@@ -1831,7 +1831,7 @@ class FederationHandler(BaseHandler):
room_version = yield self.store.get_room_version(event.room_id)
new_state = self.state_handler.resolve_events(
new_state = yield self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())],
event

View File

@@ -125,6 +125,7 @@ class RegistrationHandler(BaseHandler):
guest_access_token=None,
make_guest=False,
admin=False,
threepid=None,
):
"""Registers a new client on the server.
@@ -145,7 +146,7 @@ class RegistrationHandler(BaseHandler):
RegistrationError if there was a problem registering.
"""
yield self.auth.check_auth_blocking()
yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
if password:
password_hash = yield self.auth_handler().hash(password)

View File

@@ -745,9 +745,16 @@ class SyncHandler(object):
state_ids = {}
if lazy_load_members:
if types:
# We're returning an incremental sync, with no "gap" since
# the previous sync, so normally there would be no state to return
# But we're lazy-loading, so the client might need some more
# member events to understand the events in this timeline.
# So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below.
state_ids = yield self.store.get_state_ids_for_event(
batch.events[0].event_id, types=types,
filtered_types=filtered_types,
filtered_types=None, # we only want members!
)
if lazy_load_members and not include_redundant_members:

View File

@@ -78,6 +78,9 @@ CONDITIONAL_REQUIREMENTS = {
"affinity": {
"affinity": ["affinity"],
},
"postgres": {
"psycopg2>=2.6": ["psycopg2"]
}
}

View File

@@ -23,6 +23,7 @@ from twisted.internet import defer
import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.rest.client.v1.base import ClientV1RestServlet
from synapse.types import create_requester
@@ -281,12 +282,20 @@ class RegisterRestServlet(ClientV1RestServlet):
register_json["user"].encode("utf-8")
if "user" in register_json else None
)
threepid = None
if session.get(LoginType.EMAIL_IDENTITY):
threepid = session["threepidCreds"]
handler = self.handlers.registration_handler
(user_id, token) = yield handler.register(
localpart=desired_user_id,
password=password
password=password,
threepid=threepid,
)
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(self.hs.config, threepid):
yield self.store.upsert_monthly_active_user(user_id)
if session[LoginType.EMAIL_IDENTITY]:
logger.debug("Binding emails %s to %s" % (

View File

@@ -53,7 +53,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
403,
"Your email domain is not authorized on this server",
Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
@@ -89,7 +91,9 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid(
@@ -241,7 +245,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
403,
"Your email domain is not authorized on this server",
Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid(
@@ -276,7 +282,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
existingUid = yield self.datastore.get_user_id_by_threepid(

View File

@@ -26,6 +26,7 @@ import synapse
import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -74,7 +75,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", body['email']):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
403,
"Your email domain is not authorized to register on this server",
Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
@@ -114,7 +117,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
403,
"Phone numbers are not authorized to register on this server",
Codes.THREEPID_DENIED,
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
@@ -372,7 +377,9 @@ class RegisterRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, medium, address):
raise SynapseError(
403, "Third party identifier is not allowed",
403,
"Third party identifiers (email/phone numbers)" +
" are not authorized on this server",
Codes.THREEPID_DENIED,
)
@@ -395,12 +402,21 @@ class RegisterRestServlet(RestServlet):
if desired_username is not None:
desired_username = desired_username.lower()
threepid = None
if auth_result:
threepid = auth_result.get(LoginType.EMAIL_IDENTITY)
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
guest_access_token=guest_access_token,
generate_token=False,
threepid=threepid,
)
# Necessary due to auth checks prior to the threepid being
# written to the db
if is_threepid_reserved(self.hs.config, threepid):
yield self.store.upsert_monthly_active_user(registered_user_id)
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)

View File

@@ -19,6 +19,7 @@
# partial one for unit test mocking.
# Imports required for the default HomeServer() implementation
import abc
import logging
from twisted.enterprise import adbapi
@@ -81,7 +82,6 @@ from synapse.server_notices.server_notices_manager import ServerNoticesManager
from synapse.server_notices.server_notices_sender import ServerNoticesSender
from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
@@ -111,6 +111,8 @@ class HomeServer(object):
config (synapse.config.homeserver.HomeserverConfig):
"""
__metaclass__ = abc.ABCMeta
DEPENDENCIES = [
'http_client',
'db_pool',
@@ -172,6 +174,11 @@ class HomeServer(object):
'room_context_handler',
]
# This is overridden in derived application classes
# (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
# instantiated during setup() for future return by get_datastore()
DATASTORE_CLASS = abc.abstractproperty()
def __init__(self, hostname, reactor=None, **kwargs):
"""
Args:
@@ -188,13 +195,16 @@ class HomeServer(object):
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
self.datastore = None
# Other kwargs are explicit dependencies
for depname in kwargs:
setattr(self, depname, kwargs[depname])
def setup(self):
logger.info("Setting up.")
self.datastore = DataStore(self.get_db_conn(), self)
with self.get_db_conn() as conn:
self.datastore = self.DATASTORE_CLASS(conn, self)
logger.info("Finished setting up.")
def get_reactor(self):

View File

@@ -385,6 +385,7 @@ class StateHandler(object):
ev_ids, get_prev_content=False, check_redacted=False,
)
@defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
@@ -401,15 +402,17 @@ class StateHandler(object):
}
with Measure(self.clock, "state._resolve_events"):
new_state = resolve_events_with_state_map(
room_version, state_set_ids, state_map,
new_state = yield resolve_events_with_factory(
room_version, state_set_ids,
event_map=state_map,
state_map_factory=self._state_map_factory
)
new_state = {
key: state_map[ev_id] for key, ev_id in iteritems(new_state)
}
return new_state
defer.returnValue(new_state)
class StateResolutionHandler(object):
@@ -589,31 +592,6 @@ def _make_state_cache_entry(
)
def resolve_events_with_state_map(room_version, state_sets, state_map):
"""
Args:
room_version(str): Version of the room
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map(dict): a dict from event_id to event, for all events in
state_sets.
Returns
dict[(str, str), str]:
a map from (type, state_key) to event_id.
"""
if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
return v1.resolve_events_with_state_map(
state_sets, state_map,
)
else:
# This should only happen if we added a version but forgot to add it to
# the list above.
raise Exception(
"No state resolution algorithm defined for version %r" % (room_version,)
)
def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
"""
Args:

View File

@@ -30,34 +30,6 @@ logger = logging.getLogger(__name__)
POWER_KEY = (EventTypes.PowerLevels, "")
def resolve_events_with_state_map(state_sets, state_map):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map(dict): a dict from event_id to event, for all events in
state_sets.
Returns
dict[(str, str), str]:
a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
@defer.inlineCallbacks
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
"""

View File

@@ -17,9 +17,10 @@ import sys
import threading
import time
from six import iteritems, iterkeys, itervalues
from six import PY2, iteritems, iterkeys, itervalues
from six.moves import intern, range
from canonicaljson import json
from prometheus_client import Histogram
from twisted.internet import defer
@@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
something went wrong.
"""
pass
def db_to_json(db_content):
"""
Take some data from a database row and return a JSON-decoded object.
Args:
db_content (memoryview|buffer|bytes|bytearray|unicode)
"""
# psycopg2 on Python 3 returns memoryview objects, which we need to
# cast to bytes to decode
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
# psycopg2 on Python 2 returns buffer objects, which we need to cast to
# bytes to decode
if PY2 and isinstance(db_content, buffer):
db_content = bytes(db_content)
# Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode('utf8')
try:
return json.loads(db_content)
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise

View File

@@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
devices = messages_by_device.keys()
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = (

View File

@@ -24,7 +24,7 @@ from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from ._base import Cache, SQLBaseStore
from ._base import Cache, SQLBaseStore, db_to_json
logger = logging.getLogger(__name__)
@@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore):
if device is not None:
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore):
retcol="content",
desc="_get_cached_user_device",
)
defer.returnValue(json.loads(content))
defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
@@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore):
desc="_get_cached_devices_for_user",
)
defer.returnValue({
device["device_id"]: json.loads(device["content"])
device["device_id"]: db_to_json(device["content"])
for device in devices
})
@@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name

View File

@@ -14,13 +14,13 @@
# limitations under the License.
from six import iteritems
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
from ._base import SQLBaseStore, db_to_json
class EndToEndKeyStore(SQLBaseStore):
@@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore):
for user_id, device_keys in iteritems(results):
for device_id, device_info in iteritems(device_keys):
device_info["keys"] = json.loads(device_info.pop("key_json"))
device_info["keys"] = db_to_json(device_info.pop("key_json"))
defer.returnValue(results)

View File

@@ -41,13 +41,18 @@ class PostgresEngine(object):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
# Set the bytea output to escape, vs the default of hex
cursor = db_conn.cursor()
cursor.execute("SET bytea_output TO escape")
# Asynchronous commit, don't wait for the server to call fsync before
# ending the transaction.
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
if not self.synchronous_commit:
cursor = db_conn.cursor()
cursor.execute("SET synchronous_commit TO OFF")
cursor.close()
cursor.close()
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):

View File

@@ -19,7 +19,7 @@ import logging
from collections import OrderedDict, deque, namedtuple
from functools import wraps
from six import iteritems
from six import iteritems, text_type
from six.moves import range
from canonicaljson import json
@@ -1220,7 +1220,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
"sender": event.sender,
"contains_url": (
"url" in event.content
and isinstance(event.content["url"], basestring)
and isinstance(event.content["url"], text_type)
),
}
for event, _ in events_and_contexts
@@ -1529,7 +1529,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
contains_url = "url" in content
if contains_url:
contains_url &= isinstance(content["url"], basestring)
contains_url &= isinstance(content["url"], text_type)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@@ -1910,9 +1910,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
(room_id,)
)
rows = txn.fetchall()
max_depth = max(row[0] for row in rows)
max_depth = max(row[1] for row in rows)
if max_depth <= token.topological:
if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)

View File

@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
from collections import namedtuple
@@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
with Measure(self._clock, "_fetch_event_list"):
try:
event_id_lists = zip(*event_list)[0]
event_id_lists = list(zip(*event_list))[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
@@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
def fire(evs):
def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
d.errback(e)
d.errback(exc)
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list)
self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):

View File

@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks
from ._base import SQLBaseStore
from ._base import SQLBaseStore, db_to_json
class FilteringStore(SQLBaseStore):
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
defer.returnValue(db_to_json(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)

View File

@@ -36,7 +36,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
@defer.inlineCallbacks
def initialise_reserved_users(self, threepids):
# TODO Why can't I do this in init?
store = self.hs.get_datastore()
reserved_user_list = []

View File

@@ -15,7 +15,8 @@
# limitations under the License.
import logging
import types
import six
from canonicaljson import encode_canonical_json, json
@@ -27,6 +28,11 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
if six.PY2:
db_binary_type = buffer
else:
db_binary_type = memoryview
class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
@@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore):
dataJson = r['data']
r['data'] = None
try:
if isinstance(dataJson, types.BufferType):
if isinstance(dataJson, db_binary_type):
dataJson = str(dataJson).decode("UTF8")
r['data'] = json.loads(dataJson)
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
r['id'], dataJson, e.message,
r['id'], dataJson, e.args[0],
)
pass
if isinstance(r['pushkey'], types.BufferType):
if isinstance(r['pushkey'], db_binary_type):
r['pushkey'] = str(r['pushkey']).decode("UTF8")
return rows

View File

@@ -18,14 +18,14 @@ from collections import namedtuple
import six
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
@@ -95,7 +95,8 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
return result["response_code"], json.loads(str(result["response_json"]))
return result["response_code"], db_to_json(result["response_json"])
else:
return None

View File

@@ -467,6 +467,23 @@ class AuthTestCase(unittest.TestCase):
)
yield self.auth.check_auth_blocking()
@defer.inlineCallbacks
def test_reserved_threepid(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
threepid = {'medium': 'email', 'address': 'reserved@server.com'}
unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'}
self.hs.config.mau_limits_reserved_threepids = [threepid]
yield self.store.register(user_id='user1', token="123", password_hash=None)
with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking()
with self.assertRaises(ResourceLimitError):
yield self.auth.check_auth_blocking(threepid=unknown_threepid)
yield self.auth.check_auth_blocking(threepid=threepid)
@defer.inlineCallbacks
def test_hs_disabled(self):
self.hs.config.hs_disabled = True

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,79 +14,79 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import synapse.api.errors
import synapse.handlers.device
import synapse.storage
from tests import unittest, utils
from tests import unittest
user1 = "@boris:aaa"
user2 = "@theresa:bbb"
class DeviceTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(DeviceTestCase, self).__init__(*args, **kwargs)
self.store = None # type: synapse.storage.DataStore
self.handler = None # type: synapse.handlers.device.DeviceHandler
self.clock = None # type: utils.MockClock
@defer.inlineCallbacks
def setUp(self):
hs = yield utils.setup_test_homeserver(self.addCleanup)
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver("server", http_client=None)
self.handler = hs.get_device_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
return hs
def prepare(self, reactor, clock, hs):
# These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000)
@defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self):
res = yield self.handler.check_device_registered(
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name",
res = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name",
)
)
self.assertEqual(res, "fco")
dev = yield self.handler.store.get_device("@boris:foo", "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_is_preserved_if_exists(self):
res1 = yield self.handler.check_device_registered(
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name",
res1 = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name",
)
)
self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered(
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="new display name",
res2 = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="new display name",
)
)
self.assertEqual(res2, "fco")
dev = yield self.handler.store.get_device("@boris:foo", "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self):
device_id = yield self.handler.check_device_registered(
user_id="@theresa:foo",
device_id=None,
initial_device_display_name="display",
device_id = self.get_success(
self.handler.check_device_registered(
user_id="@theresa:foo",
device_id=None,
initial_device_display_name="display",
)
)
dev = yield self.handler.store.get_device("@theresa:foo", device_id)
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
self.assertEqual(dev["display_name"], "display")
@defer.inlineCallbacks
def test_get_devices_by_user(self):
yield self._record_users()
self._record_users()
res = self.get_success(self.handler.get_devices_by_user(user1))
res = yield self.handler.get_devices_by_user(user1)
self.assertEqual(3, len(res))
device_map = {d["device_id"]: d for d in res}
self.assertDictContainsSubset(
@@ -119,11 +120,10 @@ class DeviceTestCase(unittest.TestCase):
device_map["abc"],
)
@defer.inlineCallbacks
def test_get_device(self):
yield self._record_users()
self._record_users()
res = yield self.handler.get_device(user1, "abc")
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertDictContainsSubset(
{
"user_id": user1,
@@ -135,59 +135,66 @@ class DeviceTestCase(unittest.TestCase):
res,
)
@defer.inlineCallbacks
def test_delete_device(self):
yield self._record_users()
self._record_users()
# delete the device
yield self.handler.delete_device(user1, "abc")
self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.get_device(user1, "abc")
res = self.handler.get_device(user1, "abc")
self.pump()
self.assertIsInstance(
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
)
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
@defer.inlineCallbacks
def test_update_device(self):
yield self._record_users()
self._record_users()
update = {"display_name": "new display"}
yield self.handler.update_device(user1, "abc", update)
self.get_success(self.handler.update_device(user1, "abc", update))
res = yield self.handler.get_device(user1, "abc")
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "new display")
@defer.inlineCallbacks
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
with self.assertRaises(synapse.api.errors.NotFoundError):
yield self.handler.update_device("user_id", "unknown_device_id", update)
res = self.handler.update_device("user_id", "unknown_device_id", update)
self.pump()
self.assertIsInstance(
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
)
@defer.inlineCallbacks
def _record_users(self):
# check this works for both devices which have a recorded client_ip,
# and those which don't.
yield self._record_user(user1, "xyz", "display 0")
yield self._record_user(user1, "fco", "display 1", "token1", "ip1")
yield self._record_user(user1, "abc", "display 2", "token2", "ip2")
yield self._record_user(user1, "abc", "display 2", "token3", "ip3")
self._record_user(user1, "xyz", "display 0")
self._record_user(user1, "fco", "display 1", "token1", "ip1")
self._record_user(user1, "abc", "display 2", "token2", "ip2")
self._record_user(user1, "abc", "display 2", "token3", "ip3")
yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
self._record_user(user2, "def", "dispkay", "token4", "ip4")
self.reactor.advance(10000)
@defer.inlineCallbacks
def _record_user(
self, user_id, device_id, display_name, access_token=None, ip=None
):
device_id = yield self.handler.check_device_registered(
user_id=user_id,
device_id=device_id,
initial_device_display_name=display_name,
device_id = self.get_success(
self.handler.check_device_registered(
user_id=user_id,
device_id=device_id,
initial_device_display_name=display_name,
)
)
if ip is not None:
yield self.store.insert_client_ip(
user_id, access_token, ip, "user_agent", device_id
self.get_success(
self.store.insert_client_ip(
user_id, access_token, ip, "user_agent", device_id
)
)
self.clock.advance_time(1000)
self.reactor.advance(1000)

View File

@@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -11,89 +12,91 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
from mock import Mock, NonCallableMock
from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred
import attr
from synapse.replication.tcp.client import (
ReplicationClientFactory,
ReplicationClientHandler,
)
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
from tests import unittest
from tests.utils import setup_test_homeserver
class TestReplicationClientHandler(ReplicationClientHandler):
"""Overrides on_rdata so that we can wait for it to happen"""
class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
def __init__(self, store):
super(TestReplicationClientHandler, self).__init__(store)
self._rdata_awaiters = []
def await_replication(self):
d = Deferred()
self._rdata_awaiters.append(d)
return make_deferred_yieldable(d)
def on_rdata(self, stream_name, token, rows):
awaiters = self._rdata_awaiters
self._rdata_awaiters = []
super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
with PreserveLoggingContext():
for a in awaiters:
a.callback(None)
class BaseSlavedStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(
self.addCleanup,
hs = self.setup_test_homeserver(
"blue",
http_client=None,
federation_client=Mock(),
ratelimiter=NonCallableMock(spec_set=["send_message"]),
)
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
hs.get_ratelimiter().send_message.return_value = (True, 0)
return hs
def prepare(self, reactor, clock, hs):
self.master_store = self.hs.get_datastore()
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
# XXX: mktemp is unsafe and should never be used. but we're just a test.
path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
listener = reactor.listenUNIX(path, server_factory)
self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer
self.replication_handler = TestReplicationClientHandler(self.slaved_store)
self.replication_handler = ReplicationClientHandler(self.slaved_store)
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
client_connector = reactor.connectUNIX(path, client_factory)
self.addCleanup(client_factory.stopTrying)
self.addCleanup(client_connector.disconnect)
server = server_factory.buildProtocol(None)
client = client_factory.buildProtocol(None)
@attr.s
class FakeTransport(object):
other = attr.ib()
disconnecting = False
buffer = attr.ib(default=b'')
def registerProducer(self, producer, streaming):
self.producer = producer
def _produce():
self.producer.resumeProducing()
reactor.callLater(0.1, _produce)
reactor.callLater(0.0, _produce)
def write(self, byt):
self.buffer = self.buffer + byt
if getattr(self.other, "transport") is not None:
self.other.dataReceived(self.buffer)
self.buffer = b""
def writeSequence(self, seq):
for x in seq:
self.write(x)
client.makeConnection(FakeTransport(server))
server.makeConnection(FakeTransport(client))
def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
# xxx: should we be more specific in what we wait for?
d = self.replication_handler.await_replication()
self.streamer.on_notifier_poke()
return d
self.pump(0.1)
@defer.inlineCallbacks
def check(self, method, args, expected_result=None):
master_result = yield getattr(self.master_store, method)(*args)
slaved_result = yield getattr(self.slaved_store, method)(*args)
master_result = self.get_success(getattr(self.master_store, method)(*args))
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
if expected_result is not None:
self.assertEqual(master_result, expected_result)
self.assertEqual(slaved_result, expected_result)

View File

@@ -12,9 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from ._base import BaseSlavedStoreTestCase
@@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedAccountDataStore
@defer.inlineCallbacks
def test_user_account_data(self):
yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
yield self.replicate()
yield self.check(
self.get_success(
self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
)
self.replicate()
self.check(
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1}
)
yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
yield self.replicate()
yield self.check(
self.get_success(
self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
)
self.replicate()
self.check(
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2}
)

View File

@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.events import FrozenEvent, _EventInternalMetadata
from synapse.events.snapshot import EventContext
from synapse.replication.slave.storage.events import SlavedEventStore
@@ -55,70 +53,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def tearDown(self):
[unpatch() for unpatch in self.unpatches]
@defer.inlineCallbacks
def test_get_latest_event_ids_in_room(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
create = self.persist(type="m.room.create", key="", creator=USER_ID)
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
join = yield self.persist(
join = self.persist(
type="m.room.member",
key=USER_ID,
membership="join",
prev_events=[(create.event_id, {})],
)
yield self.replicate()
yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
@defer.inlineCallbacks
def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
self.replicate()
self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id)
yield self.replicate()
redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
self.check("get_event", [msg.event_id], redacted)
@defer.inlineCallbacks
def test_backfilled_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.member", key=USER_ID, membership="join")
msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
yield self.replicate()
yield self.check("get_event", [msg.event_id], msg)
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
self.replicate()
self.check("get_event", [msg.event_id], msg)
redaction = yield self.persist(
redaction = self.persist(
type="m.room.redaction", redacts=msg.event_id, backfill=True
)
yield self.replicate()
self.replicate()
msg_dict = msg.get_dict()
msg_dict["content"] = {}
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
msg_dict["unsigned"]["redacted_because"] = redaction
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
yield self.check("get_event", [msg.event_id], redacted)
self.check("get_event", [msg.event_id], redacted)
@defer.inlineCallbacks
def test_invites(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
event = yield self.persist(
type="m.room.member", key=USER_ID_2, membership="invite"
)
yield self.replicate()
yield self.check(
self.persist(type="m.room.create", key="", creator=USER_ID)
self.check("get_invited_rooms_for_user", [USER_ID_2], [])
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
self.replicate()
self.check(
"get_invited_rooms_for_user",
[USER_ID_2],
[
@@ -132,37 +126,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
],
)
@defer.inlineCallbacks
def test_push_actions_for_user(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.persist(type="m.room.join", key=USER_ID, membership="join")
yield self.persist(
self.persist(type="m.room.create", key="", creator=USER_ID)
self.persist(type="m.room.join", key=USER_ID, membership="join")
self.persist(
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
)
event1 = yield self.persist(
type="m.room.message", msgtype="m.text", body="hello"
)
yield self.replicate()
yield self.check(
event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 0},
)
yield self.persist(
self.persist(
type="m.room.message",
msgtype="m.text",
body="world",
push_actions=[(USER_ID_2, ["notify"])],
)
yield self.replicate()
yield self.check(
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 1},
)
yield self.persist(
self.persist(
type="m.room.message",
msgtype="m.text",
body="world",
@@ -170,8 +161,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
(USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
],
)
yield self.replicate()
yield self.check(
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 1, "notify_count": 2},
@@ -179,7 +170,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event_id = 0
@defer.inlineCallbacks
def persist(
self,
sender=USER_ID,
@@ -206,8 +196,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
depth = self.event_id
if not prev_events:
latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
room_id
latest_event_ids = self.get_success(
self.master_store.get_latest_event_ids_in_room(room_id)
)
prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
@@ -240,19 +230,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
else:
state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event)
context = self.get_success(state_handler.compute_event_context(event))
yield self.master_store.add_push_actions_to_staging(
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
)
ordering = None
if backfill:
yield self.master_store.persist_events([(event, context)], backfilled=True)
self.get_success(
self.master_store.persist_events([(event, context)], backfilled=True)
)
else:
ordering, _ = yield self.master_store.persist_event(event, context)
ordering, _ = self.get_success(
self.master_store.persist_event(event, context)
)
if ordering:
event.internal_metadata.stream_ordering = ordering
defer.returnValue(event)
return event

View File

@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from ._base import BaseSlavedStoreTestCase
@@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedReceiptsStore
@defer.inlineCallbacks
def test_receipt(self):
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
yield self.master_store.insert_receipt(
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
)
yield self.replicate()
yield self.check(
"get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}
self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
self.get_success(
self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {})
)
self.replicate()
self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID})

View File

@@ -240,7 +240,6 @@ class RestHelper(object):
self.assertEquals(200, code)
defer.returnValue(response)
@defer.inlineCallbacks
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
@@ -248,9 +247,16 @@ class RestHelper(object):
body = "body_text_here"
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
content = '{"msgtype":"m.text","body":"%s"}' % body
content = {"msgtype": "m.text", "body": body}
if tok:
path = path + "?access_token=%s" % tok
(code, response) = yield self.mock_resource.trigger("PUT", path, content)
self.assertEquals(expect_code, code, msg=str(response))
request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
% (expect_code, int(channel.result["code"]), channel.result["body"])
)
return channel.json_body

View File

@@ -232,6 +232,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
clock.threadpool = ThreadPool()
pool.threadpool = ThreadPool()
pool.running = True
return d

View File

@@ -37,18 +37,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.as_yaml_files = []
config = Mock(
app_service_config_files=self.as_yaml_files,
event_cache_size=1,
password_providers=[],
)
hs = yield setup_test_homeserver(
self.addCleanup,
config=config,
federation_sender=Mock(),
federation_client=Mock(),
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
hs.config.app_service_config_files = self.as_yaml_files
hs.config.event_cache_size = 1
hs.config.password_providers = []
self.as_token = "token1"
self.as_url = "some_url"
self.as_id = "as1"
@@ -58,7 +54,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
# must be done after inserts
self.store = ApplicationServiceStore(None, hs)
self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
def tearDown(self):
# TODO: suboptimal that we need to create files for tests!
@@ -105,18 +101,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
def setUp(self):
self.as_yaml_files = []
config = Mock(
app_service_config_files=self.as_yaml_files,
event_cache_size=1,
password_providers=[],
)
hs = yield setup_test_homeserver(
self.addCleanup,
config=config,
federation_sender=Mock(),
federation_client=Mock(),
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
hs.config.app_service_config_files = self.as_yaml_files
hs.config.event_cache_size = 1
hs.config.password_providers = []
self.db_pool = hs.get_db_pool()
self.engine = hs.database_engine
self.as_list = [
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
@@ -129,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files = []
self.store = TestTransactionStore(None, hs)
self.store = TestTransactionStore(hs.get_db_conn(), hs)
def _add_service(self, url, as_token, id):
as_yaml = dict(
@@ -146,29 +140,35 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.as_yaml_files.append(as_token)
def _set_state(self, id, state, txn=None):
return self.db_pool.runQuery(
"INSERT INTO application_services_state(as_id, state, last_txn) "
"VALUES(?,?,?)",
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_state(as_id, state, last_txn) "
"VALUES(?,?,?)"
),
(id, state, txn),
)
def _insert_txn(self, as_id, txn_id, events):
return self.db_pool.runQuery(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)"
),
(as_id, txn_id, json.dumps([e.event_id for e in events])),
)
def _set_last_txn(self, as_id, txn_id):
return self.db_pool.runQuery(
"INSERT INTO application_services_state(as_id, last_txn, state) "
"VALUES(?,?,?)",
return self.db_pool.runOperation(
self.engine.convert_param_style(
"INSERT INTO application_services_state(as_id, last_txn, state) "
"VALUES(?,?,?)"
),
(as_id, txn_id, ApplicationServiceState.UP),
)
@defer.inlineCallbacks
def test_get_appservice_state_none(self):
service = Mock(id=999)
service = Mock(id="999")
state = yield self.store.get_appservice_state(service)
self.assertEquals(None, state)
@@ -200,7 +200,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
service = Mock(id=self.as_list[1]["id"])
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
rows = yield self.db_pool.runQuery(
"SELECT as_id FROM application_services_state WHERE state=?",
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
),
(ApplicationServiceState.DOWN,),
)
self.assertEquals(service.id, rows[0][0])
@@ -212,7 +214,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
rows = yield self.db_pool.runQuery(
"SELECT as_id FROM application_services_state WHERE state=?",
self.engine.convert_param_style(
"SELECT as_id FROM application_services_state WHERE state=?"
),
(ApplicationServiceState.UP,),
)
self.assertEquals(service.id, rows[0][0])
@@ -279,14 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
self.engine.convert_param_style(
"SELECT last_txn FROM application_services_state WHERE as_id=?"
),
(service.id,),
)
self.assertEquals(1, len(res))
self.assertEquals(txn_id, res[0][0])
res = yield self.db_pool.runQuery(
"SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
self.engine.convert_param_style(
"SELECT * FROM application_services_txns WHERE txn_id=?"
),
(txn_id,),
)
self.assertEquals(0, len(res))
@@ -300,7 +309,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
res = yield self.db_pool.runQuery(
"SELECT last_txn, state FROM application_services_state WHERE " "as_id=?",
self.engine.convert_param_style(
"SELECT last_txn, state FROM application_services_state WHERE as_id=?"
),
(service.id,),
)
self.assertEquals(1, len(res))
@@ -308,7 +319,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
self.assertEquals(ApplicationServiceState.UP, res[0][1])
res = yield self.db_pool.runQuery(
"SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
self.engine.convert_param_style(
"SELECT * FROM application_services_txns WHERE txn_id=?"
),
(txn_id,),
)
self.assertEquals(0, len(res))
@@ -394,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(suffix="1")
f2 = self._write_config(suffix="2")
config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
)
hs = yield setup_test_homeserver(
self.addCleanup,
config=config,
datastore=Mock(),
federation_sender=Mock(),
federation_client=Mock(),
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
ApplicationServiceStore(None, hs)
hs.config.app_service_config_files = [f1, f2]
hs.config.event_cache_size = 1
hs.config.password_providers = []
ApplicationServiceStore(hs.get_db_conn(), hs)
@defer.inlineCallbacks
def test_duplicate_ids(self):
f1 = self._write_config(id="id", suffix="1")
f2 = self._write_config(id="id", suffix="2")
config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
)
hs = yield setup_test_homeserver(
self.addCleanup,
config=config,
datastore=Mock(),
federation_sender=Mock(),
federation_client=Mock(),
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
hs.config.app_service_config_files = [f1, f2]
hs.config.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(None, hs)
ApplicationServiceStore(hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))
@@ -436,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
f1 = self._write_config(as_token="as_token", suffix="1")
f2 = self._write_config(as_token="as_token", suffix="2")
config = Mock(
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
)
hs = yield setup_test_homeserver(
self.addCleanup,
config=config,
datastore=Mock(),
federation_sender=Mock(),
federation_client=Mock(),
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
)
hs.config.app_service_config_files = [f1, f2]
hs.config.event_cache_size = 1
hs.config.password_providers = []
with self.assertRaises(ConfigError) as cm:
ApplicationServiceStore(None, hs)
ApplicationServiceStore(hs.get_db_conn(), hs)
e = cm.exception
self.assertIn(f1, str(e))

View File

@@ -20,11 +20,11 @@ from mock import Mock
from twisted.internet import defer
from synapse.server import HomeServer
from synapse.storage._base import SQLBaseStore
from synapse.storage.engines import create_engine
from tests import unittest
from tests.utils import TestHomeServer
class SQLBaseStoreTestCase(unittest.TestCase):
@@ -51,7 +51,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
config = Mock()
config.event_cache_size = 1
config.database_config = {"name": "sqlite3"}
hs = HomeServer(
hs = TestHomeServer(
"test",
db_pool=self.db_pool,
config=config,

View File

@@ -16,7 +16,6 @@
from twisted.internet import defer
from synapse.storage.directory import DirectoryStore
from synapse.types import RoomAlias, RoomID
from tests import unittest
@@ -28,7 +27,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
self.store = DirectoryStore(None, hs)
self.store = hs.get_datastore()
self.room = RoomID.from_string("!abcde:test")
self.alias = RoomAlias.from_string("#my-room:test")

View File

@@ -37,10 +37,10 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
(
"INSERT INTO events ("
" room_id, event_id, type, depth, topological_ordering,"
" content, processed, outlier) "
"VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
" content, processed, outlier, stream_ordering) "
"VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)"
),
(room_id, event_id, i, i, True, False),
(room_id, event_id, i, i, True, False, i),
)
txn.execute(

View File

@@ -13,25 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
import tests.unittest
import tests.utils
from tests.utils import setup_test_homeserver
from tests.unittest import HomeserverTestCase
FORTY_DAYS = 40 * 24 * 60 * 60
class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
def __init__(self, *args, **kwargs):
super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs)
class MonthlyActiveUsersTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup)
self.store = self.hs.get_datastore()
hs = self.setup_test_homeserver()
self.store = hs.get_datastore()
# Advance the clock a bit
reactor.advance(FORTY_DAYS)
return hs
@defer.inlineCallbacks
def test_initialise_reserved_users(self):
self.hs.config.max_mau_value = 5
user1 = "@user1:server"
@@ -44,88 +41,101 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
]
user_num = len(threepids)
yield self.store.register(user_id=user1, token="123", password_hash=None)
yield self.store.register(user_id=user2, token="456", password_hash=None)
self.store.register(user_id=user1, token="123", password_hash=None)
self.store.register(user_id=user2, token="456", password_hash=None)
self.pump()
now = int(self.hs.get_clock().time_msec())
yield self.store.user_add_threepid(user1, "email", user1_email, now, now)
yield self.store.user_add_threepid(user2, "email", user2_email, now, now)
yield self.store.initialise_reserved_users(threepids)
self.store.user_add_threepid(user1, "email", user1_email, now, now)
self.store.user_add_threepid(user2, "email", user2_email, now, now)
self.store.initialise_reserved_users(threepids)
self.pump()
active_count = yield self.store.get_monthly_active_count()
active_count = self.store.get_monthly_active_count()
# Test total counts
self.assertEquals(active_count, user_num)
self.assertEquals(self.get_success(active_count), user_num)
# Test user is marked as active
timestamp = yield self.store.user_last_seen_monthly_active(user1)
self.assertTrue(timestamp)
timestamp = yield self.store.user_last_seen_monthly_active(user2)
self.assertTrue(timestamp)
timestamp = self.store.user_last_seen_monthly_active(user1)
self.assertTrue(self.get_success(timestamp))
timestamp = self.store.user_last_seen_monthly_active(user2)
self.assertTrue(self.get_success(timestamp))
# Test that users are never removed from the db.
self.hs.config.max_mau_value = 0
self.hs.get_clock().advance_time(FORTY_DAYS)
self.reactor.advance(FORTY_DAYS)
yield self.store.reap_monthly_active_users()
self.store.reap_monthly_active_users()
self.pump()
active_count = yield self.store.get_monthly_active_count()
self.assertEquals(active_count, user_num)
active_count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(active_count), user_num)
# Test that regalar users are removed from the db
ru_count = 2
yield self.store.upsert_monthly_active_user("@ru1:server")
yield self.store.upsert_monthly_active_user("@ru2:server")
active_count = yield self.store.get_monthly_active_count()
self.store.upsert_monthly_active_user("@ru1:server")
self.store.upsert_monthly_active_user("@ru2:server")
self.pump()
self.assertEqual(active_count, user_num + ru_count)
active_count = self.store.get_monthly_active_count()
self.assertEqual(self.get_success(active_count), user_num + ru_count)
self.hs.config.max_mau_value = user_num
yield self.store.reap_monthly_active_users()
self.store.reap_monthly_active_users()
self.pump()
active_count = yield self.store.get_monthly_active_count()
self.assertEquals(active_count, user_num)
active_count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(active_count), user_num)
@defer.inlineCallbacks
def test_can_insert_and_count_mau(self):
count = yield self.store.get_monthly_active_count()
self.assertEqual(0, count)
count = self.store.get_monthly_active_count()
self.assertEqual(0, self.get_success(count))
yield self.store.upsert_monthly_active_user("@user:server")
count = yield self.store.get_monthly_active_count()
self.store.upsert_monthly_active_user("@user:server")
self.pump()
self.assertEqual(1, count)
count = self.store.get_monthly_active_count()
self.assertEqual(1, self.get_success(count))
@defer.inlineCallbacks
def test_user_last_seen_monthly_active(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
user_id3 = "@user3:server"
result = yield self.store.user_last_seen_monthly_active(user_id1)
self.assertFalse(result == 0)
yield self.store.upsert_monthly_active_user(user_id1)
yield self.store.upsert_monthly_active_user(user_id2)
result = yield self.store.user_last_seen_monthly_active(user_id1)
self.assertTrue(result > 0)
result = yield self.store.user_last_seen_monthly_active(user_id3)
self.assertFalse(result == 0)
result = self.store.user_last_seen_monthly_active(user_id1)
self.assertFalse(self.get_success(result) == 0)
self.store.upsert_monthly_active_user(user_id1)
self.store.upsert_monthly_active_user(user_id2)
self.pump()
result = self.store.user_last_seen_monthly_active(user_id1)
self.assertGreater(self.get_success(result), 0)
result = self.store.user_last_seen_monthly_active(user_id3)
self.assertNotEqual(self.get_success(result), 0)
@defer.inlineCallbacks
def test_reap_monthly_active_users(self):
self.hs.config.max_mau_value = 5
initial_users = 10
for i in range(initial_users):
yield self.store.upsert_monthly_active_user("@user%d:server" % i)
count = yield self.store.get_monthly_active_count()
self.assertTrue(count, initial_users)
yield self.store.reap_monthly_active_users()
count = yield self.store.get_monthly_active_count()
self.assertEquals(count, initial_users - self.hs.config.max_mau_value)
self.store.upsert_monthly_active_user("@user%d:server" % i)
self.pump()
self.hs.get_clock().advance_time(FORTY_DAYS)
yield self.store.reap_monthly_active_users()
count = yield self.store.get_monthly_active_count()
self.assertEquals(count, 0)
count = self.store.get_monthly_active_count()
self.assertTrue(self.get_success(count), initial_users)
self.store.reap_monthly_active_users()
self.pump()
count = self.store.get_monthly_active_count()
self.assertEquals(
self.get_success(count), initial_users - self.hs.config.max_mau_value
)
self.reactor.advance(FORTY_DAYS)
self.store.reap_monthly_active_users()
self.pump()
count = self.store.get_monthly_active_count()
self.assertEquals(self.get_success(count), 0)

View File

@@ -16,19 +16,18 @@
from twisted.internet import defer
from synapse.storage.presence import PresenceStore
from synapse.types import UserID
from tests import unittest
from tests.utils import MockClock, setup_test_homeserver
from tests.utils import setup_test_homeserver
class PresenceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock())
hs = yield setup_test_homeserver(self.addCleanup)
self.store = PresenceStore(None, hs)
self.store = hs.get_datastore()
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")

View File

@@ -28,7 +28,7 @@ class ProfileStoreTestCase(unittest.TestCase):
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
self.store = ProfileStore(None, hs)
self.store = ProfileStore(hs.get_db_conn(), hs)
self.u_frank = UserID.from_string("@frank:test")

106
tests/storage/test_purge.py Normal file
View File

@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
class PurgeTests(HomeserverTestCase):
user_id = "@red:server"
servlets = [room.register_servlets]
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver("server", http_client=None)
return hs
def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
def test_purge(self):
"""
Purging a room will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
second = self.helper.send(self.room_id, body="test2")
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
storage = self.hs.get_datastore()
# Get the topological token
event = storage.get_topological_token_for_event(last["event_id"])
self.pump()
event = self.successResultOf(event)
# Purge everything before this topological token
purge = storage.purge_history(self.room_id, event, True)
self.pump()
self.assertEqual(self.successResultOf(purge), None)
# Try and get the events
get_first = storage.get_event(first["event_id"])
get_second = storage.get_event(second["event_id"])
get_third = storage.get_event(third["event_id"])
get_last = storage.get_event(last["event_id"])
self.pump()
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
self.failureResultOf(get_first)
self.failureResultOf(get_second)
self.failureResultOf(get_third)
self.successResultOf(get_last)
def test_purge_wont_delete_extrems(self):
"""
Purging a room will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
second = self.helper.send(self.room_id, body="test2")
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
event = storage.get_topological_token_for_event(last["event_id"])
self.pump()
event = self.successResultOf(event)
event = "t{}-{}".format(
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
)
# Purge everything before this topological token
purge = storage.purge_history(self.room_id, event, True)
self.pump()
f = self.failureResultOf(purge)
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
get_first = storage.get_event(first["event_id"])
get_second = storage.get_event(second["event_id"])
get_third = storage.get_event(third["event_id"])
get_last = storage.get_event(last["event_id"])
self.pump()
# Nothing is deleted.
self.successResultOf(get_first)
self.successResultOf(get_second)
self.successResultOf(get_third)
self.successResultOf(get_last)

View File

@@ -30,7 +30,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield setup_test_homeserver(self.addCleanup)
self.store = UserDirectoryStore(None, self.hs)
self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
# alice and bob are both in !room_id. bobby is not but shares
# a homeserver with alice.

View File

@@ -14,12 +14,12 @@
# limitations under the License.
from synapse.api.errors import SynapseError
from synapse.server import HomeServer
from synapse.types import GroupID, RoomAlias, UserID
from tests import unittest
from tests.utils import TestHomeServer
mock_homeserver = HomeServer(hostname="my.domain")
mock_homeserver = TestHomeServer(hostname="my.domain")
class UserIDTestCase(unittest.TestCase):

View File

@@ -96,7 +96,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
events_to_filter.append(evt)
# the erasey user gets erased
self.hs.get_datastore().mark_user_erased("@erased:local_hs")
yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
# ... and the filtering happens.
filtered = yield filter_events_for_server(

View File

@@ -22,6 +22,7 @@ from canonicaljson import json
import twisted
import twisted.logger
from twisted.internet.defer import Deferred
from twisted.trial import unittest
from synapse.http.server import JsonResource
@@ -151,6 +152,7 @@ class HomeserverTestCase(TestCase):
hijack_auth (bool): Whether to hijack auth to return the user specified
in user_id.
"""
servlets = []
hijack_auth = True
@@ -279,3 +281,15 @@ class HomeserverTestCase(TestCase):
kwargs = dict(kwargs)
kwargs.update(self._hs_args)
return setup_test_homeserver(self.addCleanup, *args, **kwargs)
def pump(self, by=0.0):
"""
Pump the reactor enough that Deferreds will fire.
"""
self.reactor.pump([by] * 100)
def get_success(self, d):
if not isinstance(d, Deferred):
return d
self.pump()
return self.successResultOf(d)

View File

@@ -26,11 +26,12 @@ from twisted.internet import defer, reactor
from synapse.api.constants import EventTypes
from synapse.api.errors import CodeMessageException, cs_error
from synapse.config.server import ServerConfig
from synapse.federation.transport import server
from synapse.http.server import HttpServer
from synapse.server import HomeServer
from synapse.storage import PostgresEngine
from synapse.storage.engines import create_engine
from synapse.storage import DataStore
from synapse.storage.engines import PostgresEngine, create_engine
from synapse.storage.prepare_database import (
_get_or_create_schema_state,
_setup_new_database,
@@ -41,6 +42,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter
# set this to True to run the tests against postgres instead of sqlite.
USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
@@ -92,10 +94,14 @@ def setupdb():
atexit.register(_cleanup)
class TestHomeServer(HomeServer):
DATASTORE_CLASS = DataStore
@defer.inlineCallbacks
def setup_test_homeserver(
cleanup_func, name="test", datastore=None, config=None, reactor=None,
homeserverToUse=HomeServer, **kargs
homeserverToUse=TestHomeServer, **kargs
):
"""
Setup a homeserver suitable for running tests against. Keyword arguments
@@ -143,6 +149,8 @@ def setup_test_homeserver(
config.max_mau_value = 50
config.mau_limits_reserved_threepids = []
config.admin_contact = None
config.rc_messages_per_second = 10000
config.rc_message_burst_count = 10000
# we need a sane default_room_version, otherwise attempts to create rooms will
# fail.
@@ -152,6 +160,11 @@ def setup_test_homeserver(
# background, which upsets the test runner.
config.update_user_directory = False
def is_threepid_reserved(threepid):
return ServerConfig.is_threepid_reserved(config, threepid)
config.is_threepid_reserved.side_effect = is_threepid_reserved
config.use_frozen_dicts = True
config.ldap_enabled = False
@@ -232,8 +245,9 @@ def setup_test_homeserver(
cur.close()
db_conn.close()
# Register the cleanup hook
cleanup_func(cleanup)
if not LEAVE_DB:
# Register the cleanup hook
cleanup_func(cleanup)
hs.setup()
else: