diff --git a/docs/sphinx/conf.py b/docs/sphinx/conf.py index 15c19834fc..0b15bd8912 100644 --- a/docs/sphinx/conf.py +++ b/docs/sphinx/conf.py @@ -50,7 +50,7 @@ master_doc = 'index' # General information about the project. project = u'Synapse' -copyright = u'2014, TNG' +copyright = u'Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 1ab5593c6e..fa105bce72 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -81,22 +81,38 @@ class Config(object): def abspath(file_path): return os.path.abspath(file_path) if file_path else file_path + @classmethod + def path_exists(cls, file_path): + """Check if a file exists + + Unlike os.path.exists, this throws an exception if there is an error + checking if the file exists (for example, if there is a perms error on + the parent dir). + + Returns: + bool: True if the file exists; False if not. + """ + try: + os.stat(file_path) + return True + except OSError as e: + if e.errno != errno.ENOENT: + raise e + return False + @classmethod def check_file(cls, file_path, config_name): if file_path is None: raise ConfigError( "Missing config for %s." - " You must specify a path for the config file. You can " - "do this with the -c or --config-path option. " - "Adding --generate-config along with --server-name " - " will generate a config file at the given path." % (config_name,) ) - if not os.path.exists(file_path): + try: + os.stat(file_path) + except OSError as e: raise ConfigError( - "File %s config for %s doesn't exist." - " Try running again with --generate-config" - % (file_path, config_name,) + "Error accessing file '%s' (config for %s): %s" + % (file_path, config_name, e.strerror) ) return cls.abspath(file_path) @@ -248,7 +264,7 @@ class Config(object): " -c CONFIG-FILE\"" ) (config_path,) = config_files - if not os.path.exists(config_path): + if not cls.path_exists(config_path): if config_args.keys_directory: config_dir_path = config_args.keys_directory else: @@ -261,7 +277,7 @@ class Config(object): "Must specify a server_name to a generate config for." " Pass -H server.name." ) - if not os.path.exists(config_dir_path): + if not cls.path_exists(config_dir_path): os.makedirs(config_dir_path) with open(config_path, "wb") as config_file: config_bytes, config = obj.generate_config( diff --git a/synapse/config/key.py b/synapse/config/key.py index 6ee643793e..4b8fc063d0 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -118,10 +118,9 @@ class KeyConfig(Config): signing_keys = self.read_file(signing_key_path, "signing_key") try: return read_signing_keys(signing_keys.splitlines(True)) - except Exception: + except Exception as e: raise ConfigError( - "Error reading signing_key." - " Try running again with --generate-config" + "Error reading signing_key: %s" % (str(e)) ) def read_old_signing_keys(self, old_signing_keys): @@ -141,7 +140,8 @@ class KeyConfig(Config): def generate_files(self, config): signing_key_path = config["signing_key_path"] - if not os.path.exists(signing_key_path): + + if not self.path_exists(signing_key_path): with open(signing_key_path, "w") as signing_key_file: key_id = "a_" + random_string(4) write_signing_keys( diff --git a/synapse/config/registration.py b/synapse/config/registration.py index f7e03c4cde..ef917fc9f2 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -41,6 +41,8 @@ class RegistrationConfig(Config): self.allow_guest_access and config.get("invite_3pid_guest", False) ) + self.auto_join_rooms = config.get("auto_join_rooms", []) + def default_config(self, **kwargs): registration_shared_secret = random_string_with_symbols(50) @@ -70,6 +72,11 @@ class RegistrationConfig(Config): - matrix.org - vector.im - riot.im + + # Users who register on this homeserver will automatically be joined + # to these rooms + #auto_join_rooms: + # - "#example:example.com" """ % locals() def add_arguments(self, parser): diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 2c6f57168e..6baa474931 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -70,7 +70,19 @@ class ContentRepositoryConfig(Config): self.max_upload_size = self.parse_size(config["max_upload_size"]) self.max_image_pixels = self.parse_size(config["max_image_pixels"]) self.max_spider_size = self.parse_size(config["max_spider_size"]) + self.media_store_path = self.ensure_directory(config["media_store_path"]) + + self.backup_media_store_path = config.get("backup_media_store_path") + if self.backup_media_store_path: + self.backup_media_store_path = self.ensure_directory( + self.backup_media_store_path + ) + + self.synchronous_backup_media_store = config.get( + "synchronous_backup_media_store", False + ) + self.uploads_path = self.ensure_directory(config["uploads_path"]) self.dynamic_thumbnails = config["dynamic_thumbnails"] self.thumbnail_requirements = parse_thumbnail_requirements( @@ -115,6 +127,14 @@ class ContentRepositoryConfig(Config): # Directory where uploaded images and attachments are stored. media_store_path: "%(media_store)s" + # A secondary directory where uploaded images and attachments are + # stored as a backup. + # backup_media_store_path: "%(media_store)s" + + # Whether to wait for successful write to backup media store before + # returning successfully. + # synchronous_backup_media_store: false + # Directory where in-progress uploads are stored. uploads_path: "%(uploads_path)s" diff --git a/synapse/config/tls.py b/synapse/config/tls.py index e081840a83..247f18f454 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -126,7 +126,7 @@ class TlsConfig(Config): tls_private_key_path = config["tls_private_key_path"] tls_dh_params_path = config["tls_dh_params_path"] - if not os.path.exists(tls_private_key_path): + if not self.path_exists(tls_private_key_path): with open(tls_private_key_path, "w") as private_key_file: tls_private_key = crypto.PKey() tls_private_key.generate_key(crypto.TYPE_RSA, 2048) @@ -141,7 +141,7 @@ class TlsConfig(Config): crypto.FILETYPE_PEM, private_key_pem ) - if not os.path.exists(tls_certificate_path): + if not self.path_exists(tls_certificate_path): with open(tls_certificate_path, "w") as certificate_file: cert = crypto.X509() subject = cert.get_subject() @@ -159,7 +159,7 @@ class TlsConfig(Config): certificate_file.write(cert_pem) - if not os.path.exists(tls_dh_params_path): + if not self.path_exists(tls_dh_params_path): if GENERATE_DH_PARAMS: subprocess.check_call([ "openssl", "dhparam", diff --git a/synapse/event_auth.py b/synapse/event_auth.py index 4096c606f1..9e746a28bf 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -470,14 +470,14 @@ def _check_power_levels(event, auth_events): ("invite", None), ] - old_list = current_state.content.get("users") + old_list = current_state.content.get("users", {}) for user in set(old_list.keys() + user_list.keys()): levels_to_check.append( (user, "users") ) - old_list = current_state.content.get("events") - new_list = event.content.get("events") + old_list = current_state.content.get("events", {}) + new_list = event.content.get("events", {}) for ev_id in set(old_list.keys() + new_list.keys()): levels_to_check.append( (ev_id, "events") diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index a8034bddc6..e15228e70b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -192,7 +192,6 @@ class FederationServer(FederationBase): pdu_results[event_id] = {} except FederationError as e: logger.warn("Error handling PDU %s: %s", event_id, e) - self.send_failure(e, transaction.origin) pdu_results[event_id] = {"error": str(e)} except Exception as e: pdu_results[event_id] = {"error": str(e)} diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index f96561c1fe..125d8f3598 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -550,6 +550,19 @@ class TransportLayerClient(object): ignore_backoff=True, ) + @log_function + def get_invited_users_in_group(self, destination, group_id, requester_user_id): + """Get users that have been invited to a group + """ + path = PREFIX + "/groups/%s/invited_users" % (group_id,) + + return self.client.get_json( + destination=destination, + path=path, + args={"requester_user_id": requester_user_id}, + ignore_backoff=True, + ) + @log_function def accept_group_invite(self, destination, group_id, user_id, content): """Accept a group invite diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index c7565e0737..09b97138c3 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -721,6 +721,24 @@ class FederationGroupsUsersServlet(BaseFederationServlet): defer.returnValue((200, new_content)) +class FederationGroupsInvitedUsersServlet(BaseFederationServlet): + """Get the users that have been invited to a group + """ + PATH = "/groups/(?P[^/]*)/invited_users$" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query, group_id): + requester_user_id = parse_string_from_args(query, "requester_user_id") + if get_domain_from_id(requester_user_id) != origin: + raise SynapseError(403, "requester_user_id doesn't match origin") + + new_content = yield self.handler.get_invited_users_in_group( + group_id, requester_user_id + ) + + defer.returnValue((200, new_content)) + + class FederationGroupsInviteServlet(BaseFederationServlet): """Ask a group server to invite someone to the group """ @@ -817,7 +835,7 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): def on_POST(self, origin, content, query, group_id, user_id): # We don't need to check auth here as we check the attestation signatures - new_content = yield self.handler.on_renew_group_attestation( + new_content = yield self.handler.on_renew_attestation( origin, content, group_id, user_id ) @@ -1109,12 +1127,12 @@ ROOM_LIST_CLASSES = ( PublicRoomList, ) - GROUP_SERVER_SERVLET_CLASSES = ( FederationGroupsProfileServlet, FederationGroupsSummaryServlet, FederationGroupsRoomsServlet, FederationGroupsUsersServlet, + FederationGroupsInvitedUsersServlet, FederationGroupsInviteServlet, FederationGroupsAcceptInviteServlet, FederationGroupsRemoveUserServlet, diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index 5ef7a12cb7..b751cf5e43 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -90,6 +90,7 @@ class GroupAttestionRenewer(object): self.assestations = hs.get_groups_attestation_signing() self.transport_client = hs.get_federation_transport_client() self.is_mine_id = hs.is_mine_id + self.attestations = hs.get_groups_attestation_signing() self._renew_attestations_loop = self.clock.looping_call( self._renew_attestations, 30 * 60 * 1000, @@ -126,10 +127,10 @@ class GroupAttestionRenewer(object): ) @defer.inlineCallbacks - def _renew_attestation(self, group_id, user_id): + def _renew_attestation(group_id, user_id): attestation = self.attestations.create_attestation(group_id, user_id) - if self.hs.is_mine_id(group_id): + if self.is_mine_id(group_id): destination = get_domain_from_id(user_id) else: destination = get_domain_from_id(group_id) diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index 1083bc2990..a3a500b9d6 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -420,6 +420,40 @@ class GroupsServerHandler(object): "total_user_count_estimate": len(user_results), }) + @defer.inlineCallbacks + def get_invited_users_in_group(self, group_id, requester_user_id): + """Get the users that have been invited to a group as seen by requester_user_id. + + The ordering is arbitrary at the moment + """ + + yield self.check_group_is_ours(group_id, and_exists=True) + + is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id) + + if not is_user_in_group: + raise SynapseError(403, "User not in group") + + invited_users = yield self.store.get_invited_users_in_group(group_id) + + user_profiles = [] + + for user_id in invited_users: + user_profile = { + "user_id": user_id + } + try: + profile = yield self.profile_handler.get_profile_from_cache(user_id) + user_profile.update(profile) + except Exception as e: + logger.warn("Error getting profile for %s: %s", user_id, e) + user_profiles.append(user_profile) + + defer.returnValue({ + "chunk": user_profiles, + "total_user_count_estimate": len(invited_users), + }) + @defer.inlineCallbacks def get_rooms_in_group(self, group_id, requester_user_id): """Get the rooms in group as seen by requester_user_id diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index 3b676d46bd..6699d0888f 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -68,6 +68,8 @@ class GroupsLocalHandler(object): update_group_profile = _create_rerouter("update_group_profile") get_rooms_in_group = _create_rerouter("get_rooms_in_group") + get_invited_users_in_group = _create_rerouter("get_invited_users_in_group") + add_room_to_group = _create_rerouter("add_room_to_group") remove_room_from_group = _create_rerouter("remove_room_from_group") @@ -313,8 +315,11 @@ class GroupsLocalHandler(object): self.notifier.on_new_event( "groups_key", token, users=[user_id], ) - - user_profile = yield self.profile_handler.get_profile(user_id) + try: + user_profile = yield self.profile_handler.get_profile(user_id) + except Exception as e: + logger.warn("No profile for user %s: %s", user_id, e) + user_profile = {} defer.returnValue({"state": "invite", "user_profile": user_profile}) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py index 9dce99ebec..7a18afe5f9 100644 --- a/synapse/push/baserules.py +++ b/synapse/push/baserules.py @@ -245,7 +245,7 @@ BASE_APPEND_OVERRIDE_RULES = [ { 'kind': 'event_match', 'key': 'content.body', - 'pattern': '*@room*', + 'pattern': '@room', '_id': '_roomnotif_content', }, { diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 8f3ce15b02..d11bccc1da 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -371,6 +371,27 @@ class GroupUsersServlet(RestServlet): defer.returnValue((200, result)) +class GroupInvitedUsersServlet(RestServlet): + """Get users invited to a group + """ + PATTERNS = client_v2_patterns("/groups/(?P[^/]*)/invited_users$") + + def __init__(self, hs): + super(GroupInvitedUsersServlet, self).__init__() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.groups_handler = hs.get_groups_local_handler() + + @defer.inlineCallbacks + def on_GET(self, request, group_id): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + + result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id) + + defer.returnValue((200, result)) + + class GroupCreateServlet(RestServlet): """Create a group """ @@ -674,6 +695,7 @@ class GroupsForUserServlet(RestServlet): def register_servlets(hs, http_server): GroupServlet(hs).register(http_server) GroupSummaryServlet(hs).register(http_server) + GroupInvitedUsersServlet(hs).register(http_server) GroupUsersServlet(hs).register(http_server) GroupRoomServlet(hs).register(http_server) GroupCreateServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 1421c18152..d9a8cdbbb5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -17,8 +17,10 @@ from twisted.internet import defer import synapse +import synapse.types from synapse.api.auth import get_access_token_from_request, has_access_token from synapse.api.constants import LoginType +from synapse.types import RoomID, RoomAlias from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string @@ -170,6 +172,7 @@ class RegisterRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler + self.room_member_handler = hs.get_handlers().room_member_handler self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() @@ -340,6 +343,14 @@ class RegisterRestServlet(RestServlet): generate_token=False, ) + # auto-join the user to any rooms we're supposed to dump them into + fake_requester = synapse.types.create_requester(registered_user_id) + for r in self.hs.config.auto_join_rooms: + try: + yield self._join_user_to_room(fake_requester, r) + except Exception as e: + logger.error("Failed to join new user to %r: %r", r, e) + # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) self.auth_handler.set_session_data( @@ -372,6 +383,29 @@ class RegisterRestServlet(RestServlet): def on_OPTIONS(self, _): return 200, {} + @defer.inlineCallbacks + def _join_user_to_room(self, requester, room_identifier): + room_id = None + if RoomID.is_valid(room_identifier): + room_id = room_identifier + elif RoomAlias.is_valid(room_identifier): + room_alias = RoomAlias.from_string(room_identifier) + room_id, remote_room_hosts = ( + yield self.room_member_handler.lookup_room_alias(room_alias) + ) + room_id = room_id.to_string() + else: + raise SynapseError(400, "%s was not legal room ID or room alias" % ( + room_identifier, + )) + + yield self.room_member_handler.update_membership( + requester=requester, + target=requester.user, + room_id=room_id, + action="join", + ) + @defer.inlineCallbacks def _do_appservice_registration(self, username, as_token, body): user_id = yield self.registration_handler.appservice_register( diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index d5cec10127..d5164e47e0 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -15,80 +15,111 @@ import os import re +import functools NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") +def _wrap_in_base_path(func): + """Takes a function that returns a relative path and turns it into an + absolute path based on the location of the primary media store + """ + @functools.wraps(func) + def _wrapped(self, *args, **kwargs): + path = func(self, *args, **kwargs) + return os.path.join(self.base_path, path) + + return _wrapped + + class MediaFilePaths(object): + """Describes where files are stored on disk. - def __init__(self, base_path): - self.base_path = base_path + Most of the functions have a `*_rel` variant which returns a file path that + is relative to the base media store path. This is mainly used when we want + to write to the backup media store (when one is configured) + """ - def default_thumbnail(self, default_top_level, default_sub_type, width, - height, content_type, method): + def __init__(self, primary_base_path): + self.base_path = primary_base_path + + def default_thumbnail_rel(self, default_top_level, default_sub_type, width, + height, content_type, method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) return os.path.join( - self.base_path, "default_thumbnails", default_top_level, + "default_thumbnails", default_top_level, default_sub_type, file_name ) - def local_media_filepath(self, media_id): + default_thumbnail = _wrap_in_base_path(default_thumbnail_rel) + + def local_media_filepath_rel(self, media_id): return os.path.join( - self.base_path, "local_content", + "local_content", media_id[0:2], media_id[2:4], media_id[4:] ) - def local_media_thumbnail(self, media_id, width, height, content_type, - method): + local_media_filepath = _wrap_in_base_path(local_media_filepath_rel) + + def local_media_thumbnail_rel(self, media_id, width, height, content_type, + method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) return os.path.join( - self.base_path, "local_thumbnails", + "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name ) - def remote_media_filepath(self, server_name, file_id): + local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel) + + def remote_media_filepath_rel(self, server_name, file_id): return os.path.join( - self.base_path, "remote_content", server_name, + "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:] ) - def remote_media_thumbnail(self, server_name, file_id, width, height, - content_type, method): + remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel) + + def remote_media_thumbnail_rel(self, server_name, file_id, width, height, + content_type, method): top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) return os.path.join( - self.base_path, "remote_thumbnail", server_name, + "remote_thumbnail", server_name, file_id[0:2], file_id[2:4], file_id[4:], file_name ) + remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) + def remote_media_thumbnail_dir(self, server_name, file_id): return os.path.join( self.base_path, "remote_thumbnail", server_name, file_id[0:2], file_id[2:4], file_id[4:], ) - def url_cache_filepath(self, media_id): + def url_cache_filepath_rel(self, media_id): if NEW_FORMAT_ID_RE.match(media_id): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf return os.path.join( - self.base_path, "url_cache", + "url_cache", media_id[:10], media_id[11:] ) else: return os.path.join( - self.base_path, "url_cache", + "url_cache", media_id[0:2], media_id[2:4], media_id[4:], ) + url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel) + def url_cache_filepath_dirs_to_delete(self, media_id): "The dirs to try and remove if we delete the media_id file" if NEW_FORMAT_ID_RE.match(media_id): @@ -110,8 +141,8 @@ class MediaFilePaths(object): ), ] - def url_cache_thumbnail(self, media_id, width, height, content_type, - method): + def url_cache_thumbnail_rel(self, media_id, width, height, content_type, + method): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf @@ -122,17 +153,19 @@ class MediaFilePaths(object): if NEW_FORMAT_ID_RE.match(media_id): return os.path.join( - self.base_path, "url_cache_thumbnails", + "url_cache_thumbnails", media_id[:10], media_id[11:], file_name ) else: return os.path.join( - self.base_path, "url_cache_thumbnails", + "url_cache_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name ) + url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) + def url_cache_thumbnail_directory(self, media_id): # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 0ea1248ce6..6b50b45b1f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -33,7 +33,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \ from synapse.util.async import Linearizer from synapse.util.stringutils import is_ascii -from synapse.util.logcontext import preserve_context_over_fn +from synapse.util.logcontext import make_deferred_yieldable, preserve_fn from synapse.util.retryutils import NotRetryingDestination import os @@ -59,7 +59,14 @@ class MediaRepository(object): self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels - self.filepaths = MediaFilePaths(hs.config.media_store_path) + + self.primary_base_path = hs.config.media_store_path + self.filepaths = MediaFilePaths(self.primary_base_path) + + self.backup_base_path = hs.config.backup_media_store_path + + self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store + self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements @@ -87,18 +94,86 @@ class MediaRepository(object): if not os.path.exists(dirname): os.makedirs(dirname) + @staticmethod + def _write_file_synchronously(source, fname): + """Write `source` to the path `fname` synchronously. Should be called + from a thread. + + Args: + source: A file like object to be written + fname (str): Path to write to + """ + MediaRepository._makedirs(fname) + source.seek(0) # Ensure we read from the start of the file + with open(fname, "wb") as f: + shutil.copyfileobj(source, f) + + @defer.inlineCallbacks + def write_to_file_and_backup(self, source, path): + """Write `source` to the on disk media store, and also the backup store + if configured. + + Args: + source: A file like object that should be written + path (str): Relative path to write file to + + Returns: + Deferred[str]: the file path written to in the primary media store + """ + fname = os.path.join(self.primary_base_path, path) + + # Write to the main repository + yield make_deferred_yieldable(threads.deferToThread( + self._write_file_synchronously, source, fname, + )) + + # Write to backup repository + yield self.copy_to_backup(path) + + defer.returnValue(fname) + + @defer.inlineCallbacks + def copy_to_backup(self, path): + """Copy a file from the primary to backup media store, if configured. + + Args: + path(str): Relative path to write file to + """ + if self.backup_base_path: + primary_fname = os.path.join(self.primary_base_path, path) + backup_fname = os.path.join(self.backup_base_path, path) + + # We can either wait for successful writing to the backup repository + # or write in the background and immediately return + if self.synchronous_backup_media_store: + yield make_deferred_yieldable(threads.deferToThread( + shutil.copyfile, primary_fname, backup_fname, + )) + else: + preserve_fn(threads.deferToThread)( + shutil.copyfile, primary_fname, backup_fname, + ) + @defer.inlineCallbacks def create_content(self, media_type, upload_name, content, content_length, auth_user): + """Store uploaded content for a local user and return the mxc URL + + Args: + media_type(str): The content type of the file + upload_name(str): The name of the file + content: A file like object that is the content to store + content_length(int): The length of the content + auth_user(str): The user_id of the uploader + + Returns: + Deferred[str]: The mxc url of the stored content + """ media_id = random_string(24) - fname = self.filepaths.local_media_filepath(media_id) - self._makedirs(fname) - - # This shouldn't block for very long because the content will have - # already been uploaded at this point. - with open(fname, "wb") as f: - f.write(content) + fname = yield self.write_to_file_and_backup( + content, self.filepaths.local_media_filepath_rel(media_id) + ) logger.info("Stored local media in file %r", fname) @@ -115,7 +190,7 @@ class MediaRepository(object): "media_length": content_length, } - yield self._generate_local_thumbnails(media_id, media_info) + yield self._generate_thumbnails(None, media_id, media_info) defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) @@ -148,9 +223,10 @@ class MediaRepository(object): def _download_remote_file(self, server_name, media_id): file_id = random_string(24) - fname = self.filepaths.remote_media_filepath( + fpath = self.filepaths.remote_media_filepath_rel( server_name, file_id ) + fname = os.path.join(self.primary_base_path, fpath) self._makedirs(fname) try: @@ -192,6 +268,8 @@ class MediaRepository(object): server_name, media_id) raise SynapseError(502, "Failed to fetch remote media") + yield self.copy_to_backup(fpath) + media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() @@ -244,7 +322,7 @@ class MediaRepository(object): "filesystem_id": file_id, } - yield self._generate_remote_thumbnails( + yield self._generate_thumbnails( server_name, media_id, media_info ) @@ -253,9 +331,8 @@ class MediaRepository(object): def _get_thumbnail_requirements(self, media_type): return self.thumbnail_requirements.get(media_type, ()) - def _generate_thumbnail(self, input_path, t_path, t_width, t_height, + def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type): - thumbnailer = Thumbnailer(input_path) m_width = thumbnailer.width m_height = thumbnailer.height @@ -267,72 +344,105 @@ class MediaRepository(object): return if t_method == "crop": - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) + t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": t_width, t_height = thumbnailer.aspect(t_width, t_height) t_width = min(m_width, t_width) t_height = min(m_height, t_height) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) + t_byte_source = thumbnailer.scale(t_width, t_height, t_type) else: - t_len = None + t_byte_source = None - return t_len + return t_byte_source @defer.inlineCallbacks def generate_local_exact_thumbnail(self, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.local_media_filepath(media_id) - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - - t_len = yield preserve_context_over_fn( - threads.deferToThread, + thumbnailer = Thumbnailer(input_path) + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type - ) + thumbnailer, t_width, t_height, t_method, t_type + )) + + if t_byte_source: + try: + output_path = yield self.write_to_file_and_backup( + t_byte_source, + self.filepaths.local_media_thumbnail_rel( + media_id, t_width, t_height, t_type, t_method + ) + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) - if t_len: yield self.store.store_local_thumbnail( media_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue(t_path) + defer.returnValue(output_path) @defer.inlineCallbacks def generate_remote_exact_thumbnail(self, server_name, file_id, media_id, t_width, t_height, t_method, t_type): input_path = self.filepaths.remote_media_filepath(server_name, file_id) - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - - t_len = yield preserve_context_over_fn( - threads.deferToThread, + thumbnailer = Thumbnailer(input_path) + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( self._generate_thumbnail, - input_path, t_path, t_width, t_height, t_method, t_type - ) + thumbnailer, t_width, t_height, t_method, t_type + )) + + if t_byte_source: + try: + output_path = yield self.write_to_file_and_backup( + t_byte_source, + self.filepaths.remote_media_thumbnail_rel( + server_name, file_id, t_width, t_height, t_type, t_method + ) + ) + finally: + t_byte_source.close() + + logger.info("Stored thumbnail in file %r", output_path) + + t_len = os.path.getsize(output_path) - if t_len: yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len ) - defer.returnValue(t_path) + defer.returnValue(output_path) @defer.inlineCallbacks - def _generate_local_thumbnails(self, media_id, media_info, url_cache=False): + def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False): + """Generate and store thumbnails for an image. + + Args: + server_name(str|None): The server name if remote media, else None if local + media_id(str) + media_info(dict) + url_cache(bool): If we are thumbnailing images downloaded for the URL cache, + used exclusively by the url previewer + + Returns: + Deferred[dict]: Dict with "width" and "height" keys of original image + """ media_type = media_info["media_type"] + file_id = media_info.get("filesystem_id") requirements = self._get_thumbnail_requirements(media_type) if not requirements: return - if url_cache: + if server_name: + input_path = self.filepaths.remote_media_filepath(server_name, file_id) + elif url_cache: input_path = self.filepaths.url_cache_filepath(media_id) else: input_path = self.filepaths.local_media_filepath(media_id) @@ -348,135 +458,72 @@ class MediaRepository(object): ) return - local_thumbnails = [] + # We deduplicate the thumbnail sizes by ignoring the cropped versions if + # they have the same dimensions of a scaled one. + thumbnails = {} + for r_width, r_height, r_method, r_type in requirements: + if r_method == "crop": + thumbnails.setdefault((r_width, r_height, r_type), r_method) + elif r_method == "scale": + t_width, t_height = thumbnailer.aspect(r_width, r_height) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + thumbnails[(t_width, t_height, r_type)] = r_method - def generate_thumbnails(): - scales = set() - crops = set() - for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - if url_cache: - t_path = self.filepaths.url_cache_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - else: - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len - )) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - if url_cache: - t_path = self.filepaths.url_cache_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - else: - t_path = self.filepaths.local_media_thumbnail( - media_id, t_width, t_height, t_type, t_method - ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - local_thumbnails.append(( - media_id, t_width, t_height, t_type, t_method, t_len - )) - - yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - - for l in local_thumbnails: - yield self.store.store_local_thumbnail(*l) - - defer.returnValue({ - "width": m_width, - "height": m_height, - }) - - @defer.inlineCallbacks - def _generate_remote_thumbnails(self, server_name, media_id, media_info): - media_type = media_info["media_type"] - file_id = media_info["filesystem_id"] - requirements = self._get_thumbnail_requirements(media_type) - if not requirements: - return - - remote_thumbnails = [] - - input_path = self.filepaths.remote_media_filepath(server_name, file_id) - thumbnailer = Thumbnailer(input_path) - m_width = thumbnailer.width - m_height = thumbnailer.height - - def generate_thumbnails(): - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, m_height, self.max_image_pixels - ) - return - - scales = set() - crops = set() - for r_width, r_height, r_method, r_type in requirements: - if r_method == "scale": - t_width, t_height = thumbnailer.aspect(r_width, r_height) - scales.add(( - min(m_width, t_width), min(m_height, t_height), r_type, - )) - elif r_method == "crop": - crops.add((r_width, r_height, r_type)) - - for t_width, t_height, t_type in scales: - t_method = "scale" - t_path = self.filepaths.remote_media_thumbnail( + # Now we generate the thumbnails for each dimension, store it + for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): + # Work out the correct file name for thumbnail + if server_name: + file_path = self.filepaths.remote_media_thumbnail_rel( server_name, file_id, t_width, t_height, t_type, t_method ) - self._makedirs(t_path) - t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ - server_name, media_id, file_id, - t_width, t_height, t_type, t_method, t_len - ]) - - for t_width, t_height, t_type in crops: - if (t_width, t_height, t_type) in scales: - # If the aspect ratio of the cropped thumbnail matches a purely - # scaled one then there is no point in calculating a separate - # thumbnail. - continue - t_method = "crop" - t_path = self.filepaths.remote_media_thumbnail( - server_name, file_id, t_width, t_height, t_type, t_method + elif url_cache: + file_path = self.filepaths.url_cache_thumbnail_rel( + media_id, t_width, t_height, t_type, t_method ) - self._makedirs(t_path) - t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) - remote_thumbnails.append([ + else: + file_path = self.filepaths.local_media_thumbnail_rel( + media_id, t_width, t_height, t_type, t_method + ) + + # Generate the thumbnail + if t_method == "crop": + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + thumbnailer.crop, + t_width, t_height, t_type, + )) + elif t_method == "scale": + t_byte_source = yield make_deferred_yieldable(threads.deferToThread( + thumbnailer.scale, + t_width, t_height, t_type, + )) + else: + logger.error("Unrecognized method: %r", t_method) + continue + + if not t_byte_source: + continue + + try: + # Write to disk + output_path = yield self.write_to_file_and_backup( + t_byte_source, file_path, + ) + finally: + t_byte_source.close() + + t_len = os.path.getsize(output_path) + + # Write to database + if server_name: + yield self.store.store_remote_media_thumbnail( server_name, media_id, file_id, t_width, t_height, t_type, t_method, t_len - ]) - - yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails) - - for r in remote_thumbnails: - yield self.store.store_remote_media_thumbnail(*r) + ) + else: + yield self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) defer.returnValue({ "width": m_width, @@ -497,6 +544,8 @@ class MediaRepository(object): logger.info("Deleting: %r", key) + # TODO: Should we delete from the backup store + with (yield self.remote_media_linearizer.queue(key)): full_path = self.filepaths.remote_media_filepath(origin, file_id) try: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 895b480d5c..2a3e37fdf4 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -59,6 +59,7 @@ class PreviewUrlResource(Resource): self.store = hs.get_datastore() self.client = SpiderHttpClient(hs) self.media_repo = media_repo + self.primary_base_path = media_repo.primary_base_path self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist @@ -170,8 +171,8 @@ class PreviewUrlResource(Resource): logger.debug("got media_info of '%s'" % media_info) if _is_media(media_info['media_type']): - dims = yield self.media_repo._generate_local_thumbnails( - media_info['filesystem_id'], media_info, url_cache=True, + dims = yield self.media_repo._generate_thumbnails( + None, media_info['filesystem_id'], media_info, url_cache=True, ) og = { @@ -216,8 +217,8 @@ class PreviewUrlResource(Resource): if _is_media(image_info['media_type']): # TODO: make sure we don't choke on white-on-transparent images - dims = yield self.media_repo._generate_local_thumbnails( - image_info['filesystem_id'], image_info, url_cache=True, + dims = yield self.media_repo._generate_thumbnails( + None, image_info['filesystem_id'], image_info, url_cache=True, ) if dims: og["og:image:width"] = dims['width'] @@ -262,7 +263,8 @@ class PreviewUrlResource(Resource): file_id = datetime.date.today().isoformat() + '_' + random_string(16) - fname = self.filepaths.url_cache_filepath(file_id) + fpath = self.filepaths.url_cache_filepath_rel(file_id) + fname = os.path.join(self.primary_base_path, fpath) self.media_repo._makedirs(fname) try: @@ -273,6 +275,8 @@ class PreviewUrlResource(Resource): ) # FIXME: pass through 404s and other error messages nicely + yield self.media_repo.copy_to_backup(fpath) + media_type = headers["Content-Type"][0] time_now_ms = self.clock.time_msec() @@ -338,6 +342,9 @@ class PreviewUrlResource(Resource): def _expire_url_cache_data(self): """Clean up expired url cache content, media and thumbnails. """ + + # TODO: Delete from backup media store + now = self.clock.time_msec() # First we delete expired url cache entries diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 3868d4f65f..e1ee535b9a 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -50,12 +50,16 @@ class Thumbnailer(object): else: return ((max_height * self.width) // self.height, max_height) - def scale(self, output_path, width, height, output_type): - """Rescales the image to the given dimensions""" - scaled = self.image.resize((width, height), Image.ANTIALIAS) - return self.save_image(scaled, output_type, output_path) + def scale(self, width, height, output_type): + """Rescales the image to the given dimensions. - def crop(self, output_path, width, height, output_type): + Returns: + BytesIO: the bytes of the encoded image ready to be written to disk + """ + scaled = self.image.resize((width, height), Image.ANTIALIAS) + return self._encode_image(scaled, output_type) + + def crop(self, width, height, output_type): """Rescales and crops the image to the given dimensions preserving aspect:: (w_in / h_in) = (w_scaled / h_scaled) @@ -65,6 +69,9 @@ class Thumbnailer(object): Args: max_width: The largest possible width. max_height: The larget possible height. + + Returns: + BytesIO: the bytes of the encoded image ready to be written to disk """ if width * self.height > height * self.width: scaled_height = (width * self.height) // self.width @@ -82,13 +89,9 @@ class Thumbnailer(object): crop_left = (scaled_width - width) // 2 crop_right = width + crop_left cropped = scaled_image.crop((crop_left, 0, crop_right, height)) - return self.save_image(cropped, output_type, output_path) + return self._encode_image(cropped, output_type) - def save_image(self, output_image, output_type, output_path): + def _encode_image(self, output_image, output_type): output_bytes_io = BytesIO() output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80) - output_bytes = output_bytes_io.getvalue() - with open(output_path, "wb") as output_file: - output_file.write(output_bytes) - logger.info("Stored thumbnail in file %r", output_path) - return len(output_bytes) + return output_bytes_io diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 4ab33f73bf..f6f498cdc5 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -93,7 +93,7 @@ class UploadResource(Resource): # TODO(markjh): parse content-dispostion content_uri = yield self.media_repo.create_content( - media_type, upload_name, request.content.read(), + media_type, upload_name, request.content, content_length, requester.user ) diff --git a/synapse/server.pyi b/synapse/server.pyi index 9570df5537..e8c0386b7f 100644 --- a/synapse/server.pyi +++ b/synapse/server.pyi @@ -1,4 +1,6 @@ import synapse.api.auth +import synapse.federation.transaction_queue +import synapse.federation.transport.client import synapse.handlers import synapse.handlers.auth import synapse.handlers.device @@ -27,3 +29,9 @@ class HomeServer(object): def get_state_handler(self) -> synapse.state.StateHandler: pass + + def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue: + pass + + def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient: + pass diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 752f172058..ad5c6f7c65 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -21,7 +21,7 @@ from synapse.events.utils import prune_event from synapse.util.async import ObservableDeferred from synapse.util.logcontext import ( - preserve_fn, PreserveLoggingContext, preserve_context_over_deferred + preserve_fn, PreserveLoggingContext, make_deferred_yieldable ) from synapse.util.logutils import log_function from synapse.util.metrics import Measure @@ -88,13 +88,23 @@ class _EventPeristenceQueue(object): def add_to_queue(self, room_id, events_and_contexts, backfilled): """Add events to the queue, with the given persist_event options. + NB: due to the normal usage pattern of this method, it does *not* + follow the synapse logcontext rules, and leaves the logcontext in + place whether or not the returned deferred is ready. + Args: room_id (str): events_and_contexts (list[(EventBase, EventContext)]): backfilled (bool): + + Returns: + defer.Deferred: a deferred which will resolve once the events are + persisted. Runs its callbacks *without* a logcontext. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: + # if the last item in the queue has the same `backfilled` setting, + # we can just add these new events to that item. end_item = queue[-1] if end_item.backfilled == backfilled: end_item.events_and_contexts.extend(events_and_contexts) @@ -113,11 +123,11 @@ class _EventPeristenceQueue(object): def handle_queue(self, room_id, per_item_callback): """Attempts to handle the queue for a room if not already being handled. - The given callback will be invoked with for each item in the queue,1 + The given callback will be invoked with for each item in the queue, of type _EventPersistQueueItem. The per_item_callback will continuously be called with new items, unless the queue becomnes empty. The return value of the function will be given to the deferreds waiting on the item, - exceptions will be passed to the deferres as well. + exceptions will be passed to the deferreds as well. This function should therefore be called whenever anything is added to the queue. @@ -233,7 +243,7 @@ class EventsStore(SQLBaseStore): deferreds = [] for room_id, evs_ctxs in partitioned.iteritems(): - d = preserve_fn(self._event_persist_queue.add_to_queue)( + d = self._event_persist_queue.add_to_queue( room_id, evs_ctxs, backfilled=backfilled, ) @@ -242,7 +252,7 @@ class EventsStore(SQLBaseStore): for room_id in partitioned: self._maybe_start_persisting(room_id) - return preserve_context_over_deferred( + return make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) @@ -267,7 +277,7 @@ class EventsStore(SQLBaseStore): self._maybe_start_persisting(event.room_id) - yield preserve_context_over_deferred(deferred) + yield make_deferred_yieldable(deferred) max_persisted_id = yield self._stream_id_gen.get_current_token() defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) @@ -1525,7 +1535,7 @@ class EventsStore(SQLBaseStore): if not allow_rejected: rows[:] = [r for r in rows if not r["rejects"]] - res = yield preserve_context_over_deferred(defer.gatherResults( + res = yield make_deferred_yieldable(defer.gatherResults( [ preserve_fn(self._get_event_from_row)( row["internal_metadata"], row["json"], row["redacts"], diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py index 3af372de59..9e63db5c6c 100644 --- a/synapse/storage/group_server.py +++ b/synapse/storage/group_server.py @@ -56,6 +56,18 @@ class GroupServerStore(SQLBaseStore): desc="get_users_in_group", ) + def get_invited_users_in_group(self, group_id): + # TODO: Pagination + + return self._simple_select_onecol( + table="group_invites", + keyvalues={ + "group_id": group_id, + }, + retcol="user_id", + desc="get_invited_users_in_group", + ) + def get_rooms_in_group(self, group_id, include_private=False): # TODO: Pagination diff --git a/synapse/util/async.py b/synapse/util/async.py index 0fd5b42523..a0a9039475 100644 --- a/synapse/util/async.py +++ b/synapse/util/async.py @@ -53,6 +53,11 @@ class ObservableDeferred(object): Cancelling or otherwise resolving an observer will not affect the original ObservableDeferred. + + NB that it does not attempt to do anything with logcontexts; in general + you should probably make_deferred_yieldable the deferreds + returned by `observe`, and ensure that the original deferred runs its + callbacks in the sentinel logcontext. """ __slots__ = ["_deferred", "_observers", "_result"] diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index b6173ab2ee..821c735528 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -47,6 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase): self.hs.get_auth_handler = Mock(return_value=self.auth_handler) self.hs.get_device_handler = Mock(return_value=self.device_handler) self.hs.config.enable_registration = True + self.hs.config.auto_join_rooms = [] # init the thing we're testing self.servlet = RegisterRestServlet(self.hs) diff --git a/tests/storage/event_injector.py b/tests/storage/event_injector.py deleted file mode 100644 index 024ac15069..0000000000 --- a/tests/storage/event_injector.py +++ /dev/null @@ -1,76 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from twisted.internet import defer - -from synapse.api.constants import EventTypes - - -class EventInjector: - def __init__(self, hs): - self.hs = hs - self.store = hs.get_datastore() - self.message_handler = hs.get_handlers().message_handler - self.event_builder_factory = hs.get_event_builder_factory() - - @defer.inlineCallbacks - def create_room(self, room, user): - builder = self.event_builder_factory.new({ - "type": EventTypes.Create, - "sender": user.to_string(), - "room_id": room.to_string(), - "content": {}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - - @defer.inlineCallbacks - def inject_room_member(self, room, user, membership): - builder = self.event_builder_factory.new({ - "type": EventTypes.Member, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"membership": membership}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) - - defer.returnValue(event) - - @defer.inlineCallbacks - def inject_message(self, room, user, body): - builder = self.event_builder_factory.new({ - "type": EventTypes.Message, - "sender": user.to_string(), - "state_key": user.to_string(), - "room_id": room.to_string(), - "content": {"body": body, "msgtype": u"message"}, - }) - - event, context = yield self.message_handler._create_new_client_event( - builder - ) - - yield self.store.persist_event(event, context) diff --git a/tests/util/test_log_context.py b/tests/util/test_logcontext.py similarity index 69% rename from tests/util/test_log_context.py rename to tests/util/test_logcontext.py index 9ffe209c4d..e2f7765f49 100644 --- a/tests/util/test_log_context.py +++ b/tests/util/test_logcontext.py @@ -94,3 +94,41 @@ class LoggingContextTestCase(unittest.TestCase): yield defer.succeed(None) return self._test_preserve_fn(nonblocking_function) + + @defer.inlineCallbacks + def test_make_deferred_yieldable(self): + # a function which retuns an incomplete deferred, but doesn't follow + # the synapse rules. + def blocking_function(): + d = defer.Deferred() + reactor.callLater(0, d.callback, None) + return d + + sentinel_context = LoggingContext.current_context() + + with LoggingContext() as context_one: + context_one.test_key = "one" + + d1 = logcontext.make_deferred_yieldable(blocking_function()) + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(LoggingContext.current_context(), sentinel_context) + + yield d1 + + # now it should be restored + self._check_test_key("one") + + @defer.inlineCallbacks + def test_make_deferred_yieldable_on_non_deferred(self): + """Check that make_deferred_yieldable does the right thing when its + argument isn't actually a deferred""" + + with LoggingContext() as context_one: + context_one.test_key = "one" + + d1 = logcontext.make_deferred_yieldable("bum") + self._check_test_key("one") + + r = yield d1 + self.assertEqual(r, "bum") + self._check_test_key("one")