Compare commits
31 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6cf7d9d9a | ||
|
|
70fc599ede | ||
|
|
bae37cd811 | ||
|
|
c42f7fd7b9 | ||
|
|
77055dba92 | ||
|
|
567363e497 | ||
|
|
06ee2b7cc5 | ||
|
|
30cf40ff30 | ||
|
|
905f8de673 | ||
|
|
4fc4b881c5 | ||
|
|
99178f8602 | ||
|
|
301cb60d0b | ||
|
|
0b01281e77 | ||
|
|
e8e540630e | ||
|
|
a796bdd35e | ||
|
|
09f3cf1a7e | ||
|
|
3d6aa06577 | ||
|
|
ea068d6f3c | ||
|
|
14e4d4f4bf | ||
|
|
475253a88e | ||
|
|
7f0399586d | ||
|
|
8c0c51ecb3 | ||
|
|
79a8a347a6 | ||
|
|
82276a18d1 | ||
|
|
b1580f50fe | ||
|
|
414fa36f3e | ||
|
|
32eb1dedd2 | ||
|
|
71990b3cae | ||
|
|
0b07f02e19 | ||
|
|
5c261107c9 | ||
|
|
76c80e3fdf |
@@ -35,10 +35,6 @@ matrix:
|
||||
- python: 3.6
|
||||
env: TOX_ENV=check-newsfragment
|
||||
|
||||
allow_failures:
|
||||
- python: 2.7
|
||||
env: TOX_ENV=py27-postgres TRIAL_FLAGS="-j 4"
|
||||
|
||||
install:
|
||||
- pip install tox
|
||||
|
||||
|
||||
@@ -59,9 +59,10 @@ To create a changelog entry, make a new file in the ``changelog.d``
|
||||
file named in the format of ``PRnumber.type``. The type can be
|
||||
one of ``feature``, ``bugfix``, ``removal`` (also used for
|
||||
deprecations), or ``misc`` (for internal-only changes). The content of
|
||||
the file is your changelog entry, which can contain RestructuredText
|
||||
formatting. A note of contributors is welcomed in changelogs for
|
||||
non-misc changes (the content of misc changes is not displayed).
|
||||
the file is your changelog entry, which can contain Markdown
|
||||
formatting. Adding credits to the changelog is encouraged, we value
|
||||
your contributions and would like to have you shouted out in the
|
||||
release notes!
|
||||
|
||||
For example, a fix in PR #1234 would have its changelog entry in
|
||||
``changelog.d/1234.bugfix``, and contain content like "The security levels of
|
||||
|
||||
@@ -167,11 +167,6 @@ Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a
|
||||
Dockerfile to automate a synapse server in a single Docker image, at
|
||||
https://hub.docker.com/r/avhost/docker-matrix/tags/
|
||||
|
||||
Also, Martin Giess has created an auto-deployment process with vagrant/ansible,
|
||||
tested with VirtualBox/AWS/DigitalOcean - see
|
||||
https://github.com/EMnify/matrix-synapse-auto-deploy
|
||||
for details.
|
||||
|
||||
Configuring synapse
|
||||
-------------------
|
||||
|
||||
|
||||
1
changelog.d/3378.misc
Normal file
1
changelog.d/3378.misc
Normal file
@@ -0,0 +1 @@
|
||||
Removed the link to the unmaintained matrix-synapse-auto-deploy project from the readme.
|
||||
1
changelog.d/3725.misc
Normal file
1
changelog.d/3725.misc
Normal file
@@ -0,0 +1 @@
|
||||
The synapse.storage module has been ported to Python 3.
|
||||
1
changelog.d/3730.misc
Normal file
1
changelog.d/3730.misc
Normal file
@@ -0,0 +1 @@
|
||||
The CONTRIBUTING guidelines have been updated to mention our use of Markdown and that .misc files have content.
|
||||
1
changelog.d/3737.misc
Normal file
1
changelog.d/3737.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove redundant state resolution function
|
||||
1
changelog.d/3740.misc
Normal file
1
changelog.d/3740.misc
Normal file
@@ -0,0 +1 @@
|
||||
The test suite now passes on PostgreSQL.
|
||||
1
changelog.d/3760.bugfix
Normal file
1
changelog.d/3760.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Don't return non-LL-member state in incremental sync state blocks
|
||||
1
changelog.d/3764.misc
Normal file
1
changelog.d/3764.misc
Normal file
@@ -0,0 +1 @@
|
||||
Make sure that we close db connections opened during init
|
||||
1
changelog.d/3768.bugfix
Normal file
1
changelog.d/3768.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix bug in sending presence over federation
|
||||
1
changelog.d/3777.bugfix
Normal file
1
changelog.d/3777.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix bug where preserved threepid user comes to sign up and server is mau blocked
|
||||
1
changelog.d/3789.misc
Normal file
1
changelog.d/3789.misc
Normal file
@@ -0,0 +1 @@
|
||||
Improve human readable error messages for threepid registration/account update
|
||||
@@ -31,5 +31,5 @@ $TOX_BIN/pip install 'setuptools>=18.5'
|
||||
$TOX_BIN/pip install 'pip>=10'
|
||||
|
||||
{ python synapse/python_dependencies.py
|
||||
echo lxml psycopg2
|
||||
echo lxml
|
||||
} | xargs $TOX_BIN/pip install
|
||||
|
||||
@@ -26,6 +26,7 @@ import synapse.types
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.api.errors import AuthError, Codes, ResourceLimitError
|
||||
from synapse.config.server import is_threepid_reserved
|
||||
from synapse.types import UserID
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
@@ -775,13 +776,19 @@ class Auth(object):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth_blocking(self, user_id=None):
|
||||
def check_auth_blocking(self, user_id=None, threepid=None):
|
||||
"""Checks if the user should be rejected for some external reason,
|
||||
such as monthly active user limiting or global disable flag
|
||||
|
||||
Args:
|
||||
user_id(str|None): If present, checks for presence against existing
|
||||
MAU cohort
|
||||
|
||||
threepid(dict|None): If present, checks for presence against configured
|
||||
reserved threepid. Used in cases where the user is trying register
|
||||
with a MAU blocked server, normally they would be rejected but their
|
||||
threepid is on the reserved list. user_id and
|
||||
threepid should never be set at the same time.
|
||||
"""
|
||||
|
||||
# Never fail an auth check for the server notices users
|
||||
@@ -797,6 +804,8 @@ class Auth(object):
|
||||
limit_type=self.hs.config.hs_disabled_limit_type
|
||||
)
|
||||
if self.hs.config.limit_usage_by_mau is True:
|
||||
assert not (user_id and threepid)
|
||||
|
||||
# If the user is already part of the MAU cohort or a trial user
|
||||
if user_id:
|
||||
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
|
||||
@@ -806,12 +815,16 @@ class Auth(object):
|
||||
is_trial = yield self.store.is_trial_user(user_id)
|
||||
if is_trial:
|
||||
return
|
||||
elif threepid:
|
||||
# If the user does not exist yet, but is signing up with a
|
||||
# reserved threepid then pass auth check
|
||||
if is_threepid_reserved(self.hs.config, threepid):
|
||||
return
|
||||
# Else if there is no room in the MAU bucket, bail
|
||||
current_mau = yield self.store.get_monthly_active_count()
|
||||
if current_mau >= self.hs.config.max_mau_value:
|
||||
raise ResourceLimitError(
|
||||
403, "Monthly Active User Limit Exceeded",
|
||||
|
||||
admin_contact=self.hs.config.admin_contact,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
|
||||
limit_type="monthly_active_user"
|
||||
|
||||
@@ -51,10 +51,7 @@ class AppserviceSlaveStore(
|
||||
|
||||
|
||||
class AppserviceServer(HomeServer):
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
DATASTORE_CLASS = AppserviceSlaveStore
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
|
||||
@@ -74,10 +74,7 @@ class ClientReaderSlavedStore(
|
||||
|
||||
|
||||
class ClientReaderServer(HomeServer):
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
DATASTORE_CLASS = ClientReaderSlavedStore
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
|
||||
@@ -90,10 +90,7 @@ class EventCreatorSlavedStore(
|
||||
|
||||
|
||||
class EventCreatorServer(HomeServer):
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
DATASTORE_CLASS = EventCreatorSlavedStore
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
|
||||
@@ -72,10 +72,7 @@ class FederationReaderSlavedStore(
|
||||
|
||||
|
||||
class FederationReaderServer(HomeServer):
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
DATASTORE_CLASS = FederationReaderSlavedStore
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
|
||||
@@ -78,10 +78,7 @@ class FederationSenderSlaveStore(
|
||||
|
||||
|
||||
class FederationSenderServer(HomeServer):
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
DATASTORE_CLASS = FederationSenderSlaveStore
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
|
||||
@@ -148,10 +148,7 @@ class FrontendProxySlavedStore(
|
||||
|
||||
|
||||
class FrontendProxyServer(HomeServer):
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
DATASTORE_CLASS = FrontendProxySlavedStore
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
|
||||
@@ -62,7 +62,7 @@ from synapse.rest.key.v1.server_key_resource import LocalKey
|
||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage import are_all_users_on_domain
|
||||
from synapse.storage import DataStore, are_all_users_on_domain
|
||||
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
|
||||
from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||
@@ -111,6 +111,8 @@ def build_resource_for_web_client(hs):
|
||||
|
||||
|
||||
class SynapseHomeServer(HomeServer):
|
||||
DATASTORE_CLASS = DataStore
|
||||
|
||||
def _listener_http(self, config, listener_config):
|
||||
port = listener_config["port"]
|
||||
bind_addresses = listener_config["bind_addresses"]
|
||||
@@ -356,13 +358,13 @@ def setup(config_options):
|
||||
logger.info("Preparing database: %s...", config.database_config['name'])
|
||||
|
||||
try:
|
||||
db_conn = hs.get_db_conn(run_new_connection=False)
|
||||
prepare_database(db_conn, database_engine, config=config)
|
||||
database_engine.on_new_connection(db_conn)
|
||||
with hs.get_db_conn(run_new_connection=False) as db_conn:
|
||||
prepare_database(db_conn, database_engine, config=config)
|
||||
database_engine.on_new_connection(db_conn)
|
||||
|
||||
hs.run_startup_checks(db_conn, database_engine)
|
||||
hs.run_startup_checks(db_conn, database_engine)
|
||||
|
||||
db_conn.commit()
|
||||
db_conn.commit()
|
||||
except UpgradeDatabaseException:
|
||||
sys.stderr.write(
|
||||
"\nFailed to upgrade database.\n"
|
||||
|
||||
@@ -60,10 +60,7 @@ class MediaRepositorySlavedStore(
|
||||
|
||||
|
||||
class MediaRepositoryServer(HomeServer):
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
DATASTORE_CLASS = MediaRepositorySlavedStore
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
|
||||
@@ -78,10 +78,7 @@ class PusherSlaveStore(
|
||||
|
||||
|
||||
class PusherServer(HomeServer):
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
DATASTORE_CLASS = PusherSlaveStore
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
@@ -249,10 +249,7 @@ class SynchrotronApplicationService(object):
|
||||
|
||||
|
||||
class SynchrotronServer(HomeServer):
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
DATASTORE_CLASS = SynchrotronSlavedStore
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
|
||||
@@ -94,10 +94,7 @@ class UserDirectorySlaveStore(
|
||||
|
||||
|
||||
class UserDirectoryServer(HomeServer):
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
DATASTORE_CLASS = UserDirectorySlaveStore
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
|
||||
@@ -404,6 +404,23 @@ class ServerConfig(Config):
|
||||
" service on the given port.")
|
||||
|
||||
|
||||
def is_threepid_reserved(config, threepid):
|
||||
"""Check the threepid against the reserved threepid config
|
||||
Args:
|
||||
config(ServerConfig) - to access server config attributes
|
||||
threepid(dict) - The threepid to test for
|
||||
|
||||
Returns:
|
||||
boolean Is the threepid undertest reserved_user
|
||||
"""
|
||||
|
||||
for tp in config.mau_limits_reserved_threepids:
|
||||
if (threepid['medium'] == tp['medium']
|
||||
and threepid['address'] == tp['address']):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def read_gc_thresholds(thresholds):
|
||||
"""Reads the three integer thresholds for garbage collection. Ensures that
|
||||
the thresholds are integers if thresholds are supplied.
|
||||
|
||||
@@ -32,7 +32,7 @@ Events are replicated via a separate events stream.
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
from six import iteritems, itervalues
|
||||
from six import iteritems
|
||||
|
||||
from sortedcontainers import SortedDict
|
||||
|
||||
@@ -117,7 +117,7 @@ class FederationRemoteSendQueue(object):
|
||||
|
||||
user_ids = set(
|
||||
user_id
|
||||
for uids in itervalues(self.presence_changed)
|
||||
for uids in self.presence_changed.values()
|
||||
for user_id in uids
|
||||
)
|
||||
|
||||
|
||||
@@ -1831,7 +1831,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
room_version = yield self.store.get_room_version(event.room_id)
|
||||
|
||||
new_state = self.state_handler.resolve_events(
|
||||
new_state = yield self.state_handler.resolve_events(
|
||||
room_version,
|
||||
[list(local_view.values()), list(remote_view.values())],
|
||||
event
|
||||
|
||||
@@ -125,6 +125,7 @@ class RegistrationHandler(BaseHandler):
|
||||
guest_access_token=None,
|
||||
make_guest=False,
|
||||
admin=False,
|
||||
threepid=None,
|
||||
):
|
||||
"""Registers a new client on the server.
|
||||
|
||||
@@ -145,7 +146,7 @@ class RegistrationHandler(BaseHandler):
|
||||
RegistrationError if there was a problem registering.
|
||||
"""
|
||||
|
||||
yield self.auth.check_auth_blocking()
|
||||
yield self.auth.check_auth_blocking(threepid=threepid)
|
||||
password_hash = None
|
||||
if password:
|
||||
password_hash = yield self.auth_handler().hash(password)
|
||||
|
||||
@@ -745,9 +745,16 @@ class SyncHandler(object):
|
||||
state_ids = {}
|
||||
if lazy_load_members:
|
||||
if types:
|
||||
# We're returning an incremental sync, with no "gap" since
|
||||
# the previous sync, so normally there would be no state to return
|
||||
# But we're lazy-loading, so the client might need some more
|
||||
# member events to understand the events in this timeline.
|
||||
# So we fish out all the member events corresponding to the
|
||||
# timeline here, and then dedupe any redundant ones below.
|
||||
|
||||
state_ids = yield self.store.get_state_ids_for_event(
|
||||
batch.events[0].event_id, types=types,
|
||||
filtered_types=filtered_types,
|
||||
filtered_types=None, # we only want members!
|
||||
)
|
||||
|
||||
if lazy_load_members and not include_redundant_members:
|
||||
|
||||
@@ -78,6 +78,9 @@ CONDITIONAL_REQUIREMENTS = {
|
||||
"affinity": {
|
||||
"affinity": ["affinity"],
|
||||
},
|
||||
"postgres": {
|
||||
"psycopg2>=2.6": ["psycopg2"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from twisted.internet import defer
|
||||
import synapse.util.stringutils as stringutils
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.config.server import is_threepid_reserved
|
||||
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
|
||||
from synapse.rest.client.v1.base import ClientV1RestServlet
|
||||
from synapse.types import create_requester
|
||||
@@ -281,12 +282,20 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
register_json["user"].encode("utf-8")
|
||||
if "user" in register_json else None
|
||||
)
|
||||
threepid = None
|
||||
if session.get(LoginType.EMAIL_IDENTITY):
|
||||
threepid = session["threepidCreds"]
|
||||
|
||||
handler = self.handlers.registration_handler
|
||||
(user_id, token) = yield handler.register(
|
||||
localpart=desired_user_id,
|
||||
password=password
|
||||
password=password,
|
||||
threepid=threepid,
|
||||
)
|
||||
# Necessary due to auth checks prior to the threepid being
|
||||
# written to the db
|
||||
if is_threepid_reserved(self.hs.config, threepid):
|
||||
yield self.store.upsert_monthly_active_user(user_id)
|
||||
|
||||
if session[LoginType.EMAIL_IDENTITY]:
|
||||
logger.debug("Binding emails %s to %s" % (
|
||||
|
||||
@@ -53,7 +53,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
403,
|
||||
"Your email domain is not authorized on this server",
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
@@ -89,7 +91,9 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
||||
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
403,
|
||||
"Account phone numbers are not authorized on this server",
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
@@ -241,7 +245,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
403,
|
||||
"Your email domain is not authorized on this server",
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
@@ -276,7 +282,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
403,
|
||||
"Account phone numbers are not authorized on this server",
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
|
||||
@@ -26,6 +26,7 @@ import synapse
|
||||
import synapse.types
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
|
||||
from synapse.config.server import is_threepid_reserved
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
@@ -74,7 +75,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
403,
|
||||
"Your email domain is not authorized to register on this server",
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
@@ -114,7 +117,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
403,
|
||||
"Phone numbers are not authorized to register on this server",
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
@@ -372,7 +377,9 @@ class RegisterRestServlet(RestServlet):
|
||||
|
||||
if not check_3pid_allowed(self.hs, medium, address):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed",
|
||||
403,
|
||||
"Third party identifiers (email/phone numbers)" +
|
||||
" are not authorized on this server",
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
@@ -395,12 +402,21 @@ class RegisterRestServlet(RestServlet):
|
||||
if desired_username is not None:
|
||||
desired_username = desired_username.lower()
|
||||
|
||||
threepid = None
|
||||
if auth_result:
|
||||
threepid = auth_result.get(LoginType.EMAIL_IDENTITY)
|
||||
|
||||
(registered_user_id, _) = yield self.registration_handler.register(
|
||||
localpart=desired_username,
|
||||
password=new_password,
|
||||
guest_access_token=guest_access_token,
|
||||
generate_token=False,
|
||||
threepid=threepid,
|
||||
)
|
||||
# Necessary due to auth checks prior to the threepid being
|
||||
# written to the db
|
||||
if is_threepid_reserved(self.hs.config, threepid):
|
||||
yield self.store.upsert_monthly_active_user(registered_user_id)
|
||||
|
||||
# remember that we've now registered that user account, and with
|
||||
# what user ID (since the user may not have specified)
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
# partial one for unit test mocking.
|
||||
|
||||
# Imports required for the default HomeServer() implementation
|
||||
import abc
|
||||
import logging
|
||||
|
||||
from twisted.enterprise import adbapi
|
||||
@@ -81,7 +82,6 @@ from synapse.server_notices.server_notices_manager import ServerNoticesManager
|
||||
from synapse.server_notices.server_notices_sender import ServerNoticesSender
|
||||
from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender
|
||||
from synapse.state import StateHandler, StateResolutionHandler
|
||||
from synapse.storage import DataStore
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.util import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
@@ -111,6 +111,8 @@ class HomeServer(object):
|
||||
config (synapse.config.homeserver.HomeserverConfig):
|
||||
"""
|
||||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
DEPENDENCIES = [
|
||||
'http_client',
|
||||
'db_pool',
|
||||
@@ -172,6 +174,11 @@ class HomeServer(object):
|
||||
'room_context_handler',
|
||||
]
|
||||
|
||||
# This is overridden in derived application classes
|
||||
# (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
|
||||
# instantiated during setup() for future return by get_datastore()
|
||||
DATASTORE_CLASS = abc.abstractproperty()
|
||||
|
||||
def __init__(self, hostname, reactor=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
@@ -188,13 +195,16 @@ class HomeServer(object):
|
||||
self.distributor = Distributor()
|
||||
self.ratelimiter = Ratelimiter()
|
||||
|
||||
self.datastore = None
|
||||
|
||||
# Other kwargs are explicit dependencies
|
||||
for depname in kwargs:
|
||||
setattr(self, depname, kwargs[depname])
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = DataStore(self.get_db_conn(), self)
|
||||
with self.get_db_conn() as conn:
|
||||
self.datastore = self.DATASTORE_CLASS(conn, self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def get_reactor(self):
|
||||
|
||||
@@ -385,6 +385,7 @@ class StateHandler(object):
|
||||
ev_ids, get_prev_content=False, check_redacted=False,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_events(self, room_version, state_sets, event):
|
||||
logger.info(
|
||||
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
||||
@@ -401,15 +402,17 @@ class StateHandler(object):
|
||||
}
|
||||
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = resolve_events_with_state_map(
|
||||
room_version, state_set_ids, state_map,
|
||||
new_state = yield resolve_events_with_factory(
|
||||
room_version, state_set_ids,
|
||||
event_map=state_map,
|
||||
state_map_factory=self._state_map_factory
|
||||
)
|
||||
|
||||
new_state = {
|
||||
key: state_map[ev_id] for key, ev_id in iteritems(new_state)
|
||||
}
|
||||
|
||||
return new_state
|
||||
defer.returnValue(new_state)
|
||||
|
||||
|
||||
class StateResolutionHandler(object):
|
||||
@@ -589,31 +592,6 @@ def _make_state_cache_entry(
|
||||
)
|
||||
|
||||
|
||||
def resolve_events_with_state_map(room_version, state_sets, state_map):
|
||||
"""
|
||||
Args:
|
||||
room_version(str): Version of the room
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
state_map(dict): a dict from event_id to event, for all events in
|
||||
state_sets.
|
||||
|
||||
Returns
|
||||
dict[(str, str), str]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
|
||||
return v1.resolve_events_with_state_map(
|
||||
state_sets, state_map,
|
||||
)
|
||||
else:
|
||||
# This should only happen if we added a version but forgot to add it to
|
||||
# the list above.
|
||||
raise Exception(
|
||||
"No state resolution algorithm defined for version %r" % (room_version,)
|
||||
)
|
||||
|
||||
|
||||
def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -30,34 +30,6 @@ logger = logging.getLogger(__name__)
|
||||
POWER_KEY = (EventTypes.PowerLevels, "")
|
||||
|
||||
|
||||
def resolve_events_with_state_map(state_sets, state_map):
|
||||
"""
|
||||
Args:
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
state_map(dict): a dict from event_id to event, for all events in
|
||||
state_sets.
|
||||
|
||||
Returns
|
||||
dict[(str, str), str]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
if len(state_sets) == 1:
|
||||
return state_sets[0]
|
||||
|
||||
unconflicted_state, conflicted_state = _seperate(
|
||||
state_sets,
|
||||
)
|
||||
|
||||
auth_events = _create_auth_events_from_maps(
|
||||
unconflicted_state, conflicted_state, state_map
|
||||
)
|
||||
|
||||
return _resolve_with_state(
|
||||
unconflicted_state, conflicted_state, auth_events, state_map
|
||||
)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
||||
"""
|
||||
|
||||
@@ -17,9 +17,10 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
from six import iteritems, iterkeys, itervalues
|
||||
from six import PY2, iteritems, iterkeys, itervalues
|
||||
from six.moves import intern, range
|
||||
|
||||
from canonicaljson import json
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from twisted.internet import defer
|
||||
@@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
|
||||
something went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def db_to_json(db_content):
|
||||
"""
|
||||
Take some data from a database row and return a JSON-decoded object.
|
||||
|
||||
Args:
|
||||
db_content (memoryview|buffer|bytes|bytearray|unicode)
|
||||
"""
|
||||
# psycopg2 on Python 3 returns memoryview objects, which we need to
|
||||
# cast to bytes to decode
|
||||
if isinstance(db_content, memoryview):
|
||||
db_content = db_content.tobytes()
|
||||
|
||||
# psycopg2 on Python 2 returns buffer objects, which we need to cast to
|
||||
# bytes to decode
|
||||
if PY2 and isinstance(db_content, buffer):
|
||||
db_content = bytes(db_content)
|
||||
|
||||
# Decode it to a Unicode string before feeding it to json.loads, so we
|
||||
# consistenty get a Unicode-containing object out.
|
||||
if isinstance(db_content, (bytes, bytearray)):
|
||||
db_content = db_content.decode('utf8')
|
||||
|
||||
try:
|
||||
return json.loads(db_content)
|
||||
except Exception:
|
||||
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
|
||||
raise
|
||||
|
||||
@@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
|
||||
local_by_user_then_device = {}
|
||||
for user_id, messages_by_device in messages_by_user_then_device.items():
|
||||
messages_json_for_user = {}
|
||||
devices = messages_by_device.keys()
|
||||
devices = list(messages_by_device.keys())
|
||||
if len(devices) == 1 and devices[0] == "*":
|
||||
# Handle wildcard device_ids.
|
||||
sql = (
|
||||
|
||||
@@ -24,7 +24,7 @@ from synapse.api.errors import StoreError
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
|
||||
|
||||
from ._base import Cache, SQLBaseStore
|
||||
from ._base import Cache, SQLBaseStore, db_to_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore):
|
||||
if device is not None:
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = json.loads(key_json)
|
||||
result["keys"] = db_to_json(key_json)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
@@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore):
|
||||
retcol="content",
|
||||
desc="_get_cached_user_device",
|
||||
)
|
||||
defer.returnValue(json.loads(content))
|
||||
defer.returnValue(db_to_json(content))
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def _get_cached_devices_for_user(self, user_id):
|
||||
@@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore):
|
||||
desc="_get_cached_devices_for_user",
|
||||
)
|
||||
defer.returnValue({
|
||||
device["device_id"]: json.loads(device["content"])
|
||||
device["device_id"]: db_to_json(device["content"])
|
||||
for device in devices
|
||||
})
|
||||
|
||||
@@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore):
|
||||
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = json.loads(key_json)
|
||||
result["keys"] = db_to_json(key_json)
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
|
||||
@@ -14,13 +14,13 @@
|
||||
# limitations under the License.
|
||||
from six import iteritems
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, db_to_json
|
||||
|
||||
|
||||
class EndToEndKeyStore(SQLBaseStore):
|
||||
@@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore):
|
||||
|
||||
for user_id, device_keys in iteritems(results):
|
||||
for device_id, device_info in iteritems(device_keys):
|
||||
device_info["keys"] = json.loads(device_info.pop("key_json"))
|
||||
device_info["keys"] = db_to_json(device_info.pop("key_json"))
|
||||
|
||||
defer.returnValue(results)
|
||||
|
||||
|
||||
@@ -41,13 +41,18 @@ class PostgresEngine(object):
|
||||
db_conn.set_isolation_level(
|
||||
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||
)
|
||||
|
||||
# Set the bytea output to escape, vs the default of hex
|
||||
cursor = db_conn.cursor()
|
||||
cursor.execute("SET bytea_output TO escape")
|
||||
|
||||
# Asynchronous commit, don't wait for the server to call fsync before
|
||||
# ending the transaction.
|
||||
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
|
||||
if not self.synchronous_commit:
|
||||
cursor = db_conn.cursor()
|
||||
cursor.execute("SET synchronous_commit TO OFF")
|
||||
cursor.close()
|
||||
|
||||
cursor.close()
|
||||
|
||||
def is_deadlock(self, error):
|
||||
if isinstance(error, self.module.DatabaseError):
|
||||
|
||||
@@ -19,7 +19,7 @@ import logging
|
||||
from collections import OrderedDict, deque, namedtuple
|
||||
from functools import wraps
|
||||
|
||||
from six import iteritems
|
||||
from six import iteritems, text_type
|
||||
from six.moves import range
|
||||
|
||||
from canonicaljson import json
|
||||
@@ -1220,7 +1220,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
||||
"sender": event.sender,
|
||||
"contains_url": (
|
||||
"url" in event.content
|
||||
and isinstance(event.content["url"], basestring)
|
||||
and isinstance(event.content["url"], text_type)
|
||||
),
|
||||
}
|
||||
for event, _ in events_and_contexts
|
||||
@@ -1529,7 +1529,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
||||
|
||||
contains_url = "url" in content
|
||||
if contains_url:
|
||||
contains_url &= isinstance(content["url"], basestring)
|
||||
contains_url &= isinstance(content["url"], text_type)
|
||||
except (KeyError, AttributeError):
|
||||
# If the event is missing a necessary field then
|
||||
# skip over it.
|
||||
@@ -1910,9 +1910,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
|
||||
(room_id,)
|
||||
)
|
||||
rows = txn.fetchall()
|
||||
max_depth = max(row[0] for row in rows)
|
||||
max_depth = max(row[1] for row in rows)
|
||||
|
||||
if max_depth <= token.topological:
|
||||
if max_depth < token.topological:
|
||||
# We need to ensure we don't delete all the events from the database
|
||||
# otherwise we wouldn't be able to send any events (due to not
|
||||
# having any backwards extremeties)
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
@@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
with Measure(self._clock, "_fetch_event_list"):
|
||||
try:
|
||||
event_id_lists = zip(*event_list)[0]
|
||||
event_id_lists = list(zip(*event_list))[0]
|
||||
event_ids = [
|
||||
item for sublist in event_id_lists for item in sublist
|
||||
]
|
||||
@@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
logger.exception("do_fetch")
|
||||
|
||||
# We only want to resolve deferreds from the main thread
|
||||
def fire(evs):
|
||||
def fire(evs, exc):
|
||||
for _, d in evs:
|
||||
if not d.called:
|
||||
with PreserveLoggingContext():
|
||||
d.errback(e)
|
||||
d.errback(exc)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
self.hs.get_reactor().callFromThread(fire, event_list)
|
||||
self.hs.get_reactor().callFromThread(fire, event_list, e)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
|
||||
|
||||
@@ -13,14 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, db_to_json
|
||||
|
||||
|
||||
class FilteringStore(SQLBaseStore):
|
||||
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
|
||||
desc="get_user_filter",
|
||||
)
|
||||
|
||||
defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
|
||||
defer.returnValue(db_to_json(def_json))
|
||||
|
||||
def add_user_filter(self, user_localpart, user_filter):
|
||||
def_json = encode_canonical_json(user_filter)
|
||||
|
||||
@@ -36,7 +36,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def initialise_reserved_users(self, threepids):
|
||||
# TODO Why can't I do this in init?
|
||||
store = self.hs.get_datastore()
|
||||
reserved_user_list = []
|
||||
|
||||
|
||||
@@ -15,7 +15,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import types
|
||||
|
||||
import six
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
|
||||
@@ -27,6 +28,11 @@ from ._base import SQLBaseStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if six.PY2:
|
||||
db_binary_type = buffer
|
||||
else:
|
||||
db_binary_type = memoryview
|
||||
|
||||
|
||||
class PusherWorkerStore(SQLBaseStore):
|
||||
def _decode_pushers_rows(self, rows):
|
||||
@@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
dataJson = r['data']
|
||||
r['data'] = None
|
||||
try:
|
||||
if isinstance(dataJson, types.BufferType):
|
||||
if isinstance(dataJson, db_binary_type):
|
||||
dataJson = str(dataJson).decode("UTF8")
|
||||
|
||||
r['data'] = json.loads(dataJson)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"Invalid JSON in data for pusher %d: %s, %s",
|
||||
r['id'], dataJson, e.message,
|
||||
r['id'], dataJson, e.args[0],
|
||||
)
|
||||
pass
|
||||
|
||||
if isinstance(r['pushkey'], types.BufferType):
|
||||
if isinstance(r['pushkey'], db_binary_type):
|
||||
r['pushkey'] = str(r['pushkey']).decode("UTF8")
|
||||
|
||||
return rows
|
||||
|
||||
@@ -18,14 +18,14 @@ from collections import namedtuple
|
||||
|
||||
import six
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, db_to_json
|
||||
|
||||
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
|
||||
# despite being deprecated and removed in favor of memoryview
|
||||
@@ -95,7 +95,8 @@ class TransactionStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
if result and result["response_code"]:
|
||||
return result["response_code"], json.loads(str(result["response_json"]))
|
||||
return result["response_code"], db_to_json(result["response_json"])
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -467,6 +467,23 @@ class AuthTestCase(unittest.TestCase):
|
||||
)
|
||||
yield self.auth.check_auth_blocking()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_reserved_threepid(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.hs.config.max_mau_value = 1
|
||||
threepid = {'medium': 'email', 'address': 'reserved@server.com'}
|
||||
unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'}
|
||||
self.hs.config.mau_limits_reserved_threepids = [threepid]
|
||||
|
||||
yield self.store.register(user_id='user1', token="123", password_hash=None)
|
||||
with self.assertRaises(ResourceLimitError):
|
||||
yield self.auth.check_auth_blocking()
|
||||
|
||||
with self.assertRaises(ResourceLimitError):
|
||||
yield self.auth.check_auth_blocking(threepid=unknown_threepid)
|
||||
|
||||
yield self.auth.check_auth_blocking(threepid=threepid)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_hs_disabled(self):
|
||||
self.hs.config.hs_disabled = True
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,79 +14,79 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.api.errors
|
||||
import synapse.handlers.device
|
||||
import synapse.storage
|
||||
|
||||
from tests import unittest, utils
|
||||
from tests import unittest
|
||||
|
||||
user1 = "@boris:aaa"
|
||||
user2 = "@theresa:bbb"
|
||||
|
||||
|
||||
class DeviceTestCase(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DeviceTestCase, self).__init__(*args, **kwargs)
|
||||
self.store = None # type: synapse.storage.DataStore
|
||||
self.handler = None # type: synapse.handlers.device.DeviceHandler
|
||||
self.clock = None # type: utils.MockClock
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
hs = yield utils.setup_test_homeserver(self.addCleanup)
|
||||
class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver("server", http_client=None)
|
||||
self.handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# These tests assume that it starts 1000 seconds in.
|
||||
self.reactor.advance(1000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_device_is_created_if_doesnt_exist(self):
|
||||
res = yield self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
device_id="fco",
|
||||
initial_device_display_name="display name",
|
||||
res = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
device_id="fco",
|
||||
initial_device_display_name="display name",
|
||||
)
|
||||
)
|
||||
self.assertEqual(res, "fco")
|
||||
|
||||
dev = yield self.handler.store.get_device("@boris:foo", "fco")
|
||||
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
|
||||
self.assertEqual(dev["display_name"], "display name")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_device_is_preserved_if_exists(self):
|
||||
res1 = yield self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
device_id="fco",
|
||||
initial_device_display_name="display name",
|
||||
res1 = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
device_id="fco",
|
||||
initial_device_display_name="display name",
|
||||
)
|
||||
)
|
||||
self.assertEqual(res1, "fco")
|
||||
|
||||
res2 = yield self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
device_id="fco",
|
||||
initial_device_display_name="new display name",
|
||||
res2 = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
device_id="fco",
|
||||
initial_device_display_name="new display name",
|
||||
)
|
||||
)
|
||||
self.assertEqual(res2, "fco")
|
||||
|
||||
dev = yield self.handler.store.get_device("@boris:foo", "fco")
|
||||
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
|
||||
self.assertEqual(dev["display_name"], "display name")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_device_id_is_made_up_if_unspecified(self):
|
||||
device_id = yield self.handler.check_device_registered(
|
||||
user_id="@theresa:foo",
|
||||
device_id=None,
|
||||
initial_device_display_name="display",
|
||||
device_id = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@theresa:foo",
|
||||
device_id=None,
|
||||
initial_device_display_name="display",
|
||||
)
|
||||
)
|
||||
|
||||
dev = yield self.handler.store.get_device("@theresa:foo", device_id)
|
||||
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
|
||||
self.assertEqual(dev["display_name"], "display")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_devices_by_user(self):
|
||||
yield self._record_users()
|
||||
self._record_users()
|
||||
|
||||
res = self.get_success(self.handler.get_devices_by_user(user1))
|
||||
|
||||
res = yield self.handler.get_devices_by_user(user1)
|
||||
self.assertEqual(3, len(res))
|
||||
device_map = {d["device_id"]: d for d in res}
|
||||
self.assertDictContainsSubset(
|
||||
@@ -119,11 +120,10 @@ class DeviceTestCase(unittest.TestCase):
|
||||
device_map["abc"],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_device(self):
|
||||
yield self._record_users()
|
||||
self._record_users()
|
||||
|
||||
res = yield self.handler.get_device(user1, "abc")
|
||||
res = self.get_success(self.handler.get_device(user1, "abc"))
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"user_id": user1,
|
||||
@@ -135,59 +135,66 @@ class DeviceTestCase(unittest.TestCase):
|
||||
res,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_delete_device(self):
|
||||
yield self._record_users()
|
||||
self._record_users()
|
||||
|
||||
# delete the device
|
||||
yield self.handler.delete_device(user1, "abc")
|
||||
self.get_success(self.handler.delete_device(user1, "abc"))
|
||||
|
||||
# check the device was deleted
|
||||
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||
yield self.handler.get_device(user1, "abc")
|
||||
res = self.handler.get_device(user1, "abc")
|
||||
self.pump()
|
||||
self.assertIsInstance(
|
||||
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
|
||||
)
|
||||
|
||||
# we'd like to check the access token was invalidated, but that's a
|
||||
# bit of a PITA.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_device(self):
|
||||
yield self._record_users()
|
||||
self._record_users()
|
||||
|
||||
update = {"display_name": "new display"}
|
||||
yield self.handler.update_device(user1, "abc", update)
|
||||
self.get_success(self.handler.update_device(user1, "abc", update))
|
||||
|
||||
res = yield self.handler.get_device(user1, "abc")
|
||||
res = self.get_success(self.handler.get_device(user1, "abc"))
|
||||
self.assertEqual(res["display_name"], "new display")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_update_unknown_device(self):
|
||||
update = {"display_name": "new_display"}
|
||||
with self.assertRaises(synapse.api.errors.NotFoundError):
|
||||
yield self.handler.update_device("user_id", "unknown_device_id", update)
|
||||
res = self.handler.update_device("user_id", "unknown_device_id", update)
|
||||
self.pump()
|
||||
self.assertIsInstance(
|
||||
self.failureResultOf(res).value, synapse.api.errors.NotFoundError
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _record_users(self):
|
||||
# check this works for both devices which have a recorded client_ip,
|
||||
# and those which don't.
|
||||
yield self._record_user(user1, "xyz", "display 0")
|
||||
yield self._record_user(user1, "fco", "display 1", "token1", "ip1")
|
||||
yield self._record_user(user1, "abc", "display 2", "token2", "ip2")
|
||||
yield self._record_user(user1, "abc", "display 2", "token3", "ip3")
|
||||
self._record_user(user1, "xyz", "display 0")
|
||||
self._record_user(user1, "fco", "display 1", "token1", "ip1")
|
||||
self._record_user(user1, "abc", "display 2", "token2", "ip2")
|
||||
self._record_user(user1, "abc", "display 2", "token3", "ip3")
|
||||
|
||||
yield self._record_user(user2, "def", "dispkay", "token4", "ip4")
|
||||
self._record_user(user2, "def", "dispkay", "token4", "ip4")
|
||||
|
||||
self.reactor.advance(10000)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _record_user(
|
||||
self, user_id, device_id, display_name, access_token=None, ip=None
|
||||
):
|
||||
device_id = yield self.handler.check_device_registered(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
initial_device_display_name=display_name,
|
||||
device_id = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
initial_device_display_name=display_name,
|
||||
)
|
||||
)
|
||||
|
||||
if ip is not None:
|
||||
yield self.store.insert_client_ip(
|
||||
user_id, access_token, ip, "user_agent", device_id
|
||||
self.get_success(
|
||||
self.store.insert_client_ip(
|
||||
user_id, access_token, ip, "user_agent", device_id
|
||||
)
|
||||
)
|
||||
self.clock.advance_time(1000)
|
||||
self.reactor.advance(1000)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -11,89 +12,91 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.defer import Deferred
|
||||
import attr
|
||||
|
||||
from synapse.replication.tcp.client import (
|
||||
ReplicationClientFactory,
|
||||
ReplicationClientHandler,
|
||||
)
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
|
||||
class TestReplicationClientHandler(ReplicationClientHandler):
|
||||
"""Overrides on_rdata so that we can wait for it to happen"""
|
||||
class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
def __init__(self, store):
|
||||
super(TestReplicationClientHandler, self).__init__(store)
|
||||
self._rdata_awaiters = []
|
||||
|
||||
def await_replication(self):
|
||||
d = Deferred()
|
||||
self._rdata_awaiters.append(d)
|
||||
return make_deferred_yieldable(d)
|
||||
|
||||
def on_rdata(self, stream_name, token, rows):
|
||||
awaiters = self._rdata_awaiters
|
||||
self._rdata_awaiters = []
|
||||
super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
|
||||
with PreserveLoggingContext():
|
||||
for a in awaiters:
|
||||
a.callback(None)
|
||||
|
||||
|
||||
class BaseSlavedStoreTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
hs = self.setup_test_homeserver(
|
||||
"blue",
|
||||
http_client=None,
|
||||
federation_client=Mock(),
|
||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
||||
)
|
||||
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
|
||||
|
||||
hs.get_ratelimiter().send_message.return_value = (True, 0)
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
|
||||
self.master_store = self.hs.get_datastore()
|
||||
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
|
||||
self.event_id = 0
|
||||
|
||||
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||
# XXX: mktemp is unsafe and should never be used. but we're just a test.
|
||||
path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
|
||||
listener = reactor.listenUNIX(path, server_factory)
|
||||
self.addCleanup(listener.stopListening)
|
||||
self.streamer = server_factory.streamer
|
||||
|
||||
self.replication_handler = TestReplicationClientHandler(self.slaved_store)
|
||||
self.replication_handler = ReplicationClientHandler(self.slaved_store)
|
||||
client_factory = ReplicationClientFactory(
|
||||
self.hs, "client_name", self.replication_handler
|
||||
)
|
||||
client_connector = reactor.connectUNIX(path, client_factory)
|
||||
self.addCleanup(client_factory.stopTrying)
|
||||
self.addCleanup(client_connector.disconnect)
|
||||
|
||||
server = server_factory.buildProtocol(None)
|
||||
client = client_factory.buildProtocol(None)
|
||||
|
||||
@attr.s
|
||||
class FakeTransport(object):
|
||||
|
||||
other = attr.ib()
|
||||
disconnecting = False
|
||||
buffer = attr.ib(default=b'')
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
|
||||
self.producer = producer
|
||||
|
||||
def _produce():
|
||||
self.producer.resumeProducing()
|
||||
reactor.callLater(0.1, _produce)
|
||||
|
||||
reactor.callLater(0.0, _produce)
|
||||
|
||||
def write(self, byt):
|
||||
self.buffer = self.buffer + byt
|
||||
|
||||
if getattr(self.other, "transport") is not None:
|
||||
self.other.dataReceived(self.buffer)
|
||||
self.buffer = b""
|
||||
|
||||
def writeSequence(self, seq):
|
||||
for x in seq:
|
||||
self.write(x)
|
||||
|
||||
client.makeConnection(FakeTransport(server))
|
||||
server.makeConnection(FakeTransport(client))
|
||||
|
||||
def replicate(self):
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
wait for the replication to occur.
|
||||
"""
|
||||
# xxx: should we be more specific in what we wait for?
|
||||
d = self.replication_handler.await_replication()
|
||||
self.streamer.on_notifier_poke()
|
||||
return d
|
||||
self.pump(0.1)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check(self, method, args, expected_result=None):
|
||||
master_result = yield getattr(self.master_store, method)(*args)
|
||||
slaved_result = yield getattr(self.slaved_store, method)(*args)
|
||||
master_result = self.get_success(getattr(self.master_store, method)(*args))
|
||||
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
|
||||
if expected_result is not None:
|
||||
self.assertEqual(master_result, expected_result)
|
||||
self.assertEqual(slaved_result, expected_result)
|
||||
|
||||
@@ -12,9 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
|
||||
from ._base import BaseSlavedStoreTestCase
|
||||
@@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
STORE_TYPE = SlavedAccountDataStore
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_account_data(self):
|
||||
yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
self.get_success(
|
||||
self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
|
||||
)
|
||||
self.replicate()
|
||||
self.check(
|
||||
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1}
|
||||
)
|
||||
|
||||
yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
self.get_success(
|
||||
self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
|
||||
)
|
||||
self.replicate()
|
||||
self.check(
|
||||
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2}
|
||||
)
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.events import FrozenEvent, _EventInternalMetadata
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
@@ -55,70 +53,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
def tearDown(self):
|
||||
[unpatch() for unpatch in self.unpatches]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_latest_event_ids_in_room(self):
|
||||
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.replicate()
|
||||
yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
|
||||
create = self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.replicate()
|
||||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
|
||||
|
||||
join = yield self.persist(
|
||||
join = self.persist(
|
||||
type="m.room.member",
|
||||
key=USER_ID,
|
||||
membership="join",
|
||||
prev_events=[(create.event_id, {})],
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
|
||||
self.replicate()
|
||||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_redactions(self):
|
||||
yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||
|
||||
msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
|
||||
yield self.replicate()
|
||||
yield self.check("get_event", [msg.event_id], msg)
|
||||
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
|
||||
self.replicate()
|
||||
self.check("get_event", [msg.event_id], msg)
|
||||
|
||||
redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id)
|
||||
yield self.replicate()
|
||||
redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
|
||||
self.replicate()
|
||||
|
||||
msg_dict = msg.get_dict()
|
||||
msg_dict["content"] = {}
|
||||
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
|
||||
msg_dict["unsigned"]["redacted_because"] = redaction
|
||||
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
|
||||
yield self.check("get_event", [msg.event_id], redacted)
|
||||
self.check("get_event", [msg.event_id], redacted)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_backfilled_redactions(self):
|
||||
yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||
|
||||
msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
|
||||
yield self.replicate()
|
||||
yield self.check("get_event", [msg.event_id], msg)
|
||||
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
|
||||
self.replicate()
|
||||
self.check("get_event", [msg.event_id], msg)
|
||||
|
||||
redaction = yield self.persist(
|
||||
redaction = self.persist(
|
||||
type="m.room.redaction", redacts=msg.event_id, backfill=True
|
||||
)
|
||||
yield self.replicate()
|
||||
self.replicate()
|
||||
|
||||
msg_dict = msg.get_dict()
|
||||
msg_dict["content"] = {}
|
||||
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
|
||||
msg_dict["unsigned"]["redacted_because"] = redaction
|
||||
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
|
||||
yield self.check("get_event", [msg.event_id], redacted)
|
||||
self.check("get_event", [msg.event_id], redacted)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invites(self):
|
||||
yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
|
||||
event = yield self.persist(
|
||||
type="m.room.member", key=USER_ID_2, membership="invite"
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.check("get_invited_rooms_for_user", [USER_ID_2], [])
|
||||
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
|
||||
|
||||
self.replicate()
|
||||
|
||||
self.check(
|
||||
"get_invited_rooms_for_user",
|
||||
[USER_ID_2],
|
||||
[
|
||||
@@ -132,37 +126,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_push_actions_for_user(self):
|
||||
yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.persist(type="m.room.join", key=USER_ID, membership="join")
|
||||
yield self.persist(
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.persist(type="m.room.join", key=USER_ID, membership="join")
|
||||
self.persist(
|
||||
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
|
||||
)
|
||||
event1 = yield self.persist(
|
||||
type="m.room.message", msgtype="m.text", body="hello"
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
|
||||
self.replicate()
|
||||
self.check(
|
||||
"get_unread_event_push_actions_by_room_for_user",
|
||||
[ROOM_ID, USER_ID_2, event1.event_id],
|
||||
{"highlight_count": 0, "notify_count": 0},
|
||||
)
|
||||
|
||||
yield self.persist(
|
||||
self.persist(
|
||||
type="m.room.message",
|
||||
msgtype="m.text",
|
||||
body="world",
|
||||
push_actions=[(USER_ID_2, ["notify"])],
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
self.replicate()
|
||||
self.check(
|
||||
"get_unread_event_push_actions_by_room_for_user",
|
||||
[ROOM_ID, USER_ID_2, event1.event_id],
|
||||
{"highlight_count": 0, "notify_count": 1},
|
||||
)
|
||||
|
||||
yield self.persist(
|
||||
self.persist(
|
||||
type="m.room.message",
|
||||
msgtype="m.text",
|
||||
body="world",
|
||||
@@ -170,8 +161,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
(USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
|
||||
],
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
self.replicate()
|
||||
self.check(
|
||||
"get_unread_event_push_actions_by_room_for_user",
|
||||
[ROOM_ID, USER_ID_2, event1.event_id],
|
||||
{"highlight_count": 1, "notify_count": 2},
|
||||
@@ -179,7 +170,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
event_id = 0
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def persist(
|
||||
self,
|
||||
sender=USER_ID,
|
||||
@@ -206,8 +196,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
depth = self.event_id
|
||||
|
||||
if not prev_events:
|
||||
latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
|
||||
room_id
|
||||
latest_event_ids = self.get_success(
|
||||
self.master_store.get_latest_event_ids_in_room(room_id)
|
||||
)
|
||||
prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
|
||||
|
||||
@@ -240,19 +230,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
)
|
||||
else:
|
||||
state_handler = self.hs.get_state_handler()
|
||||
context = yield state_handler.compute_event_context(event)
|
||||
context = self.get_success(state_handler.compute_event_context(event))
|
||||
|
||||
yield self.master_store.add_push_actions_to_staging(
|
||||
self.master_store.add_push_actions_to_staging(
|
||||
event.event_id, {user_id: actions for user_id, actions in push_actions}
|
||||
)
|
||||
|
||||
ordering = None
|
||||
if backfill:
|
||||
yield self.master_store.persist_events([(event, context)], backfilled=True)
|
||||
self.get_success(
|
||||
self.master_store.persist_events([(event, context)], backfilled=True)
|
||||
)
|
||||
else:
|
||||
ordering, _ = yield self.master_store.persist_event(event, context)
|
||||
ordering, _ = self.get_success(
|
||||
self.master_store.persist_event(event, context)
|
||||
)
|
||||
|
||||
if ordering:
|
||||
event.internal_metadata.stream_ordering = ordering
|
||||
|
||||
defer.returnValue(event)
|
||||
return event
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
|
||||
from ._base import BaseSlavedStoreTestCase
|
||||
@@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
STORE_TYPE = SlavedReceiptsStore
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_receipt(self):
|
||||
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
|
||||
yield self.master_store.insert_receipt(
|
||||
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
"get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}
|
||||
self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
|
||||
self.get_success(
|
||||
self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {})
|
||||
)
|
||||
self.replicate()
|
||||
self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID})
|
||||
|
||||
@@ -240,7 +240,6 @@ class RestHelper(object):
|
||||
self.assertEquals(200, code)
|
||||
defer.returnValue(response)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
@@ -248,9 +247,16 @@ class RestHelper(object):
|
||||
body = "body_text_here"
|
||||
|
||||
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
|
||||
content = '{"msgtype":"m.text","body":"%s"}' % body
|
||||
content = {"msgtype": "m.text", "body": body}
|
||||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
(code, response) = yield self.mock_resource.trigger("PUT", path, content)
|
||||
self.assertEquals(expect_code, code, msg=str(response))
|
||||
request, channel = make_request("PUT", path, json.dumps(content).encode('utf8'))
|
||||
render(request, self.resource, self.hs.get_reactor())
|
||||
|
||||
assert int(channel.result["code"]) == expect_code, (
|
||||
"Expected: %d, got: %d, resp: %r"
|
||||
% (expect_code, int(channel.result["code"]), channel.result["body"])
|
||||
)
|
||||
|
||||
return channel.json_body
|
||||
|
||||
@@ -232,6 +232,7 @@ def setup_test_homeserver(cleanup_func, *args, **kwargs):
|
||||
|
||||
clock.threadpool = ThreadPool()
|
||||
pool.threadpool = ThreadPool()
|
||||
pool.running = True
|
||||
return d
|
||||
|
||||
|
||||
|
||||
@@ -37,18 +37,14 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.as_yaml_files = []
|
||||
config = Mock(
|
||||
app_service_config_files=self.as_yaml_files,
|
||||
event_cache_size=1,
|
||||
password_providers=[],
|
||||
)
|
||||
hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
config=config,
|
||||
federation_sender=Mock(),
|
||||
federation_client=Mock(),
|
||||
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
|
||||
)
|
||||
|
||||
hs.config.app_service_config_files = self.as_yaml_files
|
||||
hs.config.event_cache_size = 1
|
||||
hs.config.password_providers = []
|
||||
|
||||
self.as_token = "token1"
|
||||
self.as_url = "some_url"
|
||||
self.as_id = "as1"
|
||||
@@ -58,7 +54,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
|
||||
self._add_appservice("token2", "as2", "some_url", "some_hs_token", "bob")
|
||||
self._add_appservice("token3", "as3", "some_url", "some_hs_token", "bob")
|
||||
# must be done after inserts
|
||||
self.store = ApplicationServiceStore(None, hs)
|
||||
self.store = ApplicationServiceStore(hs.get_db_conn(), hs)
|
||||
|
||||
def tearDown(self):
|
||||
# TODO: suboptimal that we need to create files for tests!
|
||||
@@ -105,18 +101,16 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.as_yaml_files = []
|
||||
|
||||
config = Mock(
|
||||
app_service_config_files=self.as_yaml_files,
|
||||
event_cache_size=1,
|
||||
password_providers=[],
|
||||
)
|
||||
hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
config=config,
|
||||
federation_sender=Mock(),
|
||||
federation_client=Mock(),
|
||||
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
|
||||
)
|
||||
|
||||
hs.config.app_service_config_files = self.as_yaml_files
|
||||
hs.config.event_cache_size = 1
|
||||
hs.config.password_providers = []
|
||||
|
||||
self.db_pool = hs.get_db_pool()
|
||||
self.engine = hs.database_engine
|
||||
|
||||
self.as_list = [
|
||||
{"token": "token1", "url": "https://matrix-as.org", "id": "id_1"},
|
||||
@@ -129,7 +123,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||
|
||||
self.as_yaml_files = []
|
||||
|
||||
self.store = TestTransactionStore(None, hs)
|
||||
self.store = TestTransactionStore(hs.get_db_conn(), hs)
|
||||
|
||||
def _add_service(self, url, as_token, id):
|
||||
as_yaml = dict(
|
||||
@@ -146,29 +140,35 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||
self.as_yaml_files.append(as_token)
|
||||
|
||||
def _set_state(self, id, state, txn=None):
|
||||
return self.db_pool.runQuery(
|
||||
"INSERT INTO application_services_state(as_id, state, last_txn) "
|
||||
"VALUES(?,?,?)",
|
||||
return self.db_pool.runOperation(
|
||||
self.engine.convert_param_style(
|
||||
"INSERT INTO application_services_state(as_id, state, last_txn) "
|
||||
"VALUES(?,?,?)"
|
||||
),
|
||||
(id, state, txn),
|
||||
)
|
||||
|
||||
def _insert_txn(self, as_id, txn_id, events):
|
||||
return self.db_pool.runQuery(
|
||||
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
|
||||
"VALUES(?,?,?)",
|
||||
return self.db_pool.runOperation(
|
||||
self.engine.convert_param_style(
|
||||
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
|
||||
"VALUES(?,?,?)"
|
||||
),
|
||||
(as_id, txn_id, json.dumps([e.event_id for e in events])),
|
||||
)
|
||||
|
||||
def _set_last_txn(self, as_id, txn_id):
|
||||
return self.db_pool.runQuery(
|
||||
"INSERT INTO application_services_state(as_id, last_txn, state) "
|
||||
"VALUES(?,?,?)",
|
||||
return self.db_pool.runOperation(
|
||||
self.engine.convert_param_style(
|
||||
"INSERT INTO application_services_state(as_id, last_txn, state) "
|
||||
"VALUES(?,?,?)"
|
||||
),
|
||||
(as_id, txn_id, ApplicationServiceState.UP),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_appservice_state_none(self):
|
||||
service = Mock(id=999)
|
||||
service = Mock(id="999")
|
||||
state = yield self.store.get_appservice_state(service)
|
||||
self.assertEquals(None, state)
|
||||
|
||||
@@ -200,7 +200,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||
service = Mock(id=self.as_list[1]["id"])
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
rows = yield self.db_pool.runQuery(
|
||||
"SELECT as_id FROM application_services_state WHERE state=?",
|
||||
self.engine.convert_param_style(
|
||||
"SELECT as_id FROM application_services_state WHERE state=?"
|
||||
),
|
||||
(ApplicationServiceState.DOWN,),
|
||||
)
|
||||
self.assertEquals(service.id, rows[0][0])
|
||||
@@ -212,7 +214,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
|
||||
yield self.store.set_appservice_state(service, ApplicationServiceState.UP)
|
||||
rows = yield self.db_pool.runQuery(
|
||||
"SELECT as_id FROM application_services_state WHERE state=?",
|
||||
self.engine.convert_param_style(
|
||||
"SELECT as_id FROM application_services_state WHERE state=?"
|
||||
),
|
||||
(ApplicationServiceState.UP,),
|
||||
)
|
||||
self.assertEquals(service.id, rows[0][0])
|
||||
@@ -279,14 +283,19 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
|
||||
|
||||
res = yield self.db_pool.runQuery(
|
||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||
self.engine.convert_param_style(
|
||||
"SELECT last_txn FROM application_services_state WHERE as_id=?"
|
||||
),
|
||||
(service.id,),
|
||||
)
|
||||
self.assertEquals(1, len(res))
|
||||
self.assertEquals(txn_id, res[0][0])
|
||||
|
||||
res = yield self.db_pool.runQuery(
|
||||
"SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
|
||||
self.engine.convert_param_style(
|
||||
"SELECT * FROM application_services_txns WHERE txn_id=?"
|
||||
),
|
||||
(txn_id,),
|
||||
)
|
||||
self.assertEquals(0, len(res))
|
||||
|
||||
@@ -300,7 +309,9 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||
yield self.store.complete_appservice_txn(txn_id=txn_id, service=service)
|
||||
|
||||
res = yield self.db_pool.runQuery(
|
||||
"SELECT last_txn, state FROM application_services_state WHERE " "as_id=?",
|
||||
self.engine.convert_param_style(
|
||||
"SELECT last_txn, state FROM application_services_state WHERE as_id=?"
|
||||
),
|
||||
(service.id,),
|
||||
)
|
||||
self.assertEquals(1, len(res))
|
||||
@@ -308,7 +319,10 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||
self.assertEquals(ApplicationServiceState.UP, res[0][1])
|
||||
|
||||
res = yield self.db_pool.runQuery(
|
||||
"SELECT * FROM application_services_txns WHERE txn_id=?", (txn_id,)
|
||||
self.engine.convert_param_style(
|
||||
"SELECT * FROM application_services_txns WHERE txn_id=?"
|
||||
),
|
||||
(txn_id,),
|
||||
)
|
||||
self.assertEquals(0, len(res))
|
||||
|
||||
@@ -394,37 +408,31 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
||||
f1 = self._write_config(suffix="1")
|
||||
f2 = self._write_config(suffix="2")
|
||||
|
||||
config = Mock(
|
||||
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
|
||||
)
|
||||
hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
config=config,
|
||||
datastore=Mock(),
|
||||
federation_sender=Mock(),
|
||||
federation_client=Mock(),
|
||||
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
|
||||
)
|
||||
|
||||
ApplicationServiceStore(None, hs)
|
||||
hs.config.app_service_config_files = [f1, f2]
|
||||
hs.config.event_cache_size = 1
|
||||
hs.config.password_providers = []
|
||||
|
||||
ApplicationServiceStore(hs.get_db_conn(), hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_duplicate_ids(self):
|
||||
f1 = self._write_config(id="id", suffix="1")
|
||||
f2 = self._write_config(id="id", suffix="2")
|
||||
|
||||
config = Mock(
|
||||
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
|
||||
)
|
||||
hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
config=config,
|
||||
datastore=Mock(),
|
||||
federation_sender=Mock(),
|
||||
federation_client=Mock(),
|
||||
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
|
||||
)
|
||||
|
||||
hs.config.app_service_config_files = [f1, f2]
|
||||
hs.config.event_cache_size = 1
|
||||
hs.config.password_providers = []
|
||||
|
||||
with self.assertRaises(ConfigError) as cm:
|
||||
ApplicationServiceStore(None, hs)
|
||||
ApplicationServiceStore(hs.get_db_conn(), hs)
|
||||
|
||||
e = cm.exception
|
||||
self.assertIn(f1, str(e))
|
||||
@@ -436,19 +444,16 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
|
||||
f1 = self._write_config(as_token="as_token", suffix="1")
|
||||
f2 = self._write_config(as_token="as_token", suffix="2")
|
||||
|
||||
config = Mock(
|
||||
app_service_config_files=[f1, f2], event_cache_size=1, password_providers=[]
|
||||
)
|
||||
hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
config=config,
|
||||
datastore=Mock(),
|
||||
federation_sender=Mock(),
|
||||
federation_client=Mock(),
|
||||
self.addCleanup, federation_sender=Mock(), federation_client=Mock()
|
||||
)
|
||||
|
||||
hs.config.app_service_config_files = [f1, f2]
|
||||
hs.config.event_cache_size = 1
|
||||
hs.config.password_providers = []
|
||||
|
||||
with self.assertRaises(ConfigError) as cm:
|
||||
ApplicationServiceStore(None, hs)
|
||||
ApplicationServiceStore(hs.get_db_conn(), hs)
|
||||
|
||||
e = cm.exception
|
||||
self.assertIn(f1, str(e))
|
||||
|
||||
@@ -20,11 +20,11 @@ from mock import Mock
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.engines import create_engine
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import TestHomeServer
|
||||
|
||||
|
||||
class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
@@ -51,7 +51,7 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
config = Mock()
|
||||
config.event_cache_size = 1
|
||||
config.database_config = {"name": "sqlite3"}
|
||||
hs = HomeServer(
|
||||
hs = TestHomeServer(
|
||||
"test",
|
||||
db_pool=self.db_pool,
|
||||
config=config,
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage.directory import DirectoryStore
|
||||
from synapse.types import RoomAlias, RoomID
|
||||
|
||||
from tests import unittest
|
||||
@@ -28,7 +27,7 @@ class DirectoryStoreTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
hs = yield setup_test_homeserver(self.addCleanup)
|
||||
|
||||
self.store = DirectoryStore(None, hs)
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.room = RoomID.from_string("!abcde:test")
|
||||
self.alias = RoomAlias.from_string("#my-room:test")
|
||||
|
||||
@@ -37,10 +37,10 @@ class EventFederationWorkerStoreTestCase(tests.unittest.TestCase):
|
||||
(
|
||||
"INSERT INTO events ("
|
||||
" room_id, event_id, type, depth, topological_ordering,"
|
||||
" content, processed, outlier) "
|
||||
"VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?)"
|
||||
" content, processed, outlier, stream_ordering) "
|
||||
"VALUES (?, ?, 'm.test', ?, ?, 'test', ?, ?, ?)"
|
||||
),
|
||||
(room_id, event_id, i, i, True, False),
|
||||
(room_id, event_id, i, i, True, False, i),
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
|
||||
@@ -13,25 +13,22 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import tests.unittest
|
||||
import tests.utils
|
||||
from tests.utils import setup_test_homeserver
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
FORTY_DAYS = 40 * 24 * 60 * 60
|
||||
|
||||
|
||||
class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(MonthlyActiveUsersTestCase, self).__init__(*args, **kwargs)
|
||||
class MonthlyActiveUsersTestCase(HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hs = yield setup_test_homeserver(self.addCleanup)
|
||||
self.store = self.hs.get_datastore()
|
||||
hs = self.setup_test_homeserver()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
# Advance the clock a bit
|
||||
reactor.advance(FORTY_DAYS)
|
||||
|
||||
return hs
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_initialise_reserved_users(self):
|
||||
self.hs.config.max_mau_value = 5
|
||||
user1 = "@user1:server"
|
||||
@@ -44,88 +41,101 @@ class MonthlyActiveUsersTestCase(tests.unittest.TestCase):
|
||||
]
|
||||
user_num = len(threepids)
|
||||
|
||||
yield self.store.register(user_id=user1, token="123", password_hash=None)
|
||||
|
||||
yield self.store.register(user_id=user2, token="456", password_hash=None)
|
||||
self.store.register(user_id=user1, token="123", password_hash=None)
|
||||
self.store.register(user_id=user2, token="456", password_hash=None)
|
||||
self.pump()
|
||||
|
||||
now = int(self.hs.get_clock().time_msec())
|
||||
yield self.store.user_add_threepid(user1, "email", user1_email, now, now)
|
||||
yield self.store.user_add_threepid(user2, "email", user2_email, now, now)
|
||||
yield self.store.initialise_reserved_users(threepids)
|
||||
self.store.user_add_threepid(user1, "email", user1_email, now, now)
|
||||
self.store.user_add_threepid(user2, "email", user2_email, now, now)
|
||||
self.store.initialise_reserved_users(threepids)
|
||||
self.pump()
|
||||
|
||||
active_count = yield self.store.get_monthly_active_count()
|
||||
active_count = self.store.get_monthly_active_count()
|
||||
|
||||
# Test total counts
|
||||
self.assertEquals(active_count, user_num)
|
||||
self.assertEquals(self.get_success(active_count), user_num)
|
||||
|
||||
# Test user is marked as active
|
||||
|
||||
timestamp = yield self.store.user_last_seen_monthly_active(user1)
|
||||
self.assertTrue(timestamp)
|
||||
timestamp = yield self.store.user_last_seen_monthly_active(user2)
|
||||
self.assertTrue(timestamp)
|
||||
timestamp = self.store.user_last_seen_monthly_active(user1)
|
||||
self.assertTrue(self.get_success(timestamp))
|
||||
timestamp = self.store.user_last_seen_monthly_active(user2)
|
||||
self.assertTrue(self.get_success(timestamp))
|
||||
|
||||
# Test that users are never removed from the db.
|
||||
self.hs.config.max_mau_value = 0
|
||||
|
||||
self.hs.get_clock().advance_time(FORTY_DAYS)
|
||||
self.reactor.advance(FORTY_DAYS)
|
||||
|
||||
yield self.store.reap_monthly_active_users()
|
||||
self.store.reap_monthly_active_users()
|
||||
self.pump()
|
||||
|
||||
active_count = yield self.store.get_monthly_active_count()
|
||||
self.assertEquals(active_count, user_num)
|
||||
active_count = self.store.get_monthly_active_count()
|
||||
self.assertEquals(self.get_success(active_count), user_num)
|
||||
|
||||
# Test that regalar users are removed from the db
|
||||
ru_count = 2
|
||||
yield self.store.upsert_monthly_active_user("@ru1:server")
|
||||
yield self.store.upsert_monthly_active_user("@ru2:server")
|
||||
active_count = yield self.store.get_monthly_active_count()
|
||||
self.store.upsert_monthly_active_user("@ru1:server")
|
||||
self.store.upsert_monthly_active_user("@ru2:server")
|
||||
self.pump()
|
||||
|
||||
self.assertEqual(active_count, user_num + ru_count)
|
||||
active_count = self.store.get_monthly_active_count()
|
||||
self.assertEqual(self.get_success(active_count), user_num + ru_count)
|
||||
self.hs.config.max_mau_value = user_num
|
||||
yield self.store.reap_monthly_active_users()
|
||||
self.store.reap_monthly_active_users()
|
||||
self.pump()
|
||||
|
||||
active_count = yield self.store.get_monthly_active_count()
|
||||
self.assertEquals(active_count, user_num)
|
||||
active_count = self.store.get_monthly_active_count()
|
||||
self.assertEquals(self.get_success(active_count), user_num)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_can_insert_and_count_mau(self):
|
||||
count = yield self.store.get_monthly_active_count()
|
||||
self.assertEqual(0, count)
|
||||
count = self.store.get_monthly_active_count()
|
||||
self.assertEqual(0, self.get_success(count))
|
||||
|
||||
yield self.store.upsert_monthly_active_user("@user:server")
|
||||
count = yield self.store.get_monthly_active_count()
|
||||
self.store.upsert_monthly_active_user("@user:server")
|
||||
self.pump()
|
||||
|
||||
self.assertEqual(1, count)
|
||||
count = self.store.get_monthly_active_count()
|
||||
self.assertEqual(1, self.get_success(count))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_last_seen_monthly_active(self):
|
||||
user_id1 = "@user1:server"
|
||||
user_id2 = "@user2:server"
|
||||
user_id3 = "@user3:server"
|
||||
|
||||
result = yield self.store.user_last_seen_monthly_active(user_id1)
|
||||
self.assertFalse(result == 0)
|
||||
yield self.store.upsert_monthly_active_user(user_id1)
|
||||
yield self.store.upsert_monthly_active_user(user_id2)
|
||||
result = yield self.store.user_last_seen_monthly_active(user_id1)
|
||||
self.assertTrue(result > 0)
|
||||
result = yield self.store.user_last_seen_monthly_active(user_id3)
|
||||
self.assertFalse(result == 0)
|
||||
result = self.store.user_last_seen_monthly_active(user_id1)
|
||||
self.assertFalse(self.get_success(result) == 0)
|
||||
|
||||
self.store.upsert_monthly_active_user(user_id1)
|
||||
self.store.upsert_monthly_active_user(user_id2)
|
||||
self.pump()
|
||||
|
||||
result = self.store.user_last_seen_monthly_active(user_id1)
|
||||
self.assertGreater(self.get_success(result), 0)
|
||||
|
||||
result = self.store.user_last_seen_monthly_active(user_id3)
|
||||
self.assertNotEqual(self.get_success(result), 0)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_reap_monthly_active_users(self):
|
||||
self.hs.config.max_mau_value = 5
|
||||
initial_users = 10
|
||||
for i in range(initial_users):
|
||||
yield self.store.upsert_monthly_active_user("@user%d:server" % i)
|
||||
count = yield self.store.get_monthly_active_count()
|
||||
self.assertTrue(count, initial_users)
|
||||
yield self.store.reap_monthly_active_users()
|
||||
count = yield self.store.get_monthly_active_count()
|
||||
self.assertEquals(count, initial_users - self.hs.config.max_mau_value)
|
||||
self.store.upsert_monthly_active_user("@user%d:server" % i)
|
||||
self.pump()
|
||||
|
||||
self.hs.get_clock().advance_time(FORTY_DAYS)
|
||||
yield self.store.reap_monthly_active_users()
|
||||
count = yield self.store.get_monthly_active_count()
|
||||
self.assertEquals(count, 0)
|
||||
count = self.store.get_monthly_active_count()
|
||||
self.assertTrue(self.get_success(count), initial_users)
|
||||
|
||||
self.store.reap_monthly_active_users()
|
||||
self.pump()
|
||||
count = self.store.get_monthly_active_count()
|
||||
self.assertEquals(
|
||||
self.get_success(count), initial_users - self.hs.config.max_mau_value
|
||||
)
|
||||
|
||||
self.reactor.advance(FORTY_DAYS)
|
||||
self.store.reap_monthly_active_users()
|
||||
self.pump()
|
||||
|
||||
count = self.store.get_monthly_active_count()
|
||||
self.assertEquals(self.get_success(count), 0)
|
||||
|
||||
@@ -16,19 +16,18 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage.presence import PresenceStore
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import MockClock, setup_test_homeserver
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
|
||||
class PresenceStoreTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
hs = yield setup_test_homeserver(self.addCleanup, clock=MockClock())
|
||||
hs = yield setup_test_homeserver(self.addCleanup)
|
||||
|
||||
self.store = PresenceStore(None, hs)
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.u_apple = UserID.from_string("@apple:test")
|
||||
self.u_banana = UserID.from_string("@banana:test")
|
||||
|
||||
@@ -28,7 +28,7 @@ class ProfileStoreTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
hs = yield setup_test_homeserver(self.addCleanup)
|
||||
|
||||
self.store = ProfileStore(None, hs)
|
||||
self.store = ProfileStore(hs.get_db_conn(), hs)
|
||||
|
||||
self.u_frank = UserID.from_string("@frank:test")
|
||||
|
||||
|
||||
106
tests/storage/test_purge.py
Normal file
106
tests/storage/test_purge.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.rest.client.v1 import room
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
|
||||
class PurgeTests(HomeserverTestCase):
|
||||
|
||||
user_id = "@red:server"
|
||||
servlets = [room.register_servlets]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver("server", http_client=None)
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
def test_purge(self):
|
||||
"""
|
||||
Purging a room will delete everything before the topological point.
|
||||
"""
|
||||
# Send four messages to the room
|
||||
first = self.helper.send(self.room_id, body="test1")
|
||||
second = self.helper.send(self.room_id, body="test2")
|
||||
third = self.helper.send(self.room_id, body="test3")
|
||||
last = self.helper.send(self.room_id, body="test4")
|
||||
|
||||
storage = self.hs.get_datastore()
|
||||
|
||||
# Get the topological token
|
||||
event = storage.get_topological_token_for_event(last["event_id"])
|
||||
self.pump()
|
||||
event = self.successResultOf(event)
|
||||
|
||||
# Purge everything before this topological token
|
||||
purge = storage.purge_history(self.room_id, event, True)
|
||||
self.pump()
|
||||
self.assertEqual(self.successResultOf(purge), None)
|
||||
|
||||
# Try and get the events
|
||||
get_first = storage.get_event(first["event_id"])
|
||||
get_second = storage.get_event(second["event_id"])
|
||||
get_third = storage.get_event(third["event_id"])
|
||||
get_last = storage.get_event(last["event_id"])
|
||||
self.pump()
|
||||
|
||||
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
|
||||
# and last is not.
|
||||
self.failureResultOf(get_first)
|
||||
self.failureResultOf(get_second)
|
||||
self.failureResultOf(get_third)
|
||||
self.successResultOf(get_last)
|
||||
|
||||
def test_purge_wont_delete_extrems(self):
|
||||
"""
|
||||
Purging a room will delete everything before the topological point.
|
||||
"""
|
||||
# Send four messages to the room
|
||||
first = self.helper.send(self.room_id, body="test1")
|
||||
second = self.helper.send(self.room_id, body="test2")
|
||||
third = self.helper.send(self.room_id, body="test3")
|
||||
last = self.helper.send(self.room_id, body="test4")
|
||||
|
||||
storage = self.hs.get_datastore()
|
||||
|
||||
# Set the topological token higher than it should be
|
||||
event = storage.get_topological_token_for_event(last["event_id"])
|
||||
self.pump()
|
||||
event = self.successResultOf(event)
|
||||
event = "t{}-{}".format(
|
||||
*list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
|
||||
)
|
||||
|
||||
# Purge everything before this topological token
|
||||
purge = storage.purge_history(self.room_id, event, True)
|
||||
self.pump()
|
||||
f = self.failureResultOf(purge)
|
||||
self.assertIn("greater than forward", f.value.args[0])
|
||||
|
||||
# Try and get the events
|
||||
get_first = storage.get_event(first["event_id"])
|
||||
get_second = storage.get_event(second["event_id"])
|
||||
get_third = storage.get_event(third["event_id"])
|
||||
get_last = storage.get_event(last["event_id"])
|
||||
self.pump()
|
||||
|
||||
# Nothing is deleted.
|
||||
self.successResultOf(get_first)
|
||||
self.successResultOf(get_second)
|
||||
self.successResultOf(get_third)
|
||||
self.successResultOf(get_last)
|
||||
@@ -30,7 +30,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hs = yield setup_test_homeserver(self.addCleanup)
|
||||
self.store = UserDirectoryStore(None, self.hs)
|
||||
self.store = UserDirectoryStore(self.hs.get_db_conn(), self.hs)
|
||||
|
||||
# alice and bob are both in !room_id. bobby is not but shares
|
||||
# a homeserver with alice.
|
||||
|
||||
@@ -14,12 +14,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import GroupID, RoomAlias, UserID
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import TestHomeServer
|
||||
|
||||
mock_homeserver = HomeServer(hostname="my.domain")
|
||||
mock_homeserver = TestHomeServer(hostname="my.domain")
|
||||
|
||||
|
||||
class UserIDTestCase(unittest.TestCase):
|
||||
|
||||
@@ -96,7 +96,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
|
||||
events_to_filter.append(evt)
|
||||
|
||||
# the erasey user gets erased
|
||||
self.hs.get_datastore().mark_user_erased("@erased:local_hs")
|
||||
yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
|
||||
|
||||
# ... and the filtering happens.
|
||||
filtered = yield filter_events_for_server(
|
||||
|
||||
@@ -22,6 +22,7 @@ from canonicaljson import json
|
||||
|
||||
import twisted
|
||||
import twisted.logger
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.trial import unittest
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
@@ -151,6 +152,7 @@ class HomeserverTestCase(TestCase):
|
||||
hijack_auth (bool): Whether to hijack auth to return the user specified
|
||||
in user_id.
|
||||
"""
|
||||
|
||||
servlets = []
|
||||
hijack_auth = True
|
||||
|
||||
@@ -279,3 +281,15 @@ class HomeserverTestCase(TestCase):
|
||||
kwargs = dict(kwargs)
|
||||
kwargs.update(self._hs_args)
|
||||
return setup_test_homeserver(self.addCleanup, *args, **kwargs)
|
||||
|
||||
def pump(self, by=0.0):
|
||||
"""
|
||||
Pump the reactor enough that Deferreds will fire.
|
||||
"""
|
||||
self.reactor.pump([by] * 100)
|
||||
|
||||
def get_success(self, d):
|
||||
if not isinstance(d, Deferred):
|
||||
return d
|
||||
self.pump()
|
||||
return self.successResultOf(d)
|
||||
|
||||
@@ -26,11 +26,12 @@ from twisted.internet import defer, reactor
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import CodeMessageException, cs_error
|
||||
from synapse.config.server import ServerConfig
|
||||
from synapse.federation.transport import server
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage import PostgresEngine
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.engines import PostgresEngine, create_engine
|
||||
from synapse.storage.prepare_database import (
|
||||
_get_or_create_schema_state,
|
||||
_setup_new_database,
|
||||
@@ -41,6 +42,7 @@ from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
|
||||
# set this to True to run the tests against postgres instead of sqlite.
|
||||
USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
|
||||
LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
|
||||
POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
|
||||
POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
|
||||
|
||||
@@ -92,10 +94,14 @@ def setupdb():
|
||||
atexit.register(_cleanup)
|
||||
|
||||
|
||||
class TestHomeServer(HomeServer):
|
||||
DATASTORE_CLASS = DataStore
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setup_test_homeserver(
|
||||
cleanup_func, name="test", datastore=None, config=None, reactor=None,
|
||||
homeserverToUse=HomeServer, **kargs
|
||||
homeserverToUse=TestHomeServer, **kargs
|
||||
):
|
||||
"""
|
||||
Setup a homeserver suitable for running tests against. Keyword arguments
|
||||
@@ -143,6 +149,8 @@ def setup_test_homeserver(
|
||||
config.max_mau_value = 50
|
||||
config.mau_limits_reserved_threepids = []
|
||||
config.admin_contact = None
|
||||
config.rc_messages_per_second = 10000
|
||||
config.rc_message_burst_count = 10000
|
||||
|
||||
# we need a sane default_room_version, otherwise attempts to create rooms will
|
||||
# fail.
|
||||
@@ -152,6 +160,11 @@ def setup_test_homeserver(
|
||||
# background, which upsets the test runner.
|
||||
config.update_user_directory = False
|
||||
|
||||
def is_threepid_reserved(threepid):
|
||||
return ServerConfig.is_threepid_reserved(config, threepid)
|
||||
|
||||
config.is_threepid_reserved.side_effect = is_threepid_reserved
|
||||
|
||||
config.use_frozen_dicts = True
|
||||
config.ldap_enabled = False
|
||||
|
||||
@@ -232,8 +245,9 @@ def setup_test_homeserver(
|
||||
cur.close()
|
||||
db_conn.close()
|
||||
|
||||
# Register the cleanup hook
|
||||
cleanup_func(cleanup)
|
||||
if not LEAVE_DB:
|
||||
# Register the cleanup hook
|
||||
cleanup_func(cleanup)
|
||||
|
||||
hs.setup()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user