From d1c7c2a98a133bdae7747601699a6b9ae0f90c8b Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 24 Jul 2019 23:21:52 -0400 Subject: [PATCH 01/55] allow devices to be marked as "hidden" This is a prerequisite for cross-signing, as it allows us to create other things that live within the device namespace, so they can be used for signatures. --- synapse/storage/devices.py | 63 ++++++++++++++----- .../storage/schema/delta/56/signing_keys.sql | 18 ++++++ 2 files changed, 65 insertions(+), 16 deletions(-) create mode 100644 synapse/storage/schema/delta/56/signing_keys.sql diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index d2b113a4e7..b73401bc26 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +22,7 @@ from canonicaljson import json from twisted.internet import defer -from synapse.api.errors import StoreError +from synapse.api.errors import Codes, StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import Cache, SQLBaseStore, db_to_json from synapse.storage.background_updates import BackgroundUpdateStore @@ -35,6 +37,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( class DeviceWorkerStore(SQLBaseStore): + @defer.inlineCallbacks def get_device(self, user_id, device_id): """Retrieve a device. @@ -46,12 +49,15 @@ class DeviceWorkerStore(SQLBaseStore): Raises: StoreError: if the device is not found """ - return self._simple_select_one( + ret = yield self._simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, - retcols=("user_id", "device_id", "display_name"), + retcols=("user_id", "device_id", "display_name", "hidden"), desc="get_device", ) + if ret["hidden"]: + raise StoreError(404, "No row found (devices)") + return ret @defer.inlineCallbacks def get_devices_by_user(self, user_id): @@ -67,11 +73,11 @@ class DeviceWorkerStore(SQLBaseStore): devices = yield self._simple_select_list( table="devices", keyvalues={"user_id": user_id}, - retcols=("user_id", "device_id", "display_name"), + retcols=("user_id", "device_id", "display_name", "hidden"), desc="get_devices_by_user", ) - defer.returnValue({d["device_id"]: d for d in devices}) + defer.returnValue({d["device_id"]: d for d in devices if not d["hidden"]}) @defer.inlineCallbacks def get_devices_by_remote(self, destination, from_stream_id, limit): @@ -540,6 +546,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): Returns: defer.Deferred: boolean whether the device was inserted or an existing device existed with that ID. + Raises: + StoreError: if the device is already in use """ key = (user_id, device_id) if self.device_id_exists_cache.get(key, None): @@ -552,12 +560,25 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): "user_id": user_id, "device_id": device_id, "display_name": initial_device_display_name, + "hidden": False, }, desc="store_device", or_ignore=True, ) + if not inserted: + # if the device already exists, check if it's a real device, or + # if the device ID is reserved by something else + hidden = yield self._simple_select_one_onecol( + "devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="hidden", + ) + if hidden: + raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) self.device_id_exists_cache.prefill(key, True) defer.returnValue(inserted) + except StoreError: + raise except Exception as e: logger.error( "store_device with device_id=%s(%r) user_id=%s(%r)" @@ -582,11 +603,11 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): Returns: defer.Deferred """ - yield self._simple_delete_one( - table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="delete_device", - ) + sql = """ + DELETE FROM devices + WHERE user_id = ? AND device_id = ? AND NOT COALESCE(hidden, ?) + """ + yield self._execute("delete_device", None, sql, user_id, device_id, False) self.device_id_exists_cache.invalidate((user_id, device_id)) @@ -600,13 +621,21 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): Returns: defer.Deferred """ - yield self._simple_delete_many( - table="devices", - column="device_id", - iterable=device_ids, - keyvalues={"user_id": user_id}, - desc="delete_devices", + + if not device_ids or len(device_ids) == 0: + return + sql = """ + DELETE FROM devices + WHERE user_id = ? AND device_id IN (%s) AND NOT COALESCE(hidden, ?) + """ % ( + ",".join("?" for _ in device_ids) ) + values = [user_id] + values.extend(device_ids) + values.append(False) + + yield self._execute("delete_devices", None, sql, *values) + for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) @@ -628,6 +657,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): updates["display_name"] = new_display_name if not updates: return defer.succeed(None) + # FIXME: should only update if hidden is not True. But updating the + # display name of a hidden device should be harmless return self._simple_update_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, diff --git a/synapse/storage/schema/delta/56/signing_keys.sql b/synapse/storage/schema/delta/56/signing_keys.sql new file mode 100644 index 0000000000..51c96d3116 --- /dev/null +++ b/synapse/storage/schema/delta/56/signing_keys.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- device list needs to know which ones are "real" devices, and which ones are +-- just used to avoid collisions +ALTER TABLE devices ADD COLUMN hidden BOOLEAN NULLABLE; From c659b9f94fff29adfb2abe4f6b345710b65e8741 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 25 Jul 2019 11:08:24 -0400 Subject: [PATCH 02/55] allow uploading keys for cross-signing --- synapse/api/errors.py | 1 + synapse/handlers/device.py | 17 ++ synapse/handlers/e2e_keys.py | 198 +++++++++++++++++- synapse/handlers/sync.py | 7 +- synapse/rest/client/v2_alpha/keys.py | 46 +++- synapse/storage/__init__.py | 5 +- synapse/storage/devices.py | 57 +++++ synapse/storage/end_to_end_keys.py | 174 ++++++++++++++- .../storage/schema/delta/56/signing_keys.sql | 41 ++++ synapse/types.py | 24 +++ tests/handlers/test_e2e_keys.py | 63 ++++++ 11 files changed, 621 insertions(+), 12 deletions(-) diff --git a/synapse/api/errors.py b/synapse/api/errors.py index ad3e262041..be15921bc6 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -61,6 +61,7 @@ class Codes(object): INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION" WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION" EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT" + INVALID_SIGNATURE = "M_INVALID_SIGNATURE" class CodeMessageException(RuntimeError): diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 99e8413092..2a8fa9c818 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -408,6 +410,21 @@ class DeviceHandler(DeviceWorkerHandler): for host in hosts: self.federation_sender.send_device_messages(host) + @defer.inlineCallbacks + def notify_user_signature_update(self, from_user_id, user_ids): + """Notify a user that they have made new signatures of other users. + + Args: + from_user_id (str): the user who made the signature + user_ids (list[str]): the users IDs that have new signatures + """ + + position = yield self.store.add_user_signature_change_to_streams( + from_user_id, user_ids + ) + + self.notifier.on_new_event("device_list_key", position, users=[from_user_id]) + @defer.inlineCallbacks def on_federation_query_user_devices(self, user_id): stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index fdfe8611b6..6187f879ef 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018-2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,12 +20,17 @@ import logging from six import iteritems from canonicaljson import encode_canonical_json, json +from signedjson.sign import SignatureVerifyException, verify_signed_json from twisted.internet import defer -from synapse.api.errors import CodeMessageException, SynapseError +from synapse.api.errors import CodeMessageException, Codes, SynapseError from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.types import UserID, get_domain_from_id +from synapse.types import ( + UserID, + get_domain_from_id, + get_verify_key_from_cross_signing_key, +) from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -46,7 +52,7 @@ class E2eKeysHandler(object): ) @defer.inlineCallbacks - def query_devices(self, query_body, timeout): + def query_devices(self, query_body, timeout, from_user_id): """ Handle a device key query from a client { @@ -64,6 +70,11 @@ class E2eKeysHandler(object): } } } + + Args: + from_user_id (str): the user making the query. This is used when + adding cross-signing signatures to limit what signatures users + can see. """ device_keys_query = query_body.get("device_keys", {}) @@ -118,6 +129,11 @@ class E2eKeysHandler(object): r = remote_queries_not_in_cache.setdefault(domain, {}) r[user_id] = remote_queries[user_id] + # Get cached cross-signing keys + cross_signing_keys = yield self.query_cross_signing_keys( + device_keys_query, from_user_id + ) + # Now fetch any devices that we don't have in our cache @defer.inlineCallbacks def do_remote_query(destination): @@ -131,6 +147,14 @@ class E2eKeysHandler(object): if user_id in destination_query: results[user_id] = keys + for user_id, key in remote_result["master_keys"].items(): + if user_id in destination_query: + cross_signing_keys["master"][user_id] = key + + for user_id, key in remote_result["self_signing_keys"].items(): + if user_id in destination_query: + cross_signing_keys["self_signing"][user_id] = key + except Exception as e: failures[destination] = _exception_to_failure(e) @@ -144,7 +168,73 @@ class E2eKeysHandler(object): ) ) - defer.returnValue({"device_keys": results, "failures": failures}) + ret = {"device_keys": results, "failures": failures} + + for key, value in iteritems(cross_signing_keys): + ret[key + "_keys"] = value + + defer.returnValue(ret) + + @defer.inlineCallbacks + def query_cross_signing_keys(self, query, from_user_id): + """Get cross-signing keys for users + + Args: + query (Iterable[string]) an iterable of user IDs. A dict whose keys + are user IDs satisfies this, so the query format used for + query_devices can be used here. + from_user_id (str): the user making the query. This is used when + adding cross-signing signatures to limit what signatures users + can see. + + Returns: + defer.Deferred[dict[str, dict[str, dict]]]: map from + (master|self_signing|user_signing) -> user_id -> key + """ + master_keys = {} + self_signing_keys = {} + user_signing_keys = {} + + for user_id in query: + # XXX: consider changing the store functions to allow querying + # multiple users simultaneously. + try: + key = yield self.store.get_e2e_cross_signing_key( + user_id, "master", from_user_id + ) + if key: + master_keys[user_id] = key + except Exception as e: + logger.info("Error getting master key: %s", e) + + try: + key = yield self.store.get_e2e_cross_signing_key( + user_id, "self_signing", from_user_id + ) + if key: + self_signing_keys[user_id] = key + except Exception as e: + logger.info("Error getting self-signing key: %s", e) + + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + if from_user_id == user_id: + try: + key = yield self.store.get_e2e_cross_signing_key( + user_id, "user_signing", from_user_id + ) + if key: + user_signing_keys[user_id] = key + except Exception as e: + logger.info("Error getting user-signing key: %s", e) + + defer.returnValue( + { + "master": master_keys, + "self_signing": self_signing_keys, + "user_signing": user_signing_keys, + } + ) @defer.inlineCallbacks def query_local_devices(self, query): @@ -342,6 +432,104 @@ class E2eKeysHandler(object): yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) + @defer.inlineCallbacks + def upload_signing_keys_for_user(self, user_id, keys): + """Upload signing keys for cross-signing + + Args: + user_id (string): the user uploading the keys + keys (dict[string, dict]): the signing keys + """ + + # if a master key is uploaded, then check it. Otherwise, load the + # stored master key, to check signatures on other keys + if "master_key" in keys: + master_key = keys["master_key"] + + _check_cross_signing_key(master_key, user_id, "master") + else: + master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master") + + # if there is no master key, then we can't do anything, because all the + # other cross-signing keys need to be signed by the master key + if not master_key: + raise SynapseError(400, "No master key available", Codes.MISSING_PARAM) + + master_key_id, master_verify_key = get_verify_key_from_cross_signing_key( + master_key + ) + + # for the other cross-signing keys, make sure that they have valid + # signatures from the master key + if "self_signing_key" in keys: + self_signing_key = keys["self_signing_key"] + + _check_cross_signing_key( + self_signing_key, user_id, "self_signing", master_verify_key + ) + + if "user_signing_key" in keys: + user_signing_key = keys["user_signing_key"] + + _check_cross_signing_key( + user_signing_key, user_id, "user_signing", master_verify_key + ) + + # if everything checks out, then store the keys and send notifications + deviceids = [] + if "master_key" in keys: + yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + deviceids.append(master_verify_key.version) + if "self_signing_key" in keys: + yield self.store.set_e2e_cross_signing_key( + user_id, "self_signing", self_signing_key + ) + deviceids.append( + get_verify_key_from_cross_signing_key(self_signing_key)[1].version + ) + if "user_signing_key" in keys: + yield self.store.set_e2e_cross_signing_key( + user_id, "user_signing", user_signing_key + ) + # the signature stream matches the semantics that we want for + # user-signing key updates: only the user themselves is notified of + # their own user-signing key updates + yield self.device_handler.notify_user_signature_update(user_id, [user_id]) + + # master key and self-signing key updates match the semantics of device + # list updates: all users who share an encrypted room are notified + if len(deviceids): + yield self.device_handler.notify_device_update(user_id, deviceids) + + defer.returnValue({}) + + +def _check_cross_signing_key(key, user_id, key_type, signing_key=None): + """Check a cross-signing key uploaded by a user. Performs some basic sanity + checking, and ensures that it is signed, if a signature is required. + + Args: + key (dict): the key data to verify + user_id (str): the user whose key is being checked + key_type (str): the type of key that the key should be + signing_key (VerifyKey): (optional) the signing key that the key should + be signed with. If omitted, signatures will not be checked. + """ + if ( + key.get("user_id") != user_id + or key_type not in key.get("usage", []) + or len(key.get("keys", {})) != 1 + ): + raise SynapseError(400, ("Invalid %s key" % (key_type,)), Codes.INVALID_PARAM) + + if signing_key: + try: + verify_signed_json(key, user_id, signing_key) + except SignatureVerifyException: + raise SynapseError( + 400, ("Invalid signature on %s key" % key_type), Codes.INVALID_SIGNATURE + ) + def _exception_to_failure(e): if isinstance(e, CodeMessageException): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index cd1ac0a27a..c1c28a5fa1 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018, 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1116,6 +1116,11 @@ class SyncHandler(object): # weren't in the previous sync *or* they left and rejoined. users_that_have_changed.update(newly_joined_or_invited_users) + user_signatures_changed = yield self.store.get_users_whose_signatures_changed( + user_id, since_token.device_list_key + ) + users_that_have_changed.update(user_signatures_changed) + # Now find users that we no longer track for room_id in newly_left_rooms: left_users = yield self.state.get_current_users_in_room(room_id) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 45c9928b65..3eaf1fd8a4 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,7 +27,7 @@ from synapse.http.servlet import ( ) from synapse.types import StreamToken -from ._base import client_patterns +from ._base import client_patterns, interactive_auth_handler logger = logging.getLogger(__name__) @@ -145,10 +146,11 @@ class KeyQueryServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() timeout = parse_integer(request, "timeout", 10 * 1000) body = parse_json_object_from_request(request) - result = yield self.e2e_keys_handler.query_devices(body, timeout) + result = yield self.e2e_keys_handler.query_devices(body, timeout, user_id) defer.returnValue((200, result)) @@ -227,8 +229,46 @@ class OneTimeKeyServlet(RestServlet): defer.returnValue((200, result)) +class SigningKeyUploadServlet(RestServlet): + """ + POST /keys/device_signing/upload HTTP/1.1 + Content-Type: application/json + + { + } + """ + + PATTERNS = client_patterns("/keys/device_signing/upload$", releases=()) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(SigningKeyUploadServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.auth_handler = hs.get_auth_handler() + + @interactive_auth_handler + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + yield self.auth_handler.validate_user_via_ui_auth( + requester, body, self.hs.get_ip_from_request(request) + ) + + result = yield self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) + defer.returnValue((200, result)) + + def register_servlets(hs, http_server): KeyUploadServlet(hs).register(http_server) KeyQueryServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server) + SigningKeyUploadServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 6b0ca80087..c20ba1001c 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd -# Copyright 2018 New Vector Ltd +# Copyright 2018,2019 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -207,6 +207,9 @@ class DataStore( self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max ) + self._user_signature_stream_cache = StreamChangeCache( + "UserSignatureStreamChangeCache", device_list_max + ) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max ) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index b73401bc26..ed372e2fc4 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -302,6 +302,41 @@ class DeviceWorkerStore(SQLBaseStore): """ txn.execute(sql, (destination, stream_id)) + @defer.inlineCallbacks + def add_user_signature_change_to_streams(self, from_user_id, user_ids): + """Persist that a user has made new signatures + + Args: + from_user_id (str): the user who made the signatures + user_ids (list[str]): the users who were signed + """ + + with self._device_list_id_gen.get_next() as stream_id: + yield self.runInteraction( + "add_user_sig_change_to_streams", + self._add_user_signature_change_txn, + from_user_id, + user_ids, + stream_id, + ) + defer.returnValue(stream_id) + + def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id): + txn.call_after( + self._user_signature_stream_cache.entity_has_changed, + from_user_id, + stream_id, + ) + self._simple_insert_txn( + txn, + "user_signature_stream", + values={ + "stream_id": stream_id, + "from_user_id": from_user_id, + "user_ids": json.dumps(user_ids), + }, + ) + def get_device_stream_token(self): return self._device_list_id_gen.get_current_token() @@ -440,6 +475,28 @@ class DeviceWorkerStore(SQLBaseStore): "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn ) + @defer.inlineCallbacks + def get_users_whose_signatures_changed(self, user_id, from_key): + """Get the users who have new cross-signing signatures made by `user_id` since + `from_key`. + + Args: + user_id (str): the user who made the signatures + from_key (str): The device lists stream token + """ + from_key = int(from_key) + if self._user_signature_stream_cache.has_entity_changed(user_id, from_key): + sql = """ + SELECT DISTINCT user_ids FROM user_signature_stream + WHERE from_user_id = ? AND stream_id > ? + """ + rows = yield self._execute( + "get_users_whose_signatures_changed", None, sql, user_id, from_key + ) + defer.returnValue(set(user for row in rows for user in json.loads(row[0]))) + else: + defer.returnValue(set()) + def get_all_device_list_changes_for_remotes(self, from_key, to_key): """Return a list of `(stream_id, user_id, destination)` which is the combined list of changes to devices, and which destinations need to be diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2fabb9e2cb..bb5f7d94eb 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +14,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import time + from six import iteritems -from canonicaljson import encode_canonical_json +from canonicaljson import encode_canonical_json, json from twisted.internet import defer @@ -85,11 +89,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore): " k.key_json" " FROM devices d" " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" - " WHERE %s" + " WHERE (%s) AND NOT COALESCE(d.hidden, ?)" ) % ( "LEFT" if include_all_devices else "INNER", " OR ".join("(" + q + ")" for q in query_clauses), ) + query_params.append(False) txn.execute(sql, query_params) rows = self.cursor_to_dict(txn) @@ -281,3 +286,168 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): return self.runInteraction( "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) + + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): + """Set a user's cross-signing key. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_id (str): the user to set the signing key for + key_type (str): the type of key that is being set: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + key (dict): the key data + """ + # the cross-signing keys need to occupy the same namespace as devices, + # since signatures are identified by device ID. So add an entry to the + # device table to make sure that we don't have a collision with device + # IDs + + # the 'key' dict will look something like: + # { + # "user_id": "@alice:example.com", + # "usage": ["self_signing"], + # "keys": { + # "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key", + # }, + # "signatures": { + # "@alice:example.com": { + # "ed25519:base64+master+public+key": "base64+signature" + # } + # } + # } + # The "keys" property must only have one entry, which will be the public + # key, so we just grab the first value in there + pubkey = next(iter(key["keys"].values())) + self._simple_insert( + "devices", + values={ + "user_id": user_id, + "device_id": pubkey, + "display_name": key_type + " signing key", + "hidden": True, + }, + desc="store_master_key_device", + ) + + # and finally, store the key itself + self._simple_insert( + "e2e_cross_signing_keys", + values={ + "user_id": user_id, + "keytype": key_type, + "keydata": json.dumps(key), + "added_ts": time.time() * 1000, + }, + desc="store_master_key", + ) + + def set_e2e_cross_signing_key(self, user_id, key_type, key): + """Set a user's cross-signing key. + + Args: + user_id (str): the user to set the user-signing key for + key_type (str): the type of cross-signing key to set + key (dict): the key data + """ + return self.runInteraction( + "add_e2e_cross_signing_key", + self._set_e2e_cross_signing_key_txn, + user_id, + key_type, + key, + ) + + def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): + """Returns a user's cross-signing key. + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + user_id (str): the user whose key is being requested + key_type (str): the type of key that is being set: either 'master' + for a master key, 'self_signing' for a self-signing key, or + 'user_signing' for a user-signing key + from_user_id (str): if specified, signatures made by this user on + the key will be included in the result + + Returns: + dict of the key data + """ + sql = ( + "SELECT keydata " + " FROM e2e_cross_signing_keys " + " WHERE user_id = ? AND keytype = ? ORDER BY added_ts DESC LIMIT 1" + ) + txn.execute(sql, (user_id, key_type)) + row = txn.fetchone() + if not row: + return None + key = json.loads(row[0]) + + device_id = None + for k in key["keys"].values(): + device_id = k + + if from_user_id is not None: + sql = ( + "SELECT key_id, signature " + " FROM e2e_cross_signing_signatures " + " WHERE user_id = ? " + " AND target_user_id = ? " + " AND target_device_id = ? " + ) + txn.execute(sql, (from_user_id, user_id, device_id)) + row = txn.fetchone() + if row: + key.setdefault("signatures", {}).setdefault(from_user_id, {})[ + row[0] + ] = row[1] + + return key + + def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None): + """Returns a user's cross-signing key. + + Args: + user_id (str): the user whose self-signing key is being requested + key_type (str): the type of cross-signing key to get + from_user_id (str): if specified, signatures made by this user on + the self-signing key will be included in the result + + Returns: + dict of the key data + """ + return self.runInteraction( + "get_e2e_cross_signing_key", + self._get_e2e_cross_signing_key_txn, + user_id, + key_type, + from_user_id, + ) + + def store_e2e_cross_signing_signatures(self, user_id, signatures): + """Stores cross-signing signatures. + + Args: + user_id (str): the user who made the signatures + signatures (iterable[(str, str, str, str)]): signatures to add - each + a tuple of (key_id, target_user_id, target_device_id, signature), + where key_id is the ID of the key (including the signature + algorithm) that made the signature, target_user_id and + target_device_id indicate the device being signed, and signature + is the signature of the device + """ + return self._simple_insert_many( + "e2e_cross_signing_signatures", + [ + { + "user_id": user_id, + "key_id": key_id, + "target_user_id": target_user_id, + "target_device_id": target_device_id, + "signature": signature, + } + for (key_id, target_user_id, target_device_id, signature) in signatures + ], + "add_e2e_signing_key", + ) diff --git a/synapse/storage/schema/delta/56/signing_keys.sql b/synapse/storage/schema/delta/56/signing_keys.sql index 51c96d3116..771740e970 100644 --- a/synapse/storage/schema/delta/56/signing_keys.sql +++ b/synapse/storage/schema/delta/56/signing_keys.sql @@ -13,6 +13,47 @@ * limitations under the License. */ +-- cross-signing keys +CREATE TABLE IF NOT EXISTS e2e_cross_signing_keys ( + user_id TEXT NOT NULL, + -- the type of cross-signing key (master, user_signing, or self_signing) + keytype TEXT NOT NULL, + -- the full key information, as a json-encoded dict + keydata TEXT NOT NULL, + -- time that the key was added + added_ts BIGINT NOT NULL +); + +CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, added_ts); + +-- cross-signing signatures +CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures ( + -- user who did the signing + user_id TEXT NOT NULL, + -- key used to sign + key_id TEXT NOT NULL, + -- user who was signed + target_user_id TEXT NOT NULL, + -- device/key that was signed + target_device_id TEXT NOT NULL, + -- the actual signature + signature TEXT NOT NULL +); + +CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); + +-- stream of user signature updates +CREATE TABLE IF NOT EXISTS user_signature_stream ( + -- uses the same stream ID as device list stream + stream_id BIGINT NOT NULL, + -- user who did the signing + from_user_id TEXT NOT NULL, + -- list of users who were signed, as a JSON array + user_ids TEXT NOT NULL +); + +CREATE UNIQUE INDEX user_signature_stream_idx ON user_signature_stream(stream_id); + -- device list needs to know which ones are "real" devices, and which ones are -- just used to avoid collisions ALTER TABLE devices ADD COLUMN hidden BOOLEAN NULLABLE; diff --git a/synapse/types.py b/synapse/types.py index 51eadb6ad4..7a80471a0c 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,6 +18,8 @@ import string from collections import namedtuple import attr +from signedjson.key import decode_verify_key_bytes +from unpaddedbase64 import decode_base64 from synapse.api.errors import SynapseError @@ -475,3 +478,24 @@ class ReadReceipt(object): user_id = attr.ib() event_ids = attr.ib() data = attr.ib() + + +def get_verify_key_from_cross_signing_key(key_info): + """Get the key ID and signedjson verify key from a cross-signing key dict + + Args: + key_info (dict): a cross-signing key dict, which must have a "keys" + property that has exactly one item in it + + Returns: + (str, VerifyKey): the key ID and verify key for the cross-signing key + """ + # make sure that exactly one key is provided + if "keys" not in key_info: + raise SynapseError(400, "Invalid key") + keys = key_info["keys"] + if len(keys) != 1: + raise SynapseError(400, "Invalid key") + # and return that one key + for key_id, key_data in keys.items(): + return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 8dccc6826e..9ae4cb6ea2 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -145,3 +147,64 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}}, }, ) + + @defer.inlineCallbacks + def test_replace_master_key(self): + """uploading a new signing key should make the old signing key unavailable""" + local_user = "@boris:" + self.hs.hostname + keys1 = { + "master_key": { + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + }, + } + } + yield self.handler.upload_signing_keys_for_user(local_user, keys1) + + keys2 = { + "master_key": { + # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw": "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw" + }, + } + } + yield self.handler.upload_signing_keys_for_user(local_user, keys2) + + devices = yield self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) + self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) + + @defer.inlineCallbacks + def test_self_signing_key_doesnt_show_up_as_device(self): + """signing keys should be hidden when fetching a user's devices""" + local_user = "@boris:" + self.hs.hostname + keys1 = { + "master_key": { + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + }, + } + } + yield self.handler.upload_signing_keys_for_user(local_user, keys1) + + res = None + try: + yield self.hs.get_device_handler().check_device_registered( + user_id=local_user, + device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + initial_device_display_name="new display name", + ) + except errors.SynapseError as e: + res = e.code + self.assertEqual(res, 400) + + res = yield self.handler.query_local_devices({local_user: None}) + self.assertDictEqual(res, {local_user: {}}) From 781ade836b05b7e327baa7f927c553941edcc368 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 30 Jul 2019 23:09:50 -0400 Subject: [PATCH 03/55] apply changes from PR review --- synapse/storage/devices.py | 33 +++++++++---------- synapse/storage/end_to_end_keys.py | 2 +- .../{signing_keys.sql => hidden_devices.sql} | 2 +- 3 files changed, 17 insertions(+), 20 deletions(-) rename synapse/storage/schema/delta/56/{signing_keys.sql => hidden_devices.sql} (92%) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index b73401bc26..f62ed12386 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -37,9 +37,9 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( class DeviceWorkerStore(SQLBaseStore): - @defer.inlineCallbacks def get_device(self, user_id, device_id): - """Retrieve a device. + """Retrieve a device. Only returns devices that are not marked as + hidden. Args: user_id (str): The ID of the user which owns the device @@ -49,19 +49,17 @@ class DeviceWorkerStore(SQLBaseStore): Raises: StoreError: if the device is not found """ - ret = yield self._simple_select_one( + return self._simple_select_one( table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name", "hidden"), desc="get_device", ) - if ret["hidden"]: - raise StoreError(404, "No row found (devices)") - return ret @defer.inlineCallbacks def get_devices_by_user(self, user_id): - """Retrieve all of a user's registered devices. + """Retrieve all of a user's registered devices. Only returns devices + that are not marked as hidden. Args: user_id (str): @@ -72,12 +70,12 @@ class DeviceWorkerStore(SQLBaseStore): """ devices = yield self._simple_select_list( table="devices", - keyvalues={"user_id": user_id}, - retcols=("user_id", "device_id", "display_name", "hidden"), + keyvalues={"user_id": user_id, "hidden": False}, + retcols=("user_id", "device_id", "display_name"), desc="get_devices_by_user", ) - defer.returnValue({d["device_id"]: d for d in devices if not d["hidden"]}) + defer.returnValue({d["device_id"]: d for d in devices}) @defer.inlineCallbacks def get_devices_by_remote(self, destination, from_stream_id, limit): @@ -605,9 +603,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): """ sql = """ DELETE FROM devices - WHERE user_id = ? AND device_id = ? AND NOT COALESCE(hidden, ?) + WHERE user_id = ? AND device_id = ? AND NOT hidden """ - yield self._execute("delete_device", None, sql, user_id, device_id, False) + yield self._execute("delete_device", None, sql, user_id, device_id) self.device_id_exists_cache.invalidate((user_id, device_id)) @@ -626,7 +624,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): return sql = """ DELETE FROM devices - WHERE user_id = ? AND device_id IN (%s) AND NOT COALESCE(hidden, ?) + WHERE user_id = ? AND device_id IN (%s) AND NOT hidden """ % ( ",".join("?" for _ in device_ids) ) @@ -640,7 +638,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): self.device_id_exists_cache.invalidate((user_id, device_id)) def update_device(self, user_id, device_id, new_display_name=None): - """Update a device. + """Update a device. Only updates the device if it is not marked as + hidden. Args: user_id (str): The ID of the user which owns the device @@ -657,11 +656,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): updates["display_name"] = new_display_name if not updates: return defer.succeed(None) - # FIXME: should only update if hidden is not True. But updating the - # display name of a hidden device should be harmless return self._simple_update_one( table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, desc="update_device", ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 2fabb9e2cb..66eb509588 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -85,7 +85,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): " k.key_json" " FROM devices d" " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" - " WHERE %s" + " WHERE %s AND NOT d.hidden" ) % ( "LEFT" if include_all_devices else "INNER", " OR ".join("(" + q + ")" for q in query_clauses), diff --git a/synapse/storage/schema/delta/56/signing_keys.sql b/synapse/storage/schema/delta/56/hidden_devices.sql similarity index 92% rename from synapse/storage/schema/delta/56/signing_keys.sql rename to synapse/storage/schema/delta/56/hidden_devices.sql index 51c96d3116..67f8b20297 100644 --- a/synapse/storage/schema/delta/56/signing_keys.sql +++ b/synapse/storage/schema/delta/56/hidden_devices.sql @@ -15,4 +15,4 @@ -- device list needs to know which ones are "real" devices, and which ones are -- just used to avoid collisions -ALTER TABLE devices ADD COLUMN hidden BOOLEAN NULLABLE; +ALTER TABLE devices ADD COLUMN hidden BOOLEAN DEFAULT FALSE; From 2997a91250c4af915be70d2be38df1b3889c4c2d Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 30 Jul 2019 23:14:00 -0400 Subject: [PATCH 04/55] add changelog file --- changelog.d/5759.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/5759.misc diff --git a/changelog.d/5759.misc b/changelog.d/5759.misc new file mode 100644 index 0000000000..c0bc566c4c --- /dev/null +++ b/changelog.d/5759.misc @@ -0,0 +1 @@ +Allow devices to be marked as hidden, for use by features such as cross-signing. \ No newline at end of file From 185188be03e278a5a9a24f7e206e7fc2410415a7 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 31 Jul 2019 15:18:15 -0400 Subject: [PATCH 05/55] remove extra SQL query param --- synapse/storage/devices.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index e7800da4f7..b3e8c7396d 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -630,7 +630,6 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): ) values = [user_id] values.extend(device_ids) - values.append(False) yield self._execute("delete_devices", None, sql, *values) From 430ea08186750ef67899bc302c0b6bb32c2f111c Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 31 Jul 2019 15:38:11 -0400 Subject: [PATCH 06/55] PostgreSQL, Y U no like? --- synapse/storage/devices.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index b3e8c7396d..a1f12df907 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -603,9 +603,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): """ sql = """ DELETE FROM devices - WHERE user_id = ? AND device_id = ? AND NOT hidden + WHERE user_id = ? AND device_id = ? AND hidden = ? """ - yield self._execute("delete_device", None, sql, user_id, device_id) + yield self._execute("delete_device", None, sql, user_id, device_id, False) self.device_id_exists_cache.invalidate((user_id, device_id)) @@ -624,12 +624,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): return sql = """ DELETE FROM devices - WHERE user_id = ? AND device_id IN (%s) AND NOT hidden + WHERE user_id = ? AND device_id IN (%s) AND hidden = ? """ % ( ",".join("?" for _ in device_ids) ) values = [user_id] values.extend(device_ids) + values.append(False) yield self._execute("delete_devices", None, sql, *values) From 73b26f827ccb96a10629ecb0737bd3db4915bb14 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 31 Jul 2019 18:37:05 -0400 Subject: [PATCH 07/55] really fix queries to work with Postgres (by going back to not using SQL directly) --- synapse/storage/devices.py | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index a1f12df907..9f2bb40834 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -601,11 +601,11 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): Returns: defer.Deferred """ - sql = """ - DELETE FROM devices - WHERE user_id = ? AND device_id = ? AND hidden = ? - """ - yield self._execute("delete_device", None, sql, user_id, device_id, False) + yield self._simple_delete_one( + table="devices", + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, + desc="delete_device", + ) self.device_id_exists_cache.invalidate((user_id, device_id)) @@ -619,21 +619,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): Returns: defer.Deferred """ - - if not device_ids or len(device_ids) == 0: - return - sql = """ - DELETE FROM devices - WHERE user_id = ? AND device_id IN (%s) AND hidden = ? - """ % ( - ",".join("?" for _ in device_ids) + yield self._simple_delete_many( + table="devices", + column="device_id", + iterable=device_ids, + keyvalues={"user_id": user_id, "hidden": False}, + desc="delete_devices", ) - values = [user_id] - values.extend(device_ids) - values.append(False) - - yield self._execute("delete_devices", None, sql, *values) - for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) From d78a4fe156e574ad1f28cf3b12ed0bb11bc077f4 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 1 Aug 2019 02:16:09 -0400 Subject: [PATCH 08/55] don't need to return the hidden column any more --- synapse/storage/devices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 9f2bb40834..991e28ea24 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -52,7 +52,7 @@ class DeviceWorkerStore(SQLBaseStore): return self._simple_select_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, - retcols=("user_id", "device_id", "display_name", "hidden"), + retcols=("user_id", "device_id", "display_name"), desc="get_device", ) From fac1cdc5626ab2d59861a6aead8a44e7638934ba Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 1 Aug 2019 21:51:19 -0400 Subject: [PATCH 09/55] make changes from PR review --- synapse/handlers/e2e_keys.py | 24 ++++++-- .../schema/delta/56/hidden_devices.sql | 41 -------------- .../storage/schema/delta/56/signing_keys.sql | 55 +++++++++++++++++++ synapse/types.py | 4 +- 4 files changed, 75 insertions(+), 49 deletions(-) create mode 100644 synapse/storage/schema/delta/56/signing_keys.sql diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 39f4ec8e60..9081c3f64c 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -510,9 +510,18 @@ class E2eKeysHandler(object): if not master_key: raise SynapseError(400, "No master key available", Codes.MISSING_PARAM) - master_key_id, master_verify_key = get_verify_key_from_cross_signing_key( - master_key - ) + try: + master_key_id, master_verify_key = get_verify_key_from_cross_signing_key( + master_key + ) + except ValueError: + if "master_key" in keys: + # the invalid key came from the request + raise SynapseError(400, "Invalid master key", Codes.INVALID_PARAM) + else: + # the invalid key came from the database + logger.error("Invalid master key found for user %s", user_id) + raise SynapseError(500, "Invalid master key") # for the other cross-signing keys, make sure that they have valid # signatures from the master key @@ -539,9 +548,12 @@ class E2eKeysHandler(object): yield self.store.set_e2e_cross_signing_key( user_id, "self_signing", self_signing_key ) - deviceids.append( - get_verify_key_from_cross_signing_key(self_signing_key)[1].version - ) + try: + deviceids.append( + get_verify_key_from_cross_signing_key(self_signing_key)[1].version + ) + except ValueError: + raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM) if "user_signing_key" in keys: yield self.store.set_e2e_cross_signing_key( user_id, "user_signing", user_signing_key diff --git a/synapse/storage/schema/delta/56/hidden_devices.sql b/synapse/storage/schema/delta/56/hidden_devices.sql index e1cd8cc2c1..67f8b20297 100644 --- a/synapse/storage/schema/delta/56/hidden_devices.sql +++ b/synapse/storage/schema/delta/56/hidden_devices.sql @@ -13,47 +13,6 @@ * limitations under the License. */ --- cross-signing keys -CREATE TABLE IF NOT EXISTS e2e_cross_signing_keys ( - user_id TEXT NOT NULL, - -- the type of cross-signing key (master, user_signing, or self_signing) - keytype TEXT NOT NULL, - -- the full key information, as a json-encoded dict - keydata TEXT NOT NULL, - -- time that the key was added - added_ts BIGINT NOT NULL -); - -CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, added_ts); - --- cross-signing signatures -CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures ( - -- user who did the signing - user_id TEXT NOT NULL, - -- key used to sign - key_id TEXT NOT NULL, - -- user who was signed - target_user_id TEXT NOT NULL, - -- device/key that was signed - target_device_id TEXT NOT NULL, - -- the actual signature - signature TEXT NOT NULL -); - -CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); - --- stream of user signature updates -CREATE TABLE IF NOT EXISTS user_signature_stream ( - -- uses the same stream ID as device list stream - stream_id BIGINT NOT NULL, - -- user who did the signing - from_user_id TEXT NOT NULL, - -- list of users who were signed, as a JSON array - user_ids TEXT NOT NULL -); - -CREATE UNIQUE INDEX user_signature_stream_idx ON user_signature_stream(stream_id); - -- device list needs to know which ones are "real" devices, and which ones are -- just used to avoid collisions ALTER TABLE devices ADD COLUMN hidden BOOLEAN DEFAULT FALSE; diff --git a/synapse/storage/schema/delta/56/signing_keys.sql b/synapse/storage/schema/delta/56/signing_keys.sql new file mode 100644 index 0000000000..6a9ef1782e --- /dev/null +++ b/synapse/storage/schema/delta/56/signing_keys.sql @@ -0,0 +1,55 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- cross-signing keys +CREATE TABLE IF NOT EXISTS e2e_cross_signing_keys ( + user_id TEXT NOT NULL, + -- the type of cross-signing key (master, user_signing, or self_signing) + keytype TEXT NOT NULL, + -- the full key information, as a json-encoded dict + keydata TEXT NOT NULL, + -- time that the key was added + added_ts BIGINT NOT NULL +); + +CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, added_ts); + +-- cross-signing signatures +CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures ( + -- user who did the signing + user_id TEXT NOT NULL, + -- key used to sign + key_id TEXT NOT NULL, + -- user who was signed + target_user_id TEXT NOT NULL, + -- device/key that was signed + target_device_id TEXT NOT NULL, + -- the actual signature + signature TEXT NOT NULL +); + +CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id); + +-- stream of user signature updates +CREATE TABLE IF NOT EXISTS user_signature_stream ( + -- uses the same stream ID as device list stream + stream_id BIGINT NOT NULL, + -- user who did the signing + from_user_id TEXT NOT NULL, + -- list of users who were signed, as a JSON array + user_ids TEXT NOT NULL +); + +CREATE UNIQUE INDEX user_signature_stream_idx ON user_signature_stream(stream_id); diff --git a/synapse/types.py b/synapse/types.py index 7a80471a0c..00bb0743ff 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -492,10 +492,10 @@ def get_verify_key_from_cross_signing_key(key_info): """ # make sure that exactly one key is provided if "keys" not in key_info: - raise SynapseError(400, "Invalid key") + raise ValueError("Invalid key") keys = key_info["keys"] if len(keys) != 1: - raise SynapseError(400, "Invalid key") + raise ValueError("Invalid key") # and return that one key for key_id, key_data in keys.items(): return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))) From d28d1e2d1b056a0c9e2b9f2c92013515a56dd9fb Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 1 Aug 2019 21:52:35 -0400 Subject: [PATCH 10/55] add changelog --- changelog.d/5769.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/5769.feature diff --git a/changelog.d/5769.feature b/changelog.d/5769.feature new file mode 100644 index 0000000000..c34257cb8f --- /dev/null +++ b/changelog.d/5769.feature @@ -0,0 +1 @@ +allow uploading of cross-signing keys \ No newline at end of file From 8c9adcc95dee892f90d6acbbe5c54acbf621720b Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 1 Aug 2019 22:09:05 -0400 Subject: [PATCH 11/55] fix formatting --- changelog.d/5769.feature | 2 +- tests/handlers/test_e2e_keys.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/changelog.d/5769.feature b/changelog.d/5769.feature index c34257cb8f..bf994ca327 100644 --- a/changelog.d/5769.feature +++ b/changelog.d/5769.feature @@ -1 +1 @@ -allow uploading of cross-signing keys \ No newline at end of file +Allow uploading of cross-signing keys. \ No newline at end of file diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 9ae4cb6ea2..a62c52eefa 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -176,7 +176,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): } yield self.handler.upload_signing_keys_for_user(local_user, keys2) - devices = yield self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user) + devices = yield self.handler.query_devices( + {"device_keys": {local_user: []}}, 0, local_user + ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) @defer.inlineCallbacks From f63ba7a7955d077224d4d602cd33bb31fad92fbc Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 12 Aug 2019 15:14:37 -0700 Subject: [PATCH 12/55] Cross-signing [1/4] -- hidden devices (#5759) * allow devices to be marked as "hidden" This is a prerequisite for cross-signing, as it allows us to create other things that live within the device namespace, so they can be used for signatures. --- changelog.d/5759.misc | 1 + synapse/storage/devices.py | 38 ++++++++++++++----- synapse/storage/end_to_end_keys.py | 2 +- .../schema/delta/56/hidden_devices.sql | 18 +++++++++ 4 files changed, 49 insertions(+), 10 deletions(-) create mode 100644 changelog.d/5759.misc create mode 100644 synapse/storage/schema/delta/56/hidden_devices.sql diff --git a/changelog.d/5759.misc b/changelog.d/5759.misc new file mode 100644 index 0000000000..c0bc566c4c --- /dev/null +++ b/changelog.d/5759.misc @@ -0,0 +1 @@ +Allow devices to be marked as hidden, for use by features such as cross-signing. \ No newline at end of file diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 8f72d92895..991e28ea24 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2016 OpenMarket Ltd +# Copyright 2019 New Vector Ltd +# Copyright 2019 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,7 +22,7 @@ from canonicaljson import json from twisted.internet import defer -from synapse.api.errors import StoreError +from synapse.api.errors import Codes, StoreError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage._base import Cache, SQLBaseStore, db_to_json from synapse.storage.background_updates import BackgroundUpdateStore @@ -36,7 +38,8 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( class DeviceWorkerStore(SQLBaseStore): def get_device(self, user_id, device_id): - """Retrieve a device. + """Retrieve a device. Only returns devices that are not marked as + hidden. Args: user_id (str): The ID of the user which owns the device @@ -48,14 +51,15 @@ class DeviceWorkerStore(SQLBaseStore): """ return self._simple_select_one( table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_device", ) @defer.inlineCallbacks def get_devices_by_user(self, user_id): - """Retrieve all of a user's registered devices. + """Retrieve all of a user's registered devices. Only returns devices + that are not marked as hidden. Args: user_id (str): @@ -66,7 +70,7 @@ class DeviceWorkerStore(SQLBaseStore): """ devices = yield self._simple_select_list( table="devices", - keyvalues={"user_id": user_id}, + keyvalues={"user_id": user_id, "hidden": False}, retcols=("user_id", "device_id", "display_name"), desc="get_devices_by_user", ) @@ -540,6 +544,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): Returns: defer.Deferred: boolean whether the device was inserted or an existing device existed with that ID. + Raises: + StoreError: if the device is already in use """ key = (user_id, device_id) if self.device_id_exists_cache.get(key, None): @@ -552,12 +558,25 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): "user_id": user_id, "device_id": device_id, "display_name": initial_device_display_name, + "hidden": False, }, desc="store_device", or_ignore=True, ) + if not inserted: + # if the device already exists, check if it's a real device, or + # if the device ID is reserved by something else + hidden = yield self._simple_select_one_onecol( + "devices", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="hidden", + ) + if hidden: + raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) self.device_id_exists_cache.prefill(key, True) return inserted + except StoreError: + raise except Exception as e: logger.error( "store_device with device_id=%s(%r) user_id=%s(%r)" @@ -584,7 +603,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): """ yield self._simple_delete_one( table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, desc="delete_device", ) @@ -604,14 +623,15 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): table="devices", column="device_id", iterable=device_ids, - keyvalues={"user_id": user_id}, + keyvalues={"user_id": user_id, "hidden": False}, desc="delete_devices", ) for device_id in device_ids: self.device_id_exists_cache.invalidate((user_id, device_id)) def update_device(self, user_id, device_id, new_display_name=None): - """Update a device. + """Update a device. Only updates the device if it is not marked as + hidden. Args: user_id (str): The ID of the user which owns the device @@ -630,7 +650,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore): return defer.succeed(None) return self._simple_update_one( table="devices", - keyvalues={"user_id": user_id, "device_id": device_id}, + keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False}, updatevalues=updates, desc="update_device", ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 1e07474e70..6f524cedd9 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -85,7 +85,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): " k.key_json" " FROM devices d" " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" - " WHERE %s" + " WHERE %s AND NOT d.hidden" ) % ( "LEFT" if include_all_devices else "INNER", " OR ".join("(" + q + ")" for q in query_clauses), diff --git a/synapse/storage/schema/delta/56/hidden_devices.sql b/synapse/storage/schema/delta/56/hidden_devices.sql new file mode 100644 index 0000000000..67f8b20297 --- /dev/null +++ b/synapse/storage/schema/delta/56/hidden_devices.sql @@ -0,0 +1,18 @@ +/* Copyright 2019 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- device list needs to know which ones are "real" devices, and which ones are +-- just used to avoid collisions +ALTER TABLE devices ADD COLUMN hidden BOOLEAN DEFAULT FALSE; From 7c3abc65728af052b0d484f9669b1c084cd2faf5 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 21 Aug 2019 13:19:35 -0700 Subject: [PATCH 13/55] apply PR review suggestions --- synapse/handlers/e2e_keys.py | 76 +++++++++++++--------------- synapse/rest/client/v2_alpha/keys.py | 2 +- synapse/storage/devices.py | 6 +-- synapse/storage/end_to_end_keys.py | 15 +++--- 4 files changed, 46 insertions(+), 53 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 9081c3f64c..53ca8330ad 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -17,6 +17,8 @@ import logging +import time + from six import iteritems from canonicaljson import encode_canonical_json, json @@ -132,7 +134,7 @@ class E2eKeysHandler(object): r[user_id] = remote_queries[user_id] # Get cached cross-signing keys - cross_signing_keys = yield self.query_cross_signing_keys( + cross_signing_keys = yield self.get_cross_signing_keys_from_cache( device_keys_query, from_user_id ) @@ -200,11 +202,11 @@ class E2eKeysHandler(object): for user_id, key in remote_result["master_keys"].items(): if user_id in destination_query: - cross_signing_keys["master"][user_id] = key + cross_signing_keys["master_keys"][user_id] = key for user_id, key in remote_result["self_signing_keys"].items(): if user_id in destination_query: - cross_signing_keys["self_signing"][user_id] = key + cross_signing_keys["self_signing_keys"][user_id] = key except Exception as e: failure = _exception_to_failure(e) @@ -222,14 +224,13 @@ class E2eKeysHandler(object): ret = {"device_keys": results, "failures": failures} - for key, value in iteritems(cross_signing_keys): - ret[key + "_keys"] = value + ret.update(cross_signing_keys) return ret @defer.inlineCallbacks - def query_cross_signing_keys(self, query, from_user_id): - """Get cross-signing keys for users + def get_cross_signing_keys_from_cache(self, query, from_user_id): + """Get cross-signing keys for users from the database Args: query (Iterable[string]) an iterable of user IDs. A dict whose keys @@ -250,43 +251,32 @@ class E2eKeysHandler(object): for user_id in query: # XXX: consider changing the store functions to allow querying # multiple users simultaneously. - try: - key = yield self.store.get_e2e_cross_signing_key( - user_id, "master", from_user_id - ) - if key: - master_keys[user_id] = key - except Exception as e: - logger.info("Error getting master key: %s", e) + key = yield self.store.get_e2e_cross_signing_key( + user_id, "master", from_user_id + ) + if key: + master_keys[user_id] = key - try: - key = yield self.store.get_e2e_cross_signing_key( - user_id, "self_signing", from_user_id - ) - if key: - self_signing_keys[user_id] = key - except Exception as e: - logger.info("Error getting self-signing key: %s", e) + key = yield self.store.get_e2e_cross_signing_key( + user_id, "self_signing", from_user_id + ) + if key: + self_signing_keys[user_id] = key # users can see other users' master and self-signing keys, but can # only see their own user-signing keys if from_user_id == user_id: - try: - key = yield self.store.get_e2e_cross_signing_key( - user_id, "user_signing", from_user_id - ) - if key: - user_signing_keys[user_id] = key - except Exception as e: - logger.info("Error getting user-signing key: %s", e) + key = yield self.store.get_e2e_cross_signing_key( + user_id, "user_signing", from_user_id + ) + if key: + user_signing_keys[user_id] = key - defer.returnValue( - { - "master": master_keys, - "self_signing": self_signing_keys, - "user_signing": user_signing_keys, - } - ) + return { + "master_keys": master_keys, + "self_signing_keys": self_signing_keys, + "user_signing_keys": user_signing_keys, + } @defer.inlineCallbacks def query_local_devices(self, query): @@ -542,11 +532,13 @@ class E2eKeysHandler(object): # if everything checks out, then store the keys and send notifications deviceids = [] if "master_key" in keys: - yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) + yield self.store.set_e2e_cross_signing_key( + user_id, "master", master_key, time.time() * 1000 + ) deviceids.append(master_verify_key.version) if "self_signing_key" in keys: yield self.store.set_e2e_cross_signing_key( - user_id, "self_signing", self_signing_key + user_id, "self_signing", self_signing_key, time.time() * 1000 ) try: deviceids.append( @@ -556,7 +548,7 @@ class E2eKeysHandler(object): raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM) if "user_signing_key" in keys: yield self.store.set_e2e_cross_signing_key( - user_id, "user_signing", user_signing_key + user_id, "user_signing", user_signing_key, time.time() * 1000 ) # the signature stream matches the semantics that we want for # user-signing key updates: only the user themselves is notified of @@ -568,7 +560,7 @@ class E2eKeysHandler(object): if len(deviceids): yield self.device_handler.notify_device_update(user_id, deviceids) - defer.returnValue({}) + return {} def _check_cross_signing_key(key, user_id, key_type, signing_key=None): diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index f40a785598..1340d2c80d 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -263,7 +263,7 @@ class SigningKeyUploadServlet(RestServlet): ) result = yield self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) - defer.returnValue((200, result)) + return (200, result) def register_servlets(hs, http_server): diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index da23a350a1..6a5572e001 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -317,7 +317,7 @@ class DeviceWorkerStore(SQLBaseStore): user_ids, stream_id, ) - defer.returnValue(stream_id) + return stream_id def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id): txn.call_after( @@ -491,9 +491,9 @@ class DeviceWorkerStore(SQLBaseStore): rows = yield self._execute( "get_users_whose_signatures_changed", None, sql, user_id, from_key ) - defer.returnValue(set(user for row in rows for user in json.loads(row[0]))) + return set(user for row in rows for user in json.loads(row[0])) else: - defer.returnValue(set()) + return set() def get_all_device_list_changes_for_remotes(self, from_key, to_key): """Return a list of `(stream_id, user_id, destination)` which is the diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index de2d1bbb9f..b218b7b2e8 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -14,8 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import time - from six import iteritems from canonicaljson import encode_canonical_json, json @@ -284,7 +282,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, added_ts): """Set a user's cross-signing key. Args: @@ -294,6 +292,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key key (dict): the key data + added_ts (int): the timestamp for when the key was added """ # the cross-signing keys need to occupy the same namespace as devices, # since signatures are identified by device ID. So add an entry to the @@ -334,18 +333,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "user_id": user_id, "keytype": key_type, "keydata": json.dumps(key), - "added_ts": time.time() * 1000, + "added_ts": added_ts, }, desc="store_master_key", ) - def set_e2e_cross_signing_key(self, user_id, key_type, key): + def set_e2e_cross_signing_key(self, user_id, key_type, key, added_ts): """Set a user's cross-signing key. Args: user_id (str): the user to set the user-signing key for key_type (str): the type of cross-signing key to set key (dict): the key data + added_ts (int): the timestamp for when the key was added """ return self.runInteraction( "add_e2e_cross_signing_key", @@ -353,6 +353,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): user_id, key_type, key, + added_ts, ) def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): @@ -368,7 +369,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): the key will be included in the result Returns: - dict of the key data + dict of the key data or None if not found """ sql = ( "SELECT keydata " @@ -412,7 +413,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): the self-signing key will be included in the result Returns: - dict of the key data + dict of the key data or None if not found """ return self.runInteraction( "get_e2e_cross_signing_key", From 814f253f1b102475fe0baace8b65e2281e7b6a89 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 21 Aug 2019 13:22:15 -0700 Subject: [PATCH 14/55] make isort happy --- synapse/handlers/e2e_keys.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 53ca8330ad..be15597ee8 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -16,7 +16,6 @@ # limitations under the License. import logging - import time from six import iteritems From 3b0b22cb059f7dfd1d7a7878fe391be38ee91d71 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 28 Aug 2019 17:17:21 -0700 Subject: [PATCH 15/55] use stream ID generator instead of timestamp --- synapse/handlers/e2e_keys.py | 7 ++--- synapse/storage/__init__.py | 3 ++ synapse/storage/end_to_end_keys.py | 30 +++++++++---------- .../storage/schema/delta/56/signing_keys.sql | 6 ++-- 4 files changed, 23 insertions(+), 23 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index be15597ee8..d2d9bef1fe 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -16,7 +16,6 @@ # limitations under the License. import logging -import time from six import iteritems @@ -532,12 +531,12 @@ class E2eKeysHandler(object): deviceids = [] if "master_key" in keys: yield self.store.set_e2e_cross_signing_key( - user_id, "master", master_key, time.time() * 1000 + user_id, "master", master_key ) deviceids.append(master_verify_key.version) if "self_signing_key" in keys: yield self.store.set_e2e_cross_signing_key( - user_id, "self_signing", self_signing_key, time.time() * 1000 + user_id, "self_signing", self_signing_key ) try: deviceids.append( @@ -547,7 +546,7 @@ class E2eKeysHandler(object): raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM) if "user_signing_key" in keys: yield self.store.set_e2e_cross_signing_key( - user_id, "user_signing", user_signing_key, time.time() * 1000 + user_id, "user_signing", user_signing_key ) # the signature stream matches the semantics that we want for # user-signing key updates: only the user themselves is notified of diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 0a64f90624..e9a9c2cd8d 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -136,6 +136,9 @@ class DataStore( self._device_list_id_gen = StreamIdGenerator( db_conn, "device_lists_stream", "stream_id" ) + self._cross_signing_id_gen = StreamIdGenerator( + db_conn, "e2e_cross_signing_keys", "stream_id" + ) self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index b218b7b2e8..4b37bffb0b 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -282,7 +282,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) - def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key, added_ts): + def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key): """Set a user's cross-signing key. Args: @@ -292,7 +292,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): for a master key, 'self_signing' for a self-signing key, or 'user_signing' for a user-signing key key (dict): the key data - added_ts (int): the timestamp for when the key was added """ # the cross-signing keys need to occupy the same namespace as devices, # since signatures are identified by device ID. So add an entry to the @@ -327,25 +326,25 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) # and finally, store the key itself - self._simple_insert( - "e2e_cross_signing_keys", - values={ - "user_id": user_id, - "keytype": key_type, - "keydata": json.dumps(key), - "added_ts": added_ts, - }, - desc="store_master_key", - ) + with self._cross_signing_id_gen.get_next() as stream_id: + self._simple_insert( + "e2e_cross_signing_keys", + values={ + "user_id": user_id, + "keytype": key_type, + "keydata": json.dumps(key), + "stream_id": stream_id, + }, + desc="store_master_key", + ) - def set_e2e_cross_signing_key(self, user_id, key_type, key, added_ts): + def set_e2e_cross_signing_key(self, user_id, key_type, key): """Set a user's cross-signing key. Args: user_id (str): the user to set the user-signing key for key_type (str): the type of cross-signing key to set key (dict): the key data - added_ts (int): the timestamp for when the key was added """ return self.runInteraction( "add_e2e_cross_signing_key", @@ -353,7 +352,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): user_id, key_type, key, - added_ts, ) def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None): @@ -374,7 +372,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): sql = ( "SELECT keydata " " FROM e2e_cross_signing_keys " - " WHERE user_id = ? AND keytype = ? ORDER BY added_ts DESC LIMIT 1" + " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1" ) txn.execute(sql, (user_id, key_type)) row = txn.fetchone() diff --git a/synapse/storage/schema/delta/56/signing_keys.sql b/synapse/storage/schema/delta/56/signing_keys.sql index 6a9ef1782e..27a96123e3 100644 --- a/synapse/storage/schema/delta/56/signing_keys.sql +++ b/synapse/storage/schema/delta/56/signing_keys.sql @@ -20,11 +20,11 @@ CREATE TABLE IF NOT EXISTS e2e_cross_signing_keys ( keytype TEXT NOT NULL, -- the full key information, as a json-encoded dict keydata TEXT NOT NULL, - -- time that the key was added - added_ts BIGINT NOT NULL + -- for keeping the keys in order, so that we can fetch the latest one + stream_id BIGINT NOT NULL ); -CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, added_ts); +CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, stream_id); -- cross-signing signatures CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures ( From 96bda563701795537c39d56d82869d953a6bf167 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 28 Aug 2019 17:18:40 -0700 Subject: [PATCH 16/55] black --- synapse/handlers/e2e_keys.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d2d9bef1fe..870810e6ea 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -530,9 +530,7 @@ class E2eKeysHandler(object): # if everything checks out, then store the keys and send notifications deviceids = [] if "master_key" in keys: - yield self.store.set_e2e_cross_signing_key( - user_id, "master", master_key - ) + yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key) deviceids.append(master_verify_key.version) if "self_signing_key" in keys: yield self.store.set_e2e_cross_signing_key( From a22d58c96c714e5f97b3e68f3ec7f2aeee854a81 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 4 Sep 2019 19:32:35 -0400 Subject: [PATCH 17/55] add user signature stream change cache to slaved device store --- synapse/replication/slave/storage/devices.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index d9300fce33..f045e1b937 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -33,6 +33,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max ) + self._user_signature_stream_cache = StreamChangeCache( + "UserSignatureStreamChangeCache", device_list_max + ) self._device_list_federation_stream_cache = StreamChangeCache( "DeviceListFederationStreamChangeCache", device_list_max ) From 4bb454478470c6b707d33292113ac3a23010db8b Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 22 May 2019 16:41:24 -0400 Subject: [PATCH 18/55] implement device signature uploading/fetching --- synapse/handlers/e2e_keys.py | 250 +++++++++++++++++++++++++++ synapse/rest/client/v2_alpha/keys.py | 50 ++++++ synapse/storage/end_to_end_keys.py | 38 ++++ 3 files changed, 338 insertions(+) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 997ad66f8f..9747b517ff 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -608,6 +608,194 @@ class E2eKeysHandler(object): return {} + @defer.inlineCallbacks + def upload_signatures_for_device_keys(self, user_id, signatures): + """Upload device signatures for cross-signing + + Args: + user_id (string): the user uploading the signatures + signatures (dict[string, dict[string, dict]]): map of users to + devices to signed keys + """ + failures = {} + + signature_list = [] # signatures to be stored + self_device_ids = [] # what devices have been updated, for notifying + + # split between checking signatures for own user and signatures for + # other users, since we verify them with different keys + if user_id in signatures: + self_signatures = signatures[user_id] + del signatures[user_id] + self_device_ids = list(self_signatures.keys()) + try: + # get our self-signing key to verify the signatures + self_signing_key = yield self.store.get_e2e_cross_signing_key( + user_id, "self_signing" + ) + if self_signing_key is None: + raise SynapseError( + 404, + "No self-signing key found", + Codes.NOT_FOUND + ) + + self_signing_key_id, self_signing_verify_key \ + = get_verify_key_from_cross_signing_key(self_signing_key) + + # fetch our stored devices so that we can compare with what was sent + user_devices = [] + for device in self_signatures.keys(): + user_devices.append((user_id, device)) + devices = yield self.store.get_e2e_device_keys(user_devices) + + if user_id not in devices: + raise SynapseError( + 404, + "No device key found", + Codes.NOT_FOUND + ) + + devices = devices[user_id] + for device_id, device in self_signatures.items(): + try: + if ("signatures" not in device or + user_id not in device["signatures"] or + self_signing_key_id not in device["signatures"][user_id]): + # no signature was sent + raise SynapseError( + 400, + "Invalid signature", + Codes.INVALID_SIGNATURE + ) + + stored_device = devices[device_id]["keys"] + if self_signing_key_id in stored_device.get("signatures", {}) \ + .get(user_id, {}): + # we already have a signature on this device, so we + # can skip it, since it should be exactly the same + continue + + _check_device_signature( + user_id, self_signing_verify_key, device, stored_device + ) + + signature = device["signatures"][user_id][self_signing_key_id] + signature_list.append( + (self_signing_key_id, user_id, device_id, signature) + ) + except SynapseError as e: + failures.setdefault(user_id, {})[device_id] \ + = _exception_to_failure(e) + except SynapseError as e: + failures[user_id] = { + device: _exception_to_failure(e) + for device in self_signatures.keys() + } + + signed_users = [] # what user have been signed, for notifying + if len(signatures): + # if signatures isn't empty, then we have signatures for other + # users. These signatures will be signed by the user signing key + + # get our user-signing key to verify the signatures + user_signing_key = yield self.store.get_e2e_cross_signing_key( + user_id, "user_signing" + ) + if user_signing_key is None: + for user, devicemap in signatures.items(): + failures[user] = { + device: _exception_to_failure(SynapseError( + 404, + "No user-signing key found", + Codes.NOT_FOUND + )) + for device in devicemap.keys() + } + else: + user_signing_key_id, user_signing_verify_key \ + = get_verify_key_from_cross_signing_key(user_signing_key) + + for user, devicemap in signatures.items(): + device_id = None + try: + # get the user's master key, to make sure it matches + # what was sent + stored_key = yield self.store.get_e2e_cross_signing_key( + user, "master", user_id + ) + if stored_key is None: + logger.error( + "upload signature: no user key found for %s", user + ) + raise SynapseError( + 404, + "User's master key not found", + Codes.NOT_FOUND + ) + + # make sure that the user's master key is the one that + # was signed (and no others) + device_id = get_verify_key_from_cross_signing_key(stored_key)[0] \ + .split(":", 1)[1] + if device_id not in devicemap: + logger.error( + "upload signature: wrong device: %s vs %s", + device, devicemap + ) + raise SynapseError( + 404, + "Unknown device", + Codes.NOT_FOUND + ) + if len(devicemap) > 1: + logger.error("upload signature: too many devices specified") + failures[user] = { + device: _exception_to_failure(SynapseError( + 404, + "Unknown device", + Codes.NOT_FOUND + )) + for device in devicemap.keys() + } + + key = devicemap[device_id] + + if user_signing_key_id in stored_key.get("signatures", {}) \ + .get(user_id, {}): + # we already have the signature, so we can skip it + continue + + _check_device_signature( + user_id, user_signing_verify_key, key, stored_key + ) + + signature = key["signatures"][user_id][user_signing_key_id] + + signed_users.append(user) + signature_list.append( + (user_signing_key_id, user, device_id, signature) + ) + except SynapseError as e: + if device_id is None: + failures[user] = { + device_id: _exception_to_failure(e) + for device_id in devicemap.keys() + } + else: + failures.setdefault(user, {})[device_id] \ + = _exception_to_failure(e) + + # store the signature, and send the appropriate notifications for sync + logger.debug("upload signature failures: %r", failures) + yield self.store.store_e2e_device_signatures(user_id, signature_list) + + if len(self_device_ids): + yield self.device_handler.notify_device_update(user_id, self_device_ids) + if len(signed_users): + yield self.device_handler.notify_user_signature_update(user_id, signed_users) + + defer.returnValue({"failures": failures}) def _check_cross_signing_key(key, user_id, key_type, signing_key=None): """Check a cross-signing key uploaded by a user. Performs some basic sanity @@ -636,6 +824,68 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None): ) +def _check_device_signature(user_id, verify_key, signed_device, stored_device): + """Check that a device signature is correct and matches the copy of the device + that we have. Throws an exception if an error is detected. + + Args: + user_id (str): the user ID whose signature is being checked + verify_key (VerifyKey): the key to verify the device with + signed_device (dict): the signed device data + stored_device (dict): our previous copy of the device + """ + + key_id = "%s:%s" % (verify_key.alg, verify_key.version) + + # make sure the device is signed + if ("signatures" not in signed_device or user_id not in signed_device["signatures"] + or key_id not in signed_device["signatures"][user_id]): + logger.error("upload signature: user not found in signatures") + raise SynapseError( + 400, + "Invalid signature", + Codes.INVALID_SIGNATURE + ) + + signature = signed_device["signatures"][user_id][key_id] + + # make sure that the device submitted matches what we have stored + del signed_device["signatures"] + if "unsigned" in signed_device: + del signed_device["unsigned"] + if "signatures" in stored_device: + del stored_device["signatures"] + if "unsigned" in stored_device: + del stored_device["unsigned"] + if signed_device != stored_device: + logger.error( + "upload signatures: key does not match %s vs %s", + signed_device, stored_device + ) + raise SynapseError( + 400, + "Key does not match", + "M_MISMATCHED_KEY" + ) + + # check the signature + signed_device["signatures"] = { + user_id: { + key_id: signature + } + } + + try: + verify_signed_json(signed_device, user_id, verify_key) + except SignatureVerifyException: + logger.error("invalid signature on key") + raise SynapseError( + 400, + "Invalid signature", + Codes.INVALID_SIGNATURE + ) + + def _exception_to_failure(e): if isinstance(e, CodeMessageException): return {"status": e.code, "message": str(e)} diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 151a70d449..5c288d48b7 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -277,9 +277,59 @@ class SigningKeyUploadServlet(RestServlet): return (200, result) +class SignaturesUploadServlet(RestServlet): + """ + POST /keys/signatures/upload HTTP/1.1 + Content-Type: application/json + + { + "@alice:example.com": { + "": { + "user_id": "", + "device_id": "", + "algorithms": [ + "m.olm.curve25519-aes-sha256", + "m.megolm.v1.aes-sha" + ], + "keys": { + ":": "", + }, + "signatures": { + "": { + ":": ">" + } + } + } + } + } + """ + PATTERNS = client_v2_patterns("/keys/signatures/upload$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(SignaturesUploadServlet, self).__init__() + self.auth = hs.get_auth() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + user_id = requester.user.to_string() + body = parse_json_object_from_request(request) + + result = yield self.e2e_keys_handler.upload_signatures_for_device_keys( + user_id, body + ) + defer.returnValue((200, result)) + + def register_servlets(hs, http_server): KeyUploadServlet(hs).register(http_server) KeyQueryServlet(hs).register(http_server) KeyChangesServlet(hs).register(http_server) OneTimeKeyServlet(hs).register(http_server) SigningKeyUploadServlet(hs).register(http_server) + SignaturesUploadServlet(hs).register(http_server) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 8ce5dd8bf9..fe786f3093 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -59,6 +59,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for user_id, device_keys in iteritems(results): for device_id, device_info in iteritems(device_keys): device_info["keys"] = db_to_json(device_info.pop("key_json")) + # add cross-signing signatures to the keys + if "signatures" in device_info: + for sig_user_id, sigs in device_info["signatures"].items(): + device_info["keys"].setdefault("signatures", {}) \ + .setdefault(sig_user_id, {}) \ + .update(sigs) return results @@ -71,6 +77,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore): query_clauses = [] query_params = [] + signature_query_clauses = [] + signature_query_params = [] if include_all_devices is False: include_deleted_devices = False @@ -81,12 +89,20 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for (user_id, device_id) in query_list: query_clause = "user_id = ?" query_params.append(user_id) + signature_query_clause = "target_user_id = ?" + signature_query_params.append(user_id) if device_id is not None: query_clause += " AND device_id = ?" query_params.append(device_id) + signature_query_clause += " AND target_device_id = ?" + signature_query_params.append(device_id) + + signature_query_clause += " AND user_id = ?" + signature_query_params.append(user_id) query_clauses.append(query_clause) + signature_query_clauses.append(signature_query_clause) sql = ( "SELECT user_id, device_id, " @@ -113,6 +129,28 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for user_id, device_id in deleted_devices: result.setdefault(user_id, {})[device_id] = None + # get signatures on the device + signature_sql = ( + "SELECT * " + " FROM e2e_device_signatures " + " WHERE %s" + ) % ( + " OR ".join("(" + q + ")" for q in signature_query_clauses) + ) + + txn.execute(signature_sql, signature_query_params) + rows = self.cursor_to_dict(txn) + + for row in rows: + target_user_id = row["target_user_id"] + target_device_id = row["target_device_id"] + if target_user_id in result \ + and target_device_id in result[target_user_id]: + result[target_user_id][target_device_id] \ + .setdefault("signatures", {}) \ + .setdefault(row["user_id"], {})[row["key_id"]] \ + = row["signature"] + log_kv(result) return result From ac4746ac4bb4d9371c5a25e94ecccd83effb8b9a Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 17 Jul 2019 22:11:31 -0400 Subject: [PATCH 19/55] allow uploading signatures of master key signed by devices --- synapse/handlers/e2e_keys.py | 232 +++++++++++++++++---------- synapse/rest/client/v2_alpha/keys.py | 2 +- synapse/storage/end_to_end_keys.py | 2 +- tests/handlers/test_e2e_keys.py | 227 +++++++++++++++++++++++++- 4 files changed, 378 insertions(+), 85 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 9747b517ff..1148803c1e 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -20,7 +20,9 @@ import logging from six import iteritems from canonicaljson import encode_canonical_json, json +from signedjson.key import decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json +from unpaddedbase64 import decode_base64 from twisted.internet import defer @@ -619,8 +621,11 @@ class E2eKeysHandler(object): """ failures = {} - signature_list = [] # signatures to be stored - self_device_ids = [] # what devices have been updated, for notifying + # signatures to be stored. Each item will be a tuple of + # (signing_key_id, target_user_id, target_device_id, signature) + signature_list = [] + # what devices have been updated, for notifying + self_device_ids = [] # split between checking signatures for own user and signatures for # other users, since we verify them with different keys @@ -630,46 +635,107 @@ class E2eKeysHandler(object): self_device_ids = list(self_signatures.keys()) try: # get our self-signing key to verify the signatures - self_signing_key = yield self.store.get_e2e_cross_signing_key( - user_id, "self_signing" - ) - if self_signing_key is None: - raise SynapseError( - 404, - "No self-signing key found", - Codes.NOT_FOUND + self_signing_key, self_signing_key_id, self_signing_verify_key \ + = yield self._get_e2e_cross_signing_verify_key( + user_id, "self_signing" ) - self_signing_key_id, self_signing_verify_key \ - = get_verify_key_from_cross_signing_key(self_signing_key) + # get our master key, since it may be signed + master_key, master_key_id, master_verify_key \ + = yield self._get_e2e_cross_signing_verify_key( + user_id, "master" + ) - # fetch our stored devices so that we can compare with what was sent - user_devices = [] - for device in self_signatures.keys(): - user_devices.append((user_id, device)) - devices = yield self.store.get_e2e_device_keys(user_devices) + # fetch our stored devices. This is used to 1. verify + # signatures on the master key, and 2. to can compare with what + # was sent if the device was signed + devices = yield self.store.get_e2e_device_keys([(user_id, None)]) if user_id not in devices: raise SynapseError( - 404, - "No device key found", - Codes.NOT_FOUND + 404, "No device keys found", Codes.NOT_FOUND ) devices = devices[user_id] for device_id, device in self_signatures.items(): try: if ("signatures" not in device or - user_id not in device["signatures"] or - self_signing_key_id not in device["signatures"][user_id]): + user_id not in device["signatures"]): # no signature was sent raise SynapseError( - 400, - "Invalid signature", - Codes.INVALID_SIGNATURE + 400, "Invalid signature", Codes.INVALID_SIGNATURE ) - stored_device = devices[device_id]["keys"] + if device_id == master_verify_key.version: + # we have master key signed by devices: for each + # device that signed, check the signature. Since + # the "failures" property in the response only has + # granularity up to the signed device, either all + # of the signatures on the master key succeed, or + # all fail. So loop over the signatures and add + # them to a separate signature list. If everything + # works out, then add them all to the main + # signature list. (In practice, we're likely to + # only have only one signature anyways.) + master_key_signature_list = [] + for signing_key_id, signature in device["signatures"][user_id].items(): + alg, signing_device_id = signing_key_id.split(":", 1) + if (signing_device_id not in devices or + signing_key_id not in + devices[signing_device_id]["keys"]["keys"]): + # signed by an unknown device, or the + # device does not have the key + raise SynapseError( + 400, "Invalid signature", Codes.INVALID_SIGNATURE + ) + + sigs = device["signatures"] + del device["signatures"] + # use pop to avoid exception if key doesn't exist + device.pop("unsigned", None) + master_key.pop("signature", None) + master_key.pop("unsigned", None) + + if master_key != device: + raise SynapseError( + 400, "Key does not match" + ) + + # get the key and check the signature + pubkey = devices[signing_device_id]["keys"]["keys"][signing_key_id] + verify_key = decode_verify_key_bytes( + signing_key_id, decode_base64(pubkey) + ) + device["signatures"] = sigs + try: + verify_signed_json(device, user_id, verify_key) + except SignatureVerifyException: + raise SynapseError( + 400, "Invalid signature", Codes.INVALID_SIGNATURE + ) + + master_key_signature_list.append( + (signing_key_id, user_id, device_id, signature) + ) + + signature_list.extend(master_key_signature_list) + continue + + # at this point, we have a device that should be signed + # by the self-signing key + if self_signing_key_id not in device["signatures"][user_id]: + # no signature was sent + raise SynapseError( + 400, "Invalid signature", Codes.INVALID_SIGNATURE + ) + + stored_device = None + try: + stored_device = devices[device_id]["keys"] + except KeyError: + raise SynapseError( + 404, "Unknown device", Codes.NOT_FOUND + ) if self_signing_key_id in stored_device.get("signatures", {}) \ .get(user_id, {}): # we already have a signature on this device, so we @@ -698,69 +764,50 @@ class E2eKeysHandler(object): # if signatures isn't empty, then we have signatures for other # users. These signatures will be signed by the user signing key - # get our user-signing key to verify the signatures - user_signing_key = yield self.store.get_e2e_cross_signing_key( - user_id, "user_signing" - ) - if user_signing_key is None: - for user, devicemap in signatures.items(): - failures[user] = { - device: _exception_to_failure(SynapseError( - 404, - "No user-signing key found", - Codes.NOT_FOUND - )) - for device in devicemap.keys() - } - else: - user_signing_key_id, user_signing_verify_key \ - = get_verify_key_from_cross_signing_key(user_signing_key) + try: + # get our user-signing key to verify the signatures + user_signing_key, user_signing_key_id, user_signing_verify_key \ + = yield self._get_e2e_cross_signing_verify_key( + user_id, "user_signing" + ) for user, devicemap in signatures.items(): device_id = None try: # get the user's master key, to make sure it matches # what was sent - stored_key = yield self.store.get_e2e_cross_signing_key( - user, "master", user_id - ) - if stored_key is None: - logger.error( - "upload signature: no user key found for %s", user - ) - raise SynapseError( - 404, - "User's master key not found", - Codes.NOT_FOUND + stored_key, stored_key_id, _ \ + = yield self._get_e2e_cross_signing_verify_key( + user, "master", user_id ) # make sure that the user's master key is the one that # was signed (and no others) - device_id = get_verify_key_from_cross_signing_key(stored_key)[0] \ - .split(":", 1)[1] + device_id = stored_key_id.split(":", 1)[1] if device_id not in devicemap: + # set device to None so that the failure gets + # marked on all the signatures + device_id = None logger.error( "upload signature: wrong device: %s vs %s", device, devicemap ) raise SynapseError( - 404, - "Unknown device", - Codes.NOT_FOUND + 404, "Unknown device", Codes.NOT_FOUND ) - if len(devicemap) > 1: + key = devicemap[device_id] + del devicemap[device_id] + if len(devicemap) > 0: + # other devices were signed -- mark those as failures logger.error("upload signature: too many devices specified") + failure = _exception_to_failure(SynapseError( + 404, "Unknown device", Codes.NOT_FOUND + )) failures[user] = { - device: _exception_to_failure(SynapseError( - 404, - "Unknown device", - Codes.NOT_FOUND - )) + device: failure for device in devicemap.keys() } - key = devicemap[device_id] - if user_signing_key_id in stored_key.get("signatures", {}) \ .get(user_id, {}): # we already have the signature, so we can skip it @@ -770,25 +817,31 @@ class E2eKeysHandler(object): user_id, user_signing_verify_key, key, stored_key ) - signature = key["signatures"][user_id][user_signing_key_id] - signed_users.append(user) + signature = key["signatures"][user_id][user_signing_key_id] signature_list.append( (user_signing_key_id, user, device_id, signature) ) except SynapseError as e: + failure = _exception_to_failure(e) if device_id is None: failures[user] = { - device_id: _exception_to_failure(e) + device_id: failure for device_id in devicemap.keys() } else: - failures.setdefault(user, {})[device_id] \ - = _exception_to_failure(e) + failures.setdefault(user, {})[device_id] = failure + except SynapseError as e: + failure = _exception_to_failure(e) + for user, devicemap in signature.items(): + failures[user] = { + device_id: failure + for device_id in devicemap.keys() + } # store the signature, and send the appropriate notifications for sync logger.debug("upload signature failures: %r", failures) - yield self.store.store_e2e_device_signatures(user_id, signature_list) + yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) if len(self_device_ids): yield self.device_handler.notify_device_update(user_id, self_device_ids) @@ -797,6 +850,22 @@ class E2eKeysHandler(object): defer.returnValue({"failures": failures}) + @defer.inlineCallbacks + def _get_e2e_cross_signing_verify_key(self, user_id, key_type, from_user_id=None): + key = yield self.store.get_e2e_cross_signing_key( + user_id, key_type, from_user_id + ) + if key is None: + logger.error("no %s key found for %s", key_type, user_id) + raise SynapseError( + 404, + "No %s key found for %s" % (key_type, user_id), + Codes.NOT_FOUND + ) + key_id, verify_key = get_verify_key_from_cross_signing_key(key) + return key, key_id, verify_key + + def _check_cross_signing_key(key, user_id, key_type, signing_key=None): """Check a cross-signing key uploaded by a user. Performs some basic sanity checking, and ensures that it is signed, if a signature is required. @@ -851,21 +920,17 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): # make sure that the device submitted matches what we have stored del signed_device["signatures"] - if "unsigned" in signed_device: - del signed_device["unsigned"] - if "signatures" in stored_device: - del stored_device["signatures"] - if "unsigned" in stored_device: - del stored_device["unsigned"] + # use pop to avoid exception if key doesn't exist + signed_device.pop("unsigned", None) + stored_device.pop("signatures", None) + stored_device.pop("unsigned", None) if signed_device != stored_device: logger.error( "upload signatures: key does not match %s vs %s", signed_device, stored_device ) raise SynapseError( - 400, - "Key does not match", - "M_MISMATCHED_KEY" + 400, "Key does not match", ) # check the signature @@ -887,6 +952,9 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): def _exception_to_failure(e): + if isinstance(e, SynapseError): + return {"status": e.code, "errcode": e.errcode, "message": str(e)} + if isinstance(e, CodeMessageException): return {"status": e.code, "message": str(e)} diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 5c288d48b7..cb3c52cb8e 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -303,7 +303,7 @@ class SignaturesUploadServlet(RestServlet): } } """ - PATTERNS = client_v2_patterns("/keys/signatures/upload$") + PATTERNS = client_patterns("/keys/signatures/upload$") def __init__(self, hs): """ diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index fe786f3093..e68ce318af 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -132,7 +132,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # get signatures on the device signature_sql = ( "SELECT * " - " FROM e2e_device_signatures " + " FROM e2e_cross_signing_signatures " " WHERE %s" ) % ( " OR ".join("(" + q + ")" for q in signature_query_clauses) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index a62c52eefa..b1d3a4cfae 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -17,9 +17,10 @@ import mock +import signedjson.key as key +import signedjson.sign as sign from twisted.internet import defer -import synapse.api.errors import synapse.handlers.e2e_keys import synapse.storage from synapse.api import errors @@ -210,3 +211,227 @@ class E2eKeysHandlerTestCase(unittest.TestCase): res = yield self.handler.query_local_devices({local_user: None}) self.assertDictEqual(res, {local_user: {}}) + + @defer.inlineCallbacks + def test_upload_signatures(self): + """should check signatures that are uploaded""" + # set up a user with cross-signing keys and a device. This user will + # try uploading signatures + local_user = "@boris:" + self.hs.hostname + device_id = "xyz" + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + device_pubkey = "NnHhnqiMFQkq969szYkooLaBAXW244ZOxgukCvm2ZeY" + device_key = { + "user_id": local_user, + "device_id": device_id, + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "keys": { + "curve25519:xyz": "curve25519+key", + "ed25519:xyz": device_pubkey + }, + "signatures": { + local_user: { + "ed25519:xyz": "something" + } + } + } + device_signing_key = key.decode_signing_key_base64( + "ed25519", + "xyz", + "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" + ) + + yield self.handler.upload_keys_for_user( + local_user, device_id, {"device_keys": device_key} + ) + + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + master_pubkey = "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + master_key = { + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:" + master_pubkey: master_pubkey + } + } + master_signing_key = key.decode_signing_key_base64( + "ed25519", master_pubkey, + "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0" + ) + usersigning_pubkey = "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw" + usersigning_key = { + # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs + "user_id": local_user, + "usage": ["user_signing"], + "keys": { + "ed25519:" + usersigning_pubkey: usersigning_pubkey, + } + } + usersigning_signing_key = key.decode_signing_key_base64( + "ed25519", usersigning_pubkey, + "4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs" + ) + sign.sign_json(usersigning_key, local_user, master_signing_key) + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + selfsigning_pubkey = "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ" + selfsigning_key = { + "user_id": local_user, + "usage": ["self_signing"], + "keys": { + "ed25519:" + selfsigning_pubkey: selfsigning_pubkey, + } + } + selfsigning_signing_key = key.decode_signing_key_base64( + "ed25519", selfsigning_pubkey, + "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8" + ) + sign.sign_json(selfsigning_key, local_user, master_signing_key) + cross_signing_keys = { + "master_key": master_key, + "user_signing_key": usersigning_key, + "self_signing_key": selfsigning_key, + } + yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys) + + # set up another user with a master key. This user will be signed by + # the first user + other_user = "@otherboris:" + self.hs.hostname + other_master_pubkey = "fHZ3NPiKxoLQm5OoZbKa99SYxprOjNs4TwJUKP+twCM" + other_master_key = { + # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI + "user_id": other_user, + "usage": ["master"], + "keys": { + "ed25519:" + other_master_pubkey: other_master_pubkey + } + } + yield self.handler.upload_signing_keys_for_user(other_user, { + "master_key": other_master_key + }) + + # test various signature failures (see below) + ret = yield self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: { + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + device_id: { + "user_id": local_user, + "device_id": device_id, + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "keys": { + "curve25519:xyz": "curve25519+key", + # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA + "ed25519:xyz": device_pubkey + }, + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something", + } + } + }, + # fails because device is unknown + # should fail with NOT_FOUND + "unknown": { + "user_id": local_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + selfsigning_pubkey: "something", + } + } + }, + # fails because the signature is invalid + # should fail with INVALID_SIGNATURE + master_pubkey: { + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:" + master_pubkey: master_pubkey + }, + "signatures": { + local_user: { + "ed25519:" + device_pubkey: "something", + } + } + } + }, + other_user: { + # fails because the device is not the user's master-signing key + # should fail with NOT_FOUND + "unknown": { + "user_id": other_user, + "device_id": "unknown", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something", + } + } + }, + other_master_pubkey: { + # fails because the key doesn't match what the server has + # should fail with UNKNOWN + "user_id": other_user, + "usage": ["master"], + "keys": { + "ed25519:" + other_master_pubkey: other_master_pubkey + }, + "something": "random", + "signatures": { + local_user: { + "ed25519:" + usersigning_pubkey: "something", + } + } + } + } + } + ) + + user_failures = ret["failures"][local_user] + self.assertEqual(user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE) + self.assertEqual(user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE) + self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND) + + other_user_failures = ret["failures"][other_user] + self.assertEqual(other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND) + self.assertEqual(other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN) + + # test successful signatures + del device_key["signatures"] + sign.sign_json(device_key, local_user, selfsigning_signing_key) + sign.sign_json(master_key, local_user, device_signing_key) + sign.sign_json(other_master_key, local_user, usersigning_signing_key) + ret = yield self.handler.upload_signatures_for_device_keys( + local_user, + { + local_user: { + device_id: device_key, + master_pubkey: master_key + }, + other_user: { + other_master_pubkey: other_master_key + } + } + ) + + self.assertEqual(ret["failures"], {}) + + # fetch the signed keys/devices and make sure that the signatures are there + ret = yield self.handler.query_devices( + {"device_keys": {local_user: [], other_user: []}}, + 0, local_user + ) + + self.assertEqual( + ret["device_keys"][local_user]["xyz"]["signatures"][local_user]["ed25519:" + selfsigning_pubkey], + device_key["signatures"][local_user]["ed25519:" + selfsigning_pubkey] + ) + self.assertEqual( + ret["master_keys"][local_user]["signatures"][local_user]["ed25519:" + device_id], + master_key["signatures"][local_user]["ed25519:" + device_id] + ) + self.assertEqual( + ret["master_keys"][other_user]["signatures"][local_user]["ed25519:" + usersigning_pubkey], + other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey] + ) From 7d6c70fc7ad08b94b8b577c537953a8d9b568562 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 22 Jul 2019 12:52:39 -0400 Subject: [PATCH 20/55] make black happy --- synapse/handlers/e2e_keys.py | 147 +++++++++++++-------------- synapse/rest/client/v2_alpha/keys.py | 1 + synapse/storage/end_to_end_keys.py | 24 ++--- tests/handlers/test_e2e_keys.py | 147 +++++++++++---------------- 4 files changed, 141 insertions(+), 178 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 1148803c1e..74bceddc46 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -635,16 +635,14 @@ class E2eKeysHandler(object): self_device_ids = list(self_signatures.keys()) try: # get our self-signing key to verify the signatures - self_signing_key, self_signing_key_id, self_signing_verify_key \ - = yield self._get_e2e_cross_signing_verify_key( - user_id, "self_signing" - ) + self_signing_key, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( + user_id, "self_signing" + ) # get our master key, since it may be signed - master_key, master_key_id, master_verify_key \ - = yield self._get_e2e_cross_signing_verify_key( - user_id, "master" - ) + master_key, master_key_id, master_verify_key = yield self._get_e2e_cross_signing_verify_key( + user_id, "master" + ) # fetch our stored devices. This is used to 1. verify # signatures on the master key, and 2. to can compare with what @@ -652,15 +650,15 @@ class E2eKeysHandler(object): devices = yield self.store.get_e2e_device_keys([(user_id, None)]) if user_id not in devices: - raise SynapseError( - 404, "No device keys found", Codes.NOT_FOUND - ) + raise SynapseError(404, "No device keys found", Codes.NOT_FOUND) devices = devices[user_id] for device_id, device in self_signatures.items(): try: - if ("signatures" not in device or - user_id not in device["signatures"]): + if ( + "signatures" not in device + or user_id not in device["signatures"] + ): # no signature was sent raise SynapseError( 400, "Invalid signature", Codes.INVALID_SIGNATURE @@ -678,15 +676,21 @@ class E2eKeysHandler(object): # signature list. (In practice, we're likely to # only have only one signature anyways.) master_key_signature_list = [] - for signing_key_id, signature in device["signatures"][user_id].items(): + for signing_key_id, signature in device["signatures"][ + user_id + ].items(): alg, signing_device_id = signing_key_id.split(":", 1) - if (signing_device_id not in devices or - signing_key_id not in - devices[signing_device_id]["keys"]["keys"]): + if ( + signing_device_id not in devices + or signing_key_id + not in devices[signing_device_id]["keys"]["keys"] + ): # signed by an unknown device, or the # device does not have the key raise SynapseError( - 400, "Invalid signature", Codes.INVALID_SIGNATURE + 400, + "Invalid signature", + Codes.INVALID_SIGNATURE, ) sigs = device["signatures"] @@ -697,12 +701,12 @@ class E2eKeysHandler(object): master_key.pop("unsigned", None) if master_key != device: - raise SynapseError( - 400, "Key does not match" - ) + raise SynapseError(400, "Key does not match") # get the key and check the signature - pubkey = devices[signing_device_id]["keys"]["keys"][signing_key_id] + pubkey = devices[signing_device_id]["keys"]["keys"][ + signing_key_id + ] verify_key = decode_verify_key_bytes( signing_key_id, decode_base64(pubkey) ) @@ -711,7 +715,9 @@ class E2eKeysHandler(object): verify_signed_json(device, user_id, verify_key) except SignatureVerifyException: raise SynapseError( - 400, "Invalid signature", Codes.INVALID_SIGNATURE + 400, + "Invalid signature", + Codes.INVALID_SIGNATURE, ) master_key_signature_list.append( @@ -733,11 +739,10 @@ class E2eKeysHandler(object): try: stored_device = devices[device_id]["keys"] except KeyError: - raise SynapseError( - 404, "Unknown device", Codes.NOT_FOUND - ) - if self_signing_key_id in stored_device.get("signatures", {}) \ - .get(user_id, {}): + raise SynapseError(404, "Unknown device", Codes.NOT_FOUND) + if self_signing_key_id in stored_device.get( + "signatures", {} + ).get(user_id, {}): # we already have a signature on this device, so we # can skip it, since it should be exactly the same continue @@ -751,8 +756,9 @@ class E2eKeysHandler(object): (self_signing_key_id, user_id, device_id, signature) ) except SynapseError as e: - failures.setdefault(user_id, {})[device_id] \ - = _exception_to_failure(e) + failures.setdefault(user_id, {})[ + device_id + ] = _exception_to_failure(e) except SynapseError as e: failures[user_id] = { device: _exception_to_failure(e) @@ -766,20 +772,18 @@ class E2eKeysHandler(object): try: # get our user-signing key to verify the signatures - user_signing_key, user_signing_key_id, user_signing_verify_key \ - = yield self._get_e2e_cross_signing_verify_key( - user_id, "user_signing" - ) + user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( + user_id, "user_signing" + ) for user, devicemap in signatures.items(): device_id = None try: # get the user's master key, to make sure it matches # what was sent - stored_key, stored_key_id, _ \ - = yield self._get_e2e_cross_signing_verify_key( - user, "master", user_id - ) + stored_key, stored_key_id, _ = yield self._get_e2e_cross_signing_verify_key( + user, "master", user_id + ) # make sure that the user's master key is the one that # was signed (and no others) @@ -790,26 +794,25 @@ class E2eKeysHandler(object): device_id = None logger.error( "upload signature: wrong device: %s vs %s", - device, devicemap - ) - raise SynapseError( - 404, "Unknown device", Codes.NOT_FOUND + device, + devicemap, ) + raise SynapseError(404, "Unknown device", Codes.NOT_FOUND) key = devicemap[device_id] del devicemap[device_id] if len(devicemap) > 0: # other devices were signed -- mark those as failures logger.error("upload signature: too many devices specified") - failure = _exception_to_failure(SynapseError( - 404, "Unknown device", Codes.NOT_FOUND - )) + failure = _exception_to_failure( + SynapseError(404, "Unknown device", Codes.NOT_FOUND) + ) failures[user] = { - device: failure - for device in devicemap.keys() + device: failure for device in devicemap.keys() } - if user_signing_key_id in stored_key.get("signatures", {}) \ - .get(user_id, {}): + if user_signing_key_id in stored_key.get("signatures", {}).get( + user_id, {} + ): # we already have the signature, so we can skip it continue @@ -826,8 +829,7 @@ class E2eKeysHandler(object): failure = _exception_to_failure(e) if device_id is None: failures[user] = { - device_id: failure - for device_id in devicemap.keys() + device_id: failure for device_id in devicemap.keys() } else: failures.setdefault(user, {})[device_id] = failure @@ -835,8 +837,7 @@ class E2eKeysHandler(object): failure = _exception_to_failure(e) for user, devicemap in signature.items(): failures[user] = { - device_id: failure - for device_id in devicemap.keys() + device_id: failure for device_id in devicemap.keys() } # store the signature, and send the appropriate notifications for sync @@ -846,7 +847,9 @@ class E2eKeysHandler(object): if len(self_device_ids): yield self.device_handler.notify_device_update(user_id, self_device_ids) if len(signed_users): - yield self.device_handler.notify_user_signature_update(user_id, signed_users) + yield self.device_handler.notify_user_signature_update( + user_id, signed_users + ) defer.returnValue({"failures": failures}) @@ -858,9 +861,7 @@ class E2eKeysHandler(object): if key is None: logger.error("no %s key found for %s", key_type, user_id) raise SynapseError( - 404, - "No %s key found for %s" % (key_type, user_id), - Codes.NOT_FOUND + 404, "No %s key found for %s" % (key_type, user_id), Codes.NOT_FOUND ) key_id, verify_key = get_verify_key_from_cross_signing_key(key) return key, key_id, verify_key @@ -907,14 +908,13 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): key_id = "%s:%s" % (verify_key.alg, verify_key.version) # make sure the device is signed - if ("signatures" not in signed_device or user_id not in signed_device["signatures"] - or key_id not in signed_device["signatures"][user_id]): + if ( + "signatures" not in signed_device + or user_id not in signed_device["signatures"] + or key_id not in signed_device["signatures"][user_id] + ): logger.error("upload signature: user not found in signatures") - raise SynapseError( - 400, - "Invalid signature", - Codes.INVALID_SIGNATURE - ) + raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) signature = signed_device["signatures"][user_id][key_id] @@ -927,28 +927,19 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): if signed_device != stored_device: logger.error( "upload signatures: key does not match %s vs %s", - signed_device, stored_device - ) - raise SynapseError( - 400, "Key does not match", + signed_device, + stored_device, ) + raise SynapseError(400, "Key does not match") # check the signature - signed_device["signatures"] = { - user_id: { - key_id: signature - } - } + signed_device["signatures"] = {user_id: {key_id: signature}} try: verify_signed_json(signed_device, user_id, verify_key) except SignatureVerifyException: logger.error("invalid signature on key") - raise SynapseError( - 400, - "Invalid signature", - Codes.INVALID_SIGNATURE - ) + raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) def _exception_to_failure(e): diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index cb3c52cb8e..a205281830 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -303,6 +303,7 @@ class SignaturesUploadServlet(RestServlet): } } """ + PATTERNS = client_patterns("/keys/signatures/upload$") def __init__(self, hs): diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index e68ce318af..258e8dcb47 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -62,9 +62,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # add cross-signing signatures to the keys if "signatures" in device_info: for sig_user_id, sigs in device_info["signatures"].items(): - device_info["keys"].setdefault("signatures", {}) \ - .setdefault(sig_user_id, {}) \ - .update(sigs) + device_info["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) return results @@ -131,12 +131,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore): # get signatures on the device signature_sql = ( - "SELECT * " - " FROM e2e_cross_signing_signatures " - " WHERE %s" - ) % ( - " OR ".join("(" + q + ")" for q in signature_query_clauses) - ) + "SELECT * " " FROM e2e_cross_signing_signatures " " WHERE %s" + ) % (" OR ".join("(" + q + ")" for q in signature_query_clauses)) txn.execute(signature_sql, signature_query_params) rows = self.cursor_to_dict(txn) @@ -144,12 +140,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore): for row in rows: target_user_id = row["target_user_id"] target_device_id = row["target_device_id"] - if target_user_id in result \ - and target_device_id in result[target_user_id]: - result[target_user_id][target_device_id] \ - .setdefault("signatures", {}) \ - .setdefault(row["user_id"], {})[row["key_id"]] \ - = row["signature"] + if target_user_id in result and target_device_id in result[target_user_id]: + result[target_user_id][target_device_id].setdefault( + "signatures", {} + ).setdefault(row["user_id"], {})[row["key_id"]] = row["signature"] log_kv(result) return result diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index b1d3a4cfae..8c0ee3f7d3 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -225,20 +225,11 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_id": local_user, "device_id": device_id, "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], - "keys": { - "curve25519:xyz": "curve25519+key", - "ed25519:xyz": device_pubkey - }, - "signatures": { - local_user: { - "ed25519:xyz": "something" - } - } + "keys": {"curve25519:xyz": "curve25519+key", "ed25519:xyz": device_pubkey}, + "signatures": {local_user: {"ed25519:xyz": "something"}}, } device_signing_key = key.decode_signing_key_base64( - "ed25519", - "xyz", - "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" + "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA" ) yield self.handler.upload_keys_for_user( @@ -250,26 +241,20 @@ class E2eKeysHandlerTestCase(unittest.TestCase): master_key = { "user_id": local_user, "usage": ["master"], - "keys": { - "ed25519:" + master_pubkey: master_pubkey - } + "keys": {"ed25519:" + master_pubkey: master_pubkey}, } master_signing_key = key.decode_signing_key_base64( - "ed25519", master_pubkey, - "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0" + "ed25519", master_pubkey, "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0" ) usersigning_pubkey = "Hq6gL+utB4ET+UvD5ci0kgAwsX6qP/zvf8v6OInU5iw" usersigning_key = { # private key: 4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs "user_id": local_user, "usage": ["user_signing"], - "keys": { - "ed25519:" + usersigning_pubkey: usersigning_pubkey, - } + "keys": {"ed25519:" + usersigning_pubkey: usersigning_pubkey}, } usersigning_signing_key = key.decode_signing_key_base64( - "ed25519", usersigning_pubkey, - "4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs" + "ed25519", usersigning_pubkey, "4TL4AjRYwDVwD3pqQzcor+ez/euOB1/q78aTJ+czDNs" ) sign.sign_json(usersigning_key, local_user, master_signing_key) # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 @@ -277,13 +262,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase): selfsigning_key = { "user_id": local_user, "usage": ["self_signing"], - "keys": { - "ed25519:" + selfsigning_pubkey: selfsigning_pubkey, - } + "keys": {"ed25519:" + selfsigning_pubkey: selfsigning_pubkey}, } selfsigning_signing_key = key.decode_signing_key_base64( - "ed25519", selfsigning_pubkey, - "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8" + "ed25519", selfsigning_pubkey, "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8" ) sign.sign_json(selfsigning_key, local_user, master_signing_key) cross_signing_keys = { @@ -301,13 +283,11 @@ class E2eKeysHandlerTestCase(unittest.TestCase): # private key: oyw2ZUx0O4GifbfFYM0nQvj9CL0b8B7cyN4FprtK8OI "user_id": other_user, "usage": ["master"], - "keys": { - "ed25519:" + other_master_pubkey: other_master_pubkey - } + "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, } - yield self.handler.upload_signing_keys_for_user(other_user, { - "master_key": other_master_key - }) + yield self.handler.upload_signing_keys_for_user( + other_user, {"master_key": other_master_key} + ) # test various signature failures (see below) ret = yield self.handler.upload_signatures_for_device_keys( @@ -319,17 +299,18 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_id: { "user_id": local_user, "device_id": device_id, - "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "algorithms": [ + "m.olm.curve25519-aes-sha256", + "m.megolm.v1.aes-sha", + ], "keys": { "curve25519:xyz": "curve25519+key", # private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA - "ed25519:xyz": device_pubkey + "ed25519:xyz": device_pubkey, }, "signatures": { - local_user: { - "ed25519:" + selfsigning_pubkey: "something", - } - } + local_user: {"ed25519:" + selfsigning_pubkey: "something"} + }, }, # fails because device is unknown # should fail with NOT_FOUND @@ -337,25 +318,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_id": local_user, "device_id": "unknown", "signatures": { - local_user: { - "ed25519:" + selfsigning_pubkey: "something", - } - } + local_user: {"ed25519:" + selfsigning_pubkey: "something"} + }, }, # fails because the signature is invalid # should fail with INVALID_SIGNATURE master_pubkey: { "user_id": local_user, "usage": ["master"], - "keys": { - "ed25519:" + master_pubkey: master_pubkey - }, + "keys": {"ed25519:" + master_pubkey: master_pubkey}, "signatures": { - local_user: { - "ed25519:" + device_pubkey: "something", - } - } - } + local_user: {"ed25519:" + device_pubkey: "something"} + }, + }, }, other_user: { # fails because the device is not the user's master-signing key @@ -364,38 +339,40 @@ class E2eKeysHandlerTestCase(unittest.TestCase): "user_id": other_user, "device_id": "unknown", "signatures": { - local_user: { - "ed25519:" + usersigning_pubkey: "something", - } - } + local_user: {"ed25519:" + usersigning_pubkey: "something"} + }, }, other_master_pubkey: { # fails because the key doesn't match what the server has # should fail with UNKNOWN "user_id": other_user, "usage": ["master"], - "keys": { - "ed25519:" + other_master_pubkey: other_master_pubkey - }, + "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey}, "something": "random", "signatures": { - local_user: { - "ed25519:" + usersigning_pubkey: "something", - } - } - } - } - } + local_user: {"ed25519:" + usersigning_pubkey: "something"} + }, + }, + }, + }, ) user_failures = ret["failures"][local_user] - self.assertEqual(user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE) - self.assertEqual(user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE) + self.assertEqual( + user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE + ) + self.assertEqual( + user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE + ) self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND) other_user_failures = ret["failures"][other_user] - self.assertEqual(other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND) - self.assertEqual(other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN) + self.assertEqual( + other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND + ) + self.assertEqual( + other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN + ) # test successful signatures del device_key["signatures"] @@ -405,33 +382,33 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ret = yield self.handler.upload_signatures_for_device_keys( local_user, { - local_user: { - device_id: device_key, - master_pubkey: master_key - }, - other_user: { - other_master_pubkey: other_master_key - } - } + local_user: {device_id: device_key, master_pubkey: master_key}, + other_user: {other_master_pubkey: other_master_key}, + }, ) self.assertEqual(ret["failures"], {}) # fetch the signed keys/devices and make sure that the signatures are there ret = yield self.handler.query_devices( - {"device_keys": {local_user: [], other_user: []}}, - 0, local_user + {"device_keys": {local_user: [], other_user: []}}, 0, local_user ) self.assertEqual( - ret["device_keys"][local_user]["xyz"]["signatures"][local_user]["ed25519:" + selfsigning_pubkey], - device_key["signatures"][local_user]["ed25519:" + selfsigning_pubkey] + ret["device_keys"][local_user]["xyz"]["signatures"][local_user][ + "ed25519:" + selfsigning_pubkey + ], + device_key["signatures"][local_user]["ed25519:" + selfsigning_pubkey], ) self.assertEqual( - ret["master_keys"][local_user]["signatures"][local_user]["ed25519:" + device_id], - master_key["signatures"][local_user]["ed25519:" + device_id] + ret["master_keys"][local_user]["signatures"][local_user][ + "ed25519:" + device_id + ], + master_key["signatures"][local_user]["ed25519:" + device_id], ) self.assertEqual( - ret["master_keys"][other_user]["signatures"][local_user]["ed25519:" + usersigning_pubkey], - other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey] + ret["master_keys"][other_user]["signatures"][local_user][ + "ed25519:" + usersigning_pubkey + ], + other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) From 9061b4198af4b30bb99d98aab7ad227f8ed636f8 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 22 Jul 2019 12:58:04 -0400 Subject: [PATCH 21/55] make isort happy --- tests/handlers/test_e2e_keys.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 8c0ee3f7d3..c900451e03 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -19,6 +19,7 @@ import mock import signedjson.key as key import signedjson.sign as sign + from twisted.internet import defer import synapse.handlers.e2e_keys From 5914fd09c725342d03f702a50ec1da6290e946a9 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Mon, 22 Jul 2019 13:01:10 -0400 Subject: [PATCH 22/55] add test --- tests/handlers/test_e2e_keys.py | 88 +++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index c900451e03..316dd6259d 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -183,6 +183,94 @@ class E2eKeysHandlerTestCase(unittest.TestCase): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) + @defer.inlineCallbacks + def test_reupload_signatures(self): + """re-uploading a signature should not fail""" + local_user = "@boris:" + self.hs.hostname + keys1 = { + "master_key": { + # private key: HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8 + "user_id": local_user, + "usage": ["master"], + "keys": { + "ed25519:EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ": "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ" + }, + }, + "self_signing_key": { + # private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0 + "user_id": local_user, + "usage": ["self_signing"], + "keys": { + "ed25519:nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk": "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk" + }, + }, + } + master_signing_key = key.decode_signing_key_base64( + "ed25519", + "EmkqvokUn8p+vQAGZitOk4PWjp7Ukp3txV2TbMPEiBQ", + "HvQBbU+hc2Zr+JP1sE0XwBe1pfZZEYtJNPJLZJtS+F8", + ) + sign.sign_json(keys1["self_signing_key"], local_user, master_signing_key) + signing_key = key.decode_signing_key_base64( + "ed25519", + "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk", + "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0", + ) + yield self.handler.upload_signing_keys_for_user(local_user, keys1) + + # upload two device keys, which will be signed later by the self-signing key + device_key_1 = { + "user_id": local_user, + "device_id": "abc", + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "keys": { + "ed25519:abc": "base64+ed25519+key", + "curve25519:abc": "base64+curve25519+key", + }, + "signatures": {local_user: {"ed25519:abc": "base64+signature"}}, + } + device_key_2 = { + "user_id": local_user, + "device_id": "def", + "algorithms": ["m.olm.curve25519-aes-sha256", "m.megolm.v1.aes-sha"], + "keys": { + "ed25519:def": "base64+ed25519+key", + "curve25519:def": "base64+curve25519+key", + }, + "signatures": {local_user: {"ed25519:def": "base64+signature"}}, + } + + yield self.handler.upload_keys_for_user( + local_user, "abc", {"device_keys": device_key_1} + ) + yield self.handler.upload_keys_for_user( + local_user, "def", {"device_keys": device_key_2} + ) + + # sign the first device key and upload it + del device_key_1["signatures"] + sign.sign_json(device_key_1, local_user, signing_key) + yield self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1}} + ) + + # sign the second device key and upload both device keys. The server + # should ignore the first device key since it already has a valid + # signature for it + del device_key_2["signatures"] + sign.sign_json(device_key_2, local_user, signing_key) + yield self.handler.upload_signatures_for_device_keys( + local_user, {local_user: {"abc": device_key_1, "def": device_key_2}} + ) + + device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" + device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" + devices = yield self.handler.query_devices({"device_keys": {local_user: []}}, 0) + del devices["device_keys"][local_user]["abc"]["unsigned"] + del devices["device_keys"][local_user]["def"]["unsigned"] + self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) + self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) + @defer.inlineCallbacks def test_self_signing_key_doesnt_show_up_as_device(self): """signing keys should be hidden when fetching a user's devices""" From c8dc740a94f20c0bca9aaa30b9d0fd211361a21e Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Wed, 4 Sep 2019 22:30:45 -0400 Subject: [PATCH 23/55] update with newer coding style --- synapse/handlers/e2e_keys.py | 2 +- synapse/rest/client/v2_alpha/keys.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 74bceddc46..d5d6e6e027 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -851,7 +851,7 @@ class E2eKeysHandler(object): user_id, signed_users ) - defer.returnValue({"failures": failures}) + return {"failures": failures} @defer.inlineCallbacks def _get_e2e_cross_signing_verify_key(self, user_id, key_type, from_user_id=None): diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index a205281830..341567ae21 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -274,7 +274,7 @@ class SigningKeyUploadServlet(RestServlet): ) result = yield self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) - return (200, result) + return 200, result class SignaturesUploadServlet(RestServlet): @@ -324,7 +324,7 @@ class SignaturesUploadServlet(RestServlet): result = yield self.e2e_keys_handler.upload_signatures_for_device_keys( user_id, body ) - defer.returnValue((200, result)) + return 200, result def register_servlets(hs, http_server): From e47af0f086839c5d22a0de87a32a49386abef8df Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 5 Sep 2019 17:03:14 -0400 Subject: [PATCH 24/55] fix test --- tests/handlers/test_e2e_keys.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 316dd6259d..7a59ec5085 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -265,7 +265,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" - devices = yield self.handler.query_devices({"device_keys": {local_user: []}}, 0) + devices = yield self.handler.query_devices( + {"device_keys": {local_user: []}}, 0, 0 + ) del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"] self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) From 369462da7488772ea6d2fdd076ff355bc09db28c Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Thu, 5 Sep 2019 17:03:31 -0400 Subject: [PATCH 25/55] avoid modifying input parameter --- synapse/handlers/e2e_keys.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d5d6e6e027..2c21cb9828 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -629,9 +629,9 @@ class E2eKeysHandler(object): # split between checking signatures for own user and signatures for # other users, since we verify them with different keys - if user_id in signatures: - self_signatures = signatures[user_id] - del signatures[user_id] + self_signatures = signatures.get(user_id, {}) + other_signatures = {k: v for k, v in signatures.items() if k != user_id} + if self_signatures: self_device_ids = list(self_signatures.keys()) try: # get our self-signing key to verify the signatures @@ -766,9 +766,9 @@ class E2eKeysHandler(object): } signed_users = [] # what user have been signed, for notifying - if len(signatures): - # if signatures isn't empty, then we have signatures for other - # users. These signatures will be signed by the user signing key + if other_signatures: + # now check non-self signatures. These signatures will be signed + # by the user-signing key try: # get our user-signing key to verify the signatures @@ -776,7 +776,7 @@ class E2eKeysHandler(object): user_id, "user_signing" ) - for user, devicemap in signatures.items(): + for user, devicemap in other_signatures.items(): device_id = None try: # get the user's master key, to make sure it matches From 561cbba0577b63f340050362144bef8527c1fc0e Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 6 Sep 2019 16:44:24 -0400 Subject: [PATCH 26/55] split out signature processing into separate functions --- synapse/handlers/e2e_keys.py | 429 ++++++++++++++++++----------------- 1 file changed, 219 insertions(+), 210 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 2c21cb9828..6500bf3e16 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -624,235 +624,244 @@ class E2eKeysHandler(object): # signatures to be stored. Each item will be a tuple of # (signing_key_id, target_user_id, target_device_id, signature) signature_list = [] - # what devices have been updated, for notifying - self_device_ids = [] # split between checking signatures for own user and signatures for # other users, since we verify them with different keys self_signatures = signatures.get(user_id, {}) other_signatures = {k: v for k, v in signatures.items() if k != user_id} - if self_signatures: - self_device_ids = list(self_signatures.keys()) - try: - # get our self-signing key to verify the signatures - self_signing_key, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( - user_id, "self_signing" - ) - # get our master key, since it may be signed - master_key, master_key_id, master_verify_key = yield self._get_e2e_cross_signing_verify_key( - user_id, "master" - ) + self_signature_list, self_failures = yield self._process_self_signatures( + user_id, self_signatures + ) + signature_list.extend(self_signature_list) + failures.update(self_failures) - # fetch our stored devices. This is used to 1. verify - # signatures on the master key, and 2. to can compare with what - # was sent if the device was signed - devices = yield self.store.get_e2e_device_keys([(user_id, None)]) - - if user_id not in devices: - raise SynapseError(404, "No device keys found", Codes.NOT_FOUND) - - devices = devices[user_id] - for device_id, device in self_signatures.items(): - try: - if ( - "signatures" not in device - or user_id not in device["signatures"] - ): - # no signature was sent - raise SynapseError( - 400, "Invalid signature", Codes.INVALID_SIGNATURE - ) - - if device_id == master_verify_key.version: - # we have master key signed by devices: for each - # device that signed, check the signature. Since - # the "failures" property in the response only has - # granularity up to the signed device, either all - # of the signatures on the master key succeed, or - # all fail. So loop over the signatures and add - # them to a separate signature list. If everything - # works out, then add them all to the main - # signature list. (In practice, we're likely to - # only have only one signature anyways.) - master_key_signature_list = [] - for signing_key_id, signature in device["signatures"][ - user_id - ].items(): - alg, signing_device_id = signing_key_id.split(":", 1) - if ( - signing_device_id not in devices - or signing_key_id - not in devices[signing_device_id]["keys"]["keys"] - ): - # signed by an unknown device, or the - # device does not have the key - raise SynapseError( - 400, - "Invalid signature", - Codes.INVALID_SIGNATURE, - ) - - sigs = device["signatures"] - del device["signatures"] - # use pop to avoid exception if key doesn't exist - device.pop("unsigned", None) - master_key.pop("signature", None) - master_key.pop("unsigned", None) - - if master_key != device: - raise SynapseError(400, "Key does not match") - - # get the key and check the signature - pubkey = devices[signing_device_id]["keys"]["keys"][ - signing_key_id - ] - verify_key = decode_verify_key_bytes( - signing_key_id, decode_base64(pubkey) - ) - device["signatures"] = sigs - try: - verify_signed_json(device, user_id, verify_key) - except SignatureVerifyException: - raise SynapseError( - 400, - "Invalid signature", - Codes.INVALID_SIGNATURE, - ) - - master_key_signature_list.append( - (signing_key_id, user_id, device_id, signature) - ) - - signature_list.extend(master_key_signature_list) - continue - - # at this point, we have a device that should be signed - # by the self-signing key - if self_signing_key_id not in device["signatures"][user_id]: - # no signature was sent - raise SynapseError( - 400, "Invalid signature", Codes.INVALID_SIGNATURE - ) - - stored_device = None - try: - stored_device = devices[device_id]["keys"] - except KeyError: - raise SynapseError(404, "Unknown device", Codes.NOT_FOUND) - if self_signing_key_id in stored_device.get( - "signatures", {} - ).get(user_id, {}): - # we already have a signature on this device, so we - # can skip it, since it should be exactly the same - continue - - _check_device_signature( - user_id, self_signing_verify_key, device, stored_device - ) - - signature = device["signatures"][user_id][self_signing_key_id] - signature_list.append( - (self_signing_key_id, user_id, device_id, signature) - ) - except SynapseError as e: - failures.setdefault(user_id, {})[ - device_id - ] = _exception_to_failure(e) - except SynapseError as e: - failures[user_id] = { - device: _exception_to_failure(e) - for device in self_signatures.keys() - } - - signed_users = [] # what user have been signed, for notifying - if other_signatures: - # now check non-self signatures. These signatures will be signed - # by the user-signing key - - try: - # get our user-signing key to verify the signatures - user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( - user_id, "user_signing" - ) - - for user, devicemap in other_signatures.items(): - device_id = None - try: - # get the user's master key, to make sure it matches - # what was sent - stored_key, stored_key_id, _ = yield self._get_e2e_cross_signing_verify_key( - user, "master", user_id - ) - - # make sure that the user's master key is the one that - # was signed (and no others) - device_id = stored_key_id.split(":", 1)[1] - if device_id not in devicemap: - # set device to None so that the failure gets - # marked on all the signatures - device_id = None - logger.error( - "upload signature: wrong device: %s vs %s", - device, - devicemap, - ) - raise SynapseError(404, "Unknown device", Codes.NOT_FOUND) - key = devicemap[device_id] - del devicemap[device_id] - if len(devicemap) > 0: - # other devices were signed -- mark those as failures - logger.error("upload signature: too many devices specified") - failure = _exception_to_failure( - SynapseError(404, "Unknown device", Codes.NOT_FOUND) - ) - failures[user] = { - device: failure for device in devicemap.keys() - } - - if user_signing_key_id in stored_key.get("signatures", {}).get( - user_id, {} - ): - # we already have the signature, so we can skip it - continue - - _check_device_signature( - user_id, user_signing_verify_key, key, stored_key - ) - - signed_users.append(user) - signature = key["signatures"][user_id][user_signing_key_id] - signature_list.append( - (user_signing_key_id, user, device_id, signature) - ) - except SynapseError as e: - failure = _exception_to_failure(e) - if device_id is None: - failures[user] = { - device_id: failure for device_id in devicemap.keys() - } - else: - failures.setdefault(user, {})[device_id] = failure - except SynapseError as e: - failure = _exception_to_failure(e) - for user, devicemap in signature.items(): - failures[user] = { - device_id: failure for device_id in devicemap.keys() - } + other_signature_list, other_failures = yield self._process_other_signatures( + user_id, other_signatures + ) + signature_list.extend(other_signature_list) + failures.update(other_failures) # store the signature, and send the appropriate notifications for sync logger.debug("upload signature failures: %r", failures) yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) - if len(self_device_ids): + self_device_ids = [device_id for (_, _, device_id, _) in self_signature_list] + if self_device_ids: yield self.device_handler.notify_device_update(user_id, self_device_ids) - if len(signed_users): + signed_users = [user_id for (_, user_id, _, _) in other_signature_list] + if signed_users: yield self.device_handler.notify_user_signature_update( user_id, signed_users ) return {"failures": failures} + @defer.inlineCallbacks + def _process_self_signatures(self, user_id, signatures): + signature_list = [] + failures = {} + if not signatures: + return signature_list, failures + + try: + # get our self-signing key to verify the signatures + self_signing_key, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( + user_id, "self_signing" + ) + + # get our master key, since it may be signed + master_key, master_key_id, master_verify_key = yield self._get_e2e_cross_signing_verify_key( + user_id, "master" + ) + + # fetch our stored devices. This is used to 1. verify + # signatures on the master key, and 2. to can compare with what + # was sent if the device was signed + devices = yield self.store.get_e2e_device_keys([(user_id, None)]) + + if user_id not in devices: + raise SynapseError(404, "No device keys found", Codes.NOT_FOUND) + + devices = devices[user_id] + except SynapseError as e: + failures[user_id] = { + device: _exception_to_failure(e) + for device in signatures.keys() + } + return signature_list, failures + + for device_id, device in signatures.items(): + try: + if ( + "signatures" not in device + or user_id not in device["signatures"] + ): + # no signature was sent + raise SynapseError( + 400, "Invalid signature", Codes.INVALID_SIGNATURE + ) + + if device_id == master_verify_key.version: + # we have master key signed by devices: for each + # device that signed, check the signature. Since + # the "failures" property in the response only has + # granularity up to the signed device, either all + # of the signatures on the master key succeed, or + # all fail. So loop over the signatures and add + # them to a separate signature list. If everything + # works out, then add them all to the main + # signature list. (In practice, we're likely to + # only have only one signature anyways.) + master_key_signature_list = [] + sigs = device["signatures"] + for signing_key_id, signature in sigs[user_id].items(): + alg, signing_device_id = signing_key_id.split(":", 1) + if ( + signing_device_id not in devices + or signing_key_id + not in devices[signing_device_id]["keys"]["keys"] + ): + # signed by an unknown device, or the + # device does not have the key + raise SynapseError( + 400, + "Invalid signature", + Codes.INVALID_SIGNATURE, + ) + + # get the key and check the signature + pubkey = devices[signing_device_id]["keys"]["keys"][ + signing_key_id + ] + verify_key = decode_verify_key_bytes( + signing_key_id, decode_base64(pubkey) + ) + _check_device_signature(user_id, verify_key, device, master_key) + device["signatures"] = sigs + + master_key_signature_list.append( + (signing_key_id, user_id, device_id, signature) + ) + + signature_list.extend(master_key_signature_list) + continue + + # at this point, we have a device that should be signed + # by the self-signing key + if self_signing_key_id not in device["signatures"][user_id]: + # no signature was sent + raise SynapseError( + 400, "Invalid signature", Codes.INVALID_SIGNATURE + ) + + stored_device = None + try: + stored_device = devices[device_id]["keys"] + except KeyError: + raise SynapseError(404, "Unknown device", Codes.NOT_FOUND) + if self_signing_key_id in stored_device.get( + "signatures", {} + ).get(user_id, {}): + # we already have a signature on this device, so we + # can skip it, since it should be exactly the same + continue + + _check_device_signature( + user_id, self_signing_verify_key, device, stored_device + ) + + signature = device["signatures"][user_id][self_signing_key_id] + signature_list.append( + (self_signing_key_id, user_id, device_id, signature) + ) + except SynapseError as e: + failures.setdefault(user_id, {})[ + device_id + ] = _exception_to_failure(e) + + return signature_list, failures + + @defer.inlineCallbacks + def _process_other_signatures(self, user_id, signatures): + # now check non-self signatures. These signatures will be signed + # by the user-signing key + signature_list = [] + failures = {} + if not signatures: + return signature_list, failures + + try: + # get our user-signing key to verify the signatures + user_signing_key, user_signing_key_id, user_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( + user_id, "user_signing" + ) + except SynapseError as e: + failure = _exception_to_failure(e) + for user, devicemap in signatures.items(): + failures[user] = { + device_id: failure for device_id in devicemap.keys() + } + return signature_list, failures + + for user, devicemap in signatures.items(): + device_id = None + try: + # get the user's master key, to make sure it matches + # what was sent + stored_key, stored_key_id, _ = yield self._get_e2e_cross_signing_verify_key( + user, "master", user_id + ) + + # make sure that the user's master key is the one that + # was signed (and no others) + device_id = stored_key_id.split(":", 1)[1] + if device_id not in devicemap: + logger.error( + "upload signature: could not find signature for device %s", + device_id, + ) + # set device to None so that the failure gets + # marked on all the signatures + device_id = None + raise SynapseError(404, "Unknown device", Codes.NOT_FOUND) + key = devicemap[device_id] + other_devices = [k for k in devicemap.keys() if k != device_id] + if other_devices: + # other devices were signed -- mark those as failures + logger.error("upload signature: too many devices specified") + failure = _exception_to_failure( + SynapseError(404, "Unknown device", Codes.NOT_FOUND) + ) + failures[user] = { + device: failure for device in other_devices + } + + if user_signing_key_id in stored_key.get("signatures", {}).get( + user_id, {} + ): + # we already have the signature, so we can skip it + continue + + _check_device_signature( + user_id, user_signing_verify_key, key, stored_key + ) + + signature = key["signatures"][user_id][user_signing_key_id] + signature_list.append( + (user_signing_key_id, user, device_id, signature) + ) + except SynapseError as e: + failure = _exception_to_failure(e) + if device_id is None: + failures[user] = { + device_id: failure for device_id in devicemap.keys() + } + else: + failures.setdefault(user, {})[device_id] = failure + + return signature_list, failures + @defer.inlineCallbacks def _get_e2e_cross_signing_verify_key(self, user_id, key_type, from_user_id=None): key = yield self.store.get_e2e_cross_signing_key( From 415d0a00e0845654b34542b9914ea01224dd8ed6 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 6 Sep 2019 16:46:45 -0400 Subject: [PATCH 27/55] run black --- synapse/handlers/e2e_keys.py | 34 ++++++++++------------------------ 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 6500bf3e16..95f3cc891b 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -686,17 +686,13 @@ class E2eKeysHandler(object): devices = devices[user_id] except SynapseError as e: failures[user_id] = { - device: _exception_to_failure(e) - for device in signatures.keys() + device: _exception_to_failure(e) for device in signatures.keys() } return signature_list, failures for device_id, device in signatures.items(): try: - if ( - "signatures" not in device - or user_id not in device["signatures"] - ): + if "signatures" not in device or user_id not in device["signatures"]: # no signature was sent raise SynapseError( 400, "Invalid signature", Codes.INVALID_SIGNATURE @@ -725,9 +721,7 @@ class E2eKeysHandler(object): # signed by an unknown device, or the # device does not have the key raise SynapseError( - 400, - "Invalid signature", - Codes.INVALID_SIGNATURE, + 400, "Invalid signature", Codes.INVALID_SIGNATURE ) # get the key and check the signature @@ -760,9 +754,9 @@ class E2eKeysHandler(object): stored_device = devices[device_id]["keys"] except KeyError: raise SynapseError(404, "Unknown device", Codes.NOT_FOUND) - if self_signing_key_id in stored_device.get( - "signatures", {} - ).get(user_id, {}): + if self_signing_key_id in stored_device.get("signatures", {}).get( + user_id, {} + ): # we already have a signature on this device, so we # can skip it, since it should be exactly the same continue @@ -776,9 +770,7 @@ class E2eKeysHandler(object): (self_signing_key_id, user_id, device_id, signature) ) except SynapseError as e: - failures.setdefault(user_id, {})[ - device_id - ] = _exception_to_failure(e) + failures.setdefault(user_id, {})[device_id] = _exception_to_failure(e) return signature_list, failures @@ -799,9 +791,7 @@ class E2eKeysHandler(object): except SynapseError as e: failure = _exception_to_failure(e) for user, devicemap in signatures.items(): - failures[user] = { - device_id: failure for device_id in devicemap.keys() - } + failures[user] = {device_id: failure for device_id in devicemap.keys()} return signature_list, failures for user, devicemap in signatures.items(): @@ -833,9 +823,7 @@ class E2eKeysHandler(object): failure = _exception_to_failure( SynapseError(404, "Unknown device", Codes.NOT_FOUND) ) - failures[user] = { - device: failure for device in other_devices - } + failures[user] = {device: failure for device in other_devices} if user_signing_key_id in stored_key.get("signatures", {}).get( user_id, {} @@ -848,9 +836,7 @@ class E2eKeysHandler(object): ) signature = key["signatures"][user_id][user_signing_key_id] - signature_list.append( - (user_signing_key_id, user, device_id, signature) - ) + signature_list.append((user_signing_key_id, user, device_id, signature)) except SynapseError as e: failure = _exception_to_failure(e) if device_id is None: From ab729e31cfca4d1a958937bb576010271b9c8044 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 6 Sep 2019 17:52:37 -0400 Subject: [PATCH 28/55] use something that's the right type for user_id --- tests/handlers/test_e2e_keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 7a59ec5085..854eb6c024 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -266,7 +266,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase): device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature" device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature" devices = yield self.handler.query_devices( - {"device_keys": {local_user: []}}, 0, 0 + {"device_keys": {local_user: []}}, 0, local_user ) del devices["device_keys"][local_user]["abc"]["unsigned"] del devices["device_keys"][local_user]["def"]["unsigned"] From d3f2fbcfe577f42d0208d15a57bd66e56186742a Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Sat, 7 Sep 2019 14:13:18 -0400 Subject: [PATCH 29/55] add function docs --- synapse/handlers/e2e_keys.py | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 95f3cc891b..cca361b15b 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -659,6 +659,18 @@ class E2eKeysHandler(object): @defer.inlineCallbacks def _process_self_signatures(self, user_id, signatures): + """Process uploaded signatures of the user's own keys. + + Args: + user_id (string): the user uploading the keys + signatures (dict[string, dict]): map of devices to signed keys + + Returns: + (list[(string, string, string, string)], dict[string, dict[string, dict]]): + a list of signatures to upload, in the form (signing_key_id, target_user_id, + target_device_id, signature), and a map of users to devices to failure + reasons + """ signature_list = [] failures = {} if not signatures: @@ -776,8 +788,18 @@ class E2eKeysHandler(object): @defer.inlineCallbacks def _process_other_signatures(self, user_id, signatures): - # now check non-self signatures. These signatures will be signed - # by the user-signing key + """Process uploaded signatures of other users' keys. + + Args: + user_id (string): the user uploading the keys + signatures (dict[string, dict]): map of users to devices to signed keys + + Returns: + (list[(string, string, string, string)], dict[string, dict[string, dict]]): + a list of signatures to upload, in the form (signing_key_id, target_user_id, + target_device_id, signature), and a map of users to devices to failure + reasons + """ signature_list = [] failures = {} if not signatures: From 26113fb7de98ba09fed4ce687dbef8c4cfb07dc0 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 24 Sep 2019 14:12:20 -0400 Subject: [PATCH 30/55] make changes based on PR feedback --- synapse/handlers/e2e_keys.py | 272 +++++++++++++++++------------ synapse/storage/end_to_end_keys.py | 17 +- 2 files changed, 168 insertions(+), 121 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index cca361b15b..352c8ee93b 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -19,6 +19,8 @@ import logging from six import iteritems +import attr + from canonicaljson import encode_canonical_json, json from signedjson.key import decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json @@ -26,7 +28,7 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer -from synapse.api.errors import CodeMessageException, Codes, SynapseError +from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.types import ( @@ -617,12 +619,18 @@ class E2eKeysHandler(object): Args: user_id (string): the user uploading the signatures signatures (dict[string, dict[string, dict]]): map of users to - devices to signed keys + devices to signed keys. This is the submission from the user; an + exception will be raised if it is malformed. + Returns: + dict: response to be sent back to the client. The response will have + a "failures" key, which will be a dict mapping users to devices + to errors for the signatures that failed. + Raises: + SynapseError: if the signatures dict is not valid. """ failures = {} - # signatures to be stored. Each item will be a tuple of - # (signing_key_id, target_user_id, target_device_id, signature) + # signatures to be stored. Each item will be a SignatureListItem signature_list = [] # split between checking signatures for own user and signatures for @@ -646,10 +654,10 @@ class E2eKeysHandler(object): logger.debug("upload signature failures: %r", failures) yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list) - self_device_ids = [device_id for (_, _, device_id, _) in self_signature_list] + self_device_ids = [item.target_device_id for item in self_signature_list] if self_device_ids: yield self.device_handler.notify_device_update(user_id, self_device_ids) - signed_users = [user_id for (_, user_id, _, _) in other_signature_list] + signed_users = [item.target_user_id for item in other_signature_list] if signed_users: yield self.device_handler.notify_user_signature_update( user_id, signed_users @@ -661,48 +669,58 @@ class E2eKeysHandler(object): def _process_self_signatures(self, user_id, signatures): """Process uploaded signatures of the user's own keys. + Signatures of the user's own keys from this API come in two forms: + - signatures of the user's devices by the user's self-signing key, + - signatures of the user's master key by the user's devices. + Args: user_id (string): the user uploading the keys signatures (dict[string, dict]): map of devices to signed keys Returns: - (list[(string, string, string, string)], dict[string, dict[string, dict]]): - a list of signatures to upload, in the form (signing_key_id, target_user_id, - target_device_id, signature), and a map of users to devices to failure + (list[SignatureListItem], dict[string, dict[string, dict]]): + a list of signatures to upload, and a map of users to devices to failure reasons + + Raises: + SynapseError: if the input is malformed """ signature_list = [] failures = {} if not signatures: return signature_list, failures + if not isinstance(signatures, dict): + raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) + try: # get our self-signing key to verify the signatures - self_signing_key, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( + _, self_signing_key_id, self_signing_verify_key = yield self._get_e2e_cross_signing_verify_key( user_id, "self_signing" ) # get our master key, since it may be signed - master_key, master_key_id, master_verify_key = yield self._get_e2e_cross_signing_verify_key( + master_key, _, master_verify_key = yield self._get_e2e_cross_signing_verify_key( user_id, "master" ) # fetch our stored devices. This is used to 1. verify - # signatures on the master key, and 2. to can compare with what + # signatures on the master key, and 2. to compare with what # was sent if the device was signed devices = yield self.store.get_e2e_device_keys([(user_id, None)]) if user_id not in devices: - raise SynapseError(404, "No device keys found", Codes.NOT_FOUND) + raise NotFoundError("No device keys found") devices = devices[user_id] except SynapseError as e: - failures[user_id] = { - device: _exception_to_failure(e) for device in signatures.keys() - } + failure = _exception_to_failure(e) + failures[user_id] = {device: failure for device in signatures.keys()} return signature_list, failures for device_id, device in signatures.items(): + if not isinstance(device, dict): + raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) try: if "signatures" not in device or user_id not in device["signatures"]: # no signature was sent @@ -711,45 +729,9 @@ class E2eKeysHandler(object): ) if device_id == master_verify_key.version: - # we have master key signed by devices: for each - # device that signed, check the signature. Since - # the "failures" property in the response only has - # granularity up to the signed device, either all - # of the signatures on the master key succeed, or - # all fail. So loop over the signatures and add - # them to a separate signature list. If everything - # works out, then add them all to the main - # signature list. (In practice, we're likely to - # only have only one signature anyways.) - master_key_signature_list = [] - sigs = device["signatures"] - for signing_key_id, signature in sigs[user_id].items(): - alg, signing_device_id = signing_key_id.split(":", 1) - if ( - signing_device_id not in devices - or signing_key_id - not in devices[signing_device_id]["keys"]["keys"] - ): - # signed by an unknown device, or the - # device does not have the key - raise SynapseError( - 400, "Invalid signature", Codes.INVALID_SIGNATURE - ) - - # get the key and check the signature - pubkey = devices[signing_device_id]["keys"]["keys"][ - signing_key_id - ] - verify_key = decode_verify_key_bytes( - signing_key_id, decode_base64(pubkey) - ) - _check_device_signature(user_id, verify_key, device, master_key) - device["signatures"] = sigs - - master_key_signature_list.append( - (signing_key_id, user_id, device_id, signature) - ) - + master_key_signature_list = self._check_master_key_signature( + user_id, device_id, device, master_key, devices + ) signature_list.extend(master_key_signature_list) continue @@ -765,7 +747,7 @@ class E2eKeysHandler(object): try: stored_device = devices[device_id]["keys"] except KeyError: - raise SynapseError(404, "Unknown device", Codes.NOT_FOUND) + raise NotFoundError("Unknown device") if self_signing_key_id in stored_device.get("signatures", {}).get( user_id, {} ): @@ -779,26 +761,75 @@ class E2eKeysHandler(object): signature = device["signatures"][user_id][self_signing_key_id] signature_list.append( - (self_signing_key_id, user_id, device_id, signature) + SignatureListItem( + self_signing_key_id, user_id, device_id, signature + ) ) except SynapseError as e: failures.setdefault(user_id, {})[device_id] = _exception_to_failure(e) return signature_list, failures - @defer.inlineCallbacks - def _process_other_signatures(self, user_id, signatures): - """Process uploaded signatures of other users' keys. + def _check_master_key_signature( + self, user_id, master_key_id, signed_master_key, stored_master_key, devices + ): + """Check signatures of the user's master key made by their devices. Args: user_id (string): the user uploading the keys signatures (dict[string, dict]): map of users to devices to signed keys Returns: - (list[(string, string, string, string)], dict[string, dict[string, dict]]): - a list of signatures to upload, in the form (signing_key_id, target_user_id, - target_device_id, signature), and a map of users to devices to failure + (list[SignatureListItem], dict[string, dict[string, dict]]): + a list of signatures to upload, and a map of users to devices to failure reasons + + Raises: + SynapseError: if the input is malformed + """ + # for each device that signed the master key, check the signature. + master_key_signature_list = [] + sigs = signed_master_key["signatures"] + for signing_key_id, signature in sigs[user_id].items(): + _, signing_device_id = signing_key_id.split(":", 1) + if ( + signing_device_id not in devices + or signing_key_id not in devices[signing_device_id]["keys"]["keys"] + ): + # signed by an unknown device, or the + # device does not have the key + raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) + + # get the key and check the signature + pubkey = devices[signing_device_id]["keys"]["keys"][signing_key_id] + verify_key = decode_verify_key_bytes(signing_key_id, decode_base64(pubkey)) + _check_device_signature( + user_id, verify_key, signed_master_key, stored_master_key + ) + + master_key_signature_list.append( + SignatureListItem(signing_key_id, user_id, master_key_id, signature) + ) + + return master_key_signature_list + + @defer.inlineCallbacks + def _process_other_signatures(self, user_id, signatures): + """Process uploaded signatures of other users' keys. These will be the + target user's master keys, signed by the uploading user's user-signing + key. + + Args: + user_id (string): the user uploading the keys + signatures (dict[string, dict]): map of users to devices to signed keys + + Returns: + (list[SignatureListItem], dict[string, dict[string, dict]]): + a list of signatures to upload, and a map of users to devices to failure + reasons + + Raises: + SynapseError: if the input is malformed """ signature_list = [] failures = {} @@ -816,70 +847,89 @@ class E2eKeysHandler(object): failures[user] = {device_id: failure for device_id in devicemap.keys()} return signature_list, failures - for user, devicemap in signatures.items(): + for target_user, devicemap in signatures.items(): + if not isinstance(devicemap, dict): + raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) + for device in devicemap.values(): + if not isinstance(device, dict): + raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) device_id = None try: - # get the user's master key, to make sure it matches + # get the target user's master key, to make sure it matches # what was sent - stored_key, stored_key_id, _ = yield self._get_e2e_cross_signing_verify_key( - user, "master", user_id + master_key, master_key_id, _ = yield self._get_e2e_cross_signing_verify_key( + target_user, "master", user_id ) - # make sure that the user's master key is the one that + # make sure that the target user's master key is the one that # was signed (and no others) - device_id = stored_key_id.split(":", 1)[1] + device_id = master_key_id.split(":", 1)[1] if device_id not in devicemap: - logger.error( + logger.debug( "upload signature: could not find signature for device %s", device_id, ) # set device to None so that the failure gets # marked on all the signatures device_id = None - raise SynapseError(404, "Unknown device", Codes.NOT_FOUND) + raise NotFoundError("Unknown device") key = devicemap[device_id] other_devices = [k for k in devicemap.keys() if k != device_id] if other_devices: # other devices were signed -- mark those as failures - logger.error("upload signature: too many devices specified") - failure = _exception_to_failure( - SynapseError(404, "Unknown device", Codes.NOT_FOUND) - ) - failures[user] = {device: failure for device in other_devices} + logger.debug("upload signature: too many devices specified") + failure = _exception_to_failure(NotFoundError("Unknown device")) + failures[target_user] = { + device: failure for device in other_devices + } - if user_signing_key_id in stored_key.get("signatures", {}).get( + if user_signing_key_id in master_key.get("signatures", {}).get( user_id, {} ): # we already have the signature, so we can skip it continue _check_device_signature( - user_id, user_signing_verify_key, key, stored_key + user_id, user_signing_verify_key, key, master_key ) signature = key["signatures"][user_id][user_signing_key_id] - signature_list.append((user_signing_key_id, user, device_id, signature)) + signature_list.append( + SignatureListItem( + user_signing_key_id, target_user, device_id, signature + ) + ) except SynapseError as e: failure = _exception_to_failure(e) if device_id is None: - failures[user] = { + failures[target_user] = { device_id: failure for device_id in devicemap.keys() } else: - failures.setdefault(user, {})[device_id] = failure + failures.setdefault(target_user, {})[device_id] = failure return signature_list, failures @defer.inlineCallbacks def _get_e2e_cross_signing_verify_key(self, user_id, key_type, from_user_id=None): + """Fetch the cross-signing public key from storage and interpret it. + + Args: + user_id (str): the user whose key should be fetched + key_type (str): the type of key to fetch + from_user_id (str): the user that we are fetching the keys for. + This affects what signatures are fetched. + + Returns: + dict, str, VerifyKey: the raw key data, the key ID, and the + signedjson verify key + """ key = yield self.store.get_e2e_cross_signing_key( user_id, key_type, from_user_id ) if key is None: logger.error("no %s key found for %s", key_type, user_id) - raise SynapseError( - 404, "No %s key found for %s" % (key_type, user_id), Codes.NOT_FOUND - ) + raise NotFoundError("No %s key found for %s" % (key_type, user_id)) key_id, verify_key = get_verify_key_from_cross_signing_key(key) return key, key_id, verify_key @@ -912,36 +962,30 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None): def _check_device_signature(user_id, verify_key, signed_device, stored_device): - """Check that a device signature is correct and matches the copy of the device - that we have. Throws an exception if an error is detected. + """Check that a signature on a device or cross-signing key is correct and + matches the copy of the device/key that we have stored. Throws an + exception if an error is detected. Args: user_id (str): the user ID whose signature is being checked verify_key (VerifyKey): the key to verify the device with - signed_device (dict): the signed device data - stored_device (dict): our previous copy of the device + signed_device (dict): the uploaded signed device data + stored_device (dict): our previously stored copy of the device + + Raises: + SynapseError: if the signature was invalid or the sent device is not the + same as the stored device + """ - key_id = "%s:%s" % (verify_key.alg, verify_key.version) - - # make sure the device is signed - if ( - "signatures" not in signed_device - or user_id not in signed_device["signatures"] - or key_id not in signed_device["signatures"][user_id] - ): - logger.error("upload signature: user not found in signatures") - raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) - - signature = signed_device["signatures"][user_id][key_id] - # make sure that the device submitted matches what we have stored - del signed_device["signatures"] - # use pop to avoid exception if key doesn't exist - signed_device.pop("unsigned", None) - stored_device.pop("signatures", None) - stored_device.pop("unsigned", None) - if signed_device != stored_device: + stripped_signed_device = { + k: v for k, v in signed_device.items() if k not in ["signatures", "unsigned"] + } + stripped_stored_device = { + k: v for k, v in stored_device.items() if k not in ["signatures", "unsigned"] + } + if stripped_signed_device != stripped_stored_device: logger.error( "upload signatures: key does not match %s vs %s", signed_device, @@ -949,9 +993,6 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): ) raise SynapseError(400, "Key does not match") - # check the signature - signed_device["signatures"] = {user_id: {key_id: signature}} - try: verify_signed_json(signed_device, user_id, verify_key) except SignatureVerifyException: @@ -990,3 +1031,14 @@ def _one_time_keys_match(old_key_json, new_key): new_key_copy.pop("signatures", None) return old_key == new_key_copy + + +@attr.s +class SignatureListItem: + """An item in the signature list as used by upload_signatures_for_device_keys. + """ + + signing_key_id = attr.ib() + target_user_id = attr.ib() + target_device_id = attr.ib() + signature = attr.ib() diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 258e8dcb47..625f95234f 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -490,24 +490,19 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): Args: user_id (str): the user who made the signatures - signatures (iterable[(str, str, str, str)]): signatures to add - each - a tuple of (key_id, target_user_id, target_device_id, signature), - where key_id is the ID of the key (including the signature - algorithm) that made the signature, target_user_id and - target_device_id indicate the device being signed, and signature - is the signature of the device + signatures (iterable[SignatureListItem]): signatures to add """ return self._simple_insert_many( "e2e_cross_signing_signatures", [ { "user_id": user_id, - "key_id": key_id, - "target_user_id": target_user_id, - "target_device_id": target_device_id, - "signature": signature, + "key_id": item.signing_key_id, + "target_user_id": item.target_user_id, + "target_device_id": item.target_device_id, + "signature": item.signature, } - for (key_id, target_user_id, target_device_id, signature) in signatures + for item in signatures ], "add_e2e_signing_key", ) From 39864f45ec1a5c2c65d4cb03744d4d9452505c0d Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 24 Sep 2019 15:26:45 -0400 Subject: [PATCH 31/55] drop some logger lines to debug --- synapse/handlers/e2e_keys.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 352c8ee93b..ff32fdaccc 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -928,7 +928,7 @@ class E2eKeysHandler(object): user_id, key_type, from_user_id ) if key is None: - logger.error("no %s key found for %s", key_type, user_id) + logger.debug("no %s key found for %s", key_type, user_id) raise NotFoundError("No %s key found for %s" % (key_type, user_id)) key_id, verify_key = get_verify_key_from_cross_signing_key(key) return key, key_id, verify_key @@ -986,7 +986,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): k: v for k, v in stored_device.items() if k not in ["signatures", "unsigned"] } if stripped_signed_device != stripped_stored_device: - logger.error( + logger.debug( "upload signatures: key does not match %s vs %s", signed_device, stored_device, @@ -996,7 +996,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): try: verify_signed_json(signed_device, user_id, verify_key) except SignatureVerifyException: - logger.error("invalid signature on key") + logger.debug("invalid signature on key") raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) From f4b6d43ec31ca93ee5e1b25c43a831c6b52df3bf Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 24 Sep 2019 16:19:54 -0400 Subject: [PATCH 32/55] add some comments --- synapse/handlers/e2e_keys.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index ff32fdaccc..85d7047f67 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -699,7 +699,10 @@ class E2eKeysHandler(object): user_id, "self_signing" ) - # get our master key, since it may be signed + # get our master key, since we may have received a signature of it. + # We need to fetch it here so that we know what its key ID is, so + # that we can check if a signature that was sent is a signature of + # the master key or of a device master_key, _, master_verify_key = yield self._get_e2e_cross_signing_verify_key( user_id, "master" ) @@ -719,8 +722,10 @@ class E2eKeysHandler(object): return signature_list, failures for device_id, device in signatures.items(): + # make sure submitted data is in the right form if not isinstance(device, dict): raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) + try: if "signatures" not in device or user_id not in device["signatures"]: # no signature was sent @@ -729,6 +734,8 @@ class E2eKeysHandler(object): ) if device_id == master_verify_key.version: + # The signature is of the master key. This needs to be + # handled differently from signatures of normal devices. master_key_signature_list = self._check_master_key_signature( user_id, device_id, device, master_key, devices ) @@ -743,7 +750,6 @@ class E2eKeysHandler(object): 400, "Invalid signature", Codes.INVALID_SIGNATURE ) - stored_device = None try: stored_device = devices[device_id]["keys"] except KeyError: @@ -848,11 +854,13 @@ class E2eKeysHandler(object): return signature_list, failures for target_user, devicemap in signatures.items(): + # make sure submitted data is in the right form if not isinstance(devicemap, dict): raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) for device in devicemap.values(): if not isinstance(device, dict): raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM) + device_id = None try: # get the target user's master key, to make sure it matches From c3635c94597d0ff188d1609af6b5f3a4464c91d6 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Tue, 24 Sep 2019 16:21:03 -0400 Subject: [PATCH 33/55] make isort happy --- synapse/handlers/e2e_keys.py | 1 - 1 file changed, 1 deletion(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 85d7047f67..786fbfb596 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -20,7 +20,6 @@ import logging from six import iteritems import attr - from canonicaljson import encode_canonical_json, json from signedjson.key import decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json From cad0132fb590aec10a398a5d961896c6348fae4f Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 11 Oct 2019 11:24:03 +0100 Subject: [PATCH 34/55] Remove dead check_auth script This doesn't work, and afaict hasn't been used since 2015. --- scripts-dev/check_auth.py | 58 --------------------------------------- 1 file changed, 58 deletions(-) delete mode 100644 scripts-dev/check_auth.py diff --git a/scripts-dev/check_auth.py b/scripts-dev/check_auth.py deleted file mode 100644 index 2a1c5f39d4..0000000000 --- a/scripts-dev/check_auth.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import print_function - -import argparse -import itertools -import json -import sys - -from mock import Mock - -from synapse.api.auth import Auth -from synapse.events import FrozenEvent - - -def check_auth(auth, auth_chain, events): - auth_chain.sort(key=lambda e: e.depth) - - auth_map = {e.event_id: e for e in auth_chain} - - create_events = {} - for e in auth_chain: - if e.type == "m.room.create": - create_events[e.room_id] = e - - for e in itertools.chain(auth_chain, events): - auth_events_list = [auth_map[i] for i, _ in e.auth_events] - - auth_events = {(e.type, e.state_key): e for e in auth_events_list} - - auth_events[("m.room.create", "")] = create_events[e.room_id] - - try: - auth.check(e, auth_events=auth_events) - except Exception as ex: - print("Failed:", e.event_id, e.type, e.state_key) - print("Auth_events:", auth_events) - print(ex) - print(json.dumps(e.get_dict(), sort_keys=True, indent=4)) - # raise - print("Success:", e.event_id, e.type, e.state_key) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - - parser.add_argument( - "json", nargs="?", type=argparse.FileType("r"), default=sys.stdin - ) - - args = parser.parse_args() - - js = json.load(args.json) - - auth = Auth(Mock()) - check_auth( - auth, - [FrozenEvent(d) for d in js["auth_chain"]], - [FrozenEvent(d) for d in js.get("pdus", [])], - ) From 1ba359a11f238fa8d9b6319067d1b0acefdba20a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 11 Oct 2019 14:50:48 +0100 Subject: [PATCH 35/55] rip out some unreachable code The only possible rejection reason is AUTH_ERROR, so all of this is unreachable. --- synapse/api/constants.py | 2 - synapse/federation/federation_client.py | 38 --------- synapse/federation/transport/client.py | 11 --- synapse/handlers/federation.py | 102 ------------------------ 4 files changed, 153 deletions(-) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f29bce560c..60e99e4663 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -97,8 +97,6 @@ class EventTypes(object): class RejectedReason(object): AUTH_ERROR = "auth_error" - REPLACED = "replaced" - NOT_ANCESTOR = "not_ancestor" class RoomCreationPreset(object): diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 6ee6216660..5b22a39b7f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -878,44 +878,6 @@ class FederationClient(FederationBase): third_party_instance_id=third_party_instance_id, ) - @defer.inlineCallbacks - def query_auth(self, destination, room_id, event_id, local_auth): - """ - Params: - destination (str) - event_it (str) - local_auth (list) - """ - time_now = self._clock.time_msec() - - send_content = {"auth_chain": [e.get_pdu_json(time_now) for e in local_auth]} - - code, content = yield self.transport_layer.send_query_auth( - destination=destination, - room_id=room_id, - event_id=event_id, - content=send_content, - ) - - room_version = yield self.store.get_room_version(room_id) - format_ver = room_version_to_event_format(room_version) - - auth_chain = [event_from_pdu_json(e, format_ver) for e in content["auth_chain"]] - - signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True, room_version=room_version - ) - - signed_auth.sort(key=lambda e: e.depth) - - ret = { - "auth_chain": signed_auth, - "rejects": content.get("rejects", []), - "missing": content.get("missing", []), - } - - return ret - @defer.inlineCallbacks def get_missing_events( self, diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 482a101c09..7b18408144 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -381,17 +381,6 @@ class TransportLayerClient(object): return content - @defer.inlineCallbacks - @log_function - def send_query_auth(self, destination, room_id, event_id, content): - path = _create_v1_path("/query_auth/%s/%s", room_id, event_id) - - content = yield self.client.post_json( - destination=destination, path=path, data=content - ) - - return content - @defer.inlineCallbacks @log_function def query_client_keys(self, destination, query_content, timeout): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 50fc0fde2a..57f661f16e 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2181,103 +2181,10 @@ class FederationHandler(BaseHandler): auth_events.update(new_state) - different_auth = event_auth_events.difference( - e.event_id for e in auth_events.values() - ) - yield self._update_context_for_auth_events( event, context, auth_events, event_key ) - if not different_auth: - # we're done - return - - logger.info( - "auth_events still refers to events which are not in the calculated auth " - "chain after state resolution: %s", - different_auth, - ) - - # Only do auth resolution if we have something new to say. - # We can't prove an auth failure. - do_resolution = False - - for e_id in different_auth: - if e_id in have_events: - if have_events[e_id] == RejectedReason.NOT_ANCESTOR: - do_resolution = True - break - - if not do_resolution: - logger.info( - "Skipping auth resolution due to lack of provable rejection reasons" - ) - return - - logger.info("Doing auth resolution") - - prev_state_ids = yield context.get_prev_state_ids(self.store) - - # 1. Get what we think is the auth chain. - auth_ids = yield self.auth.compute_auth_events(event, prev_state_ids) - local_auth_chain = yield self.store.get_auth_chain(auth_ids, include_given=True) - - try: - # 2. Get remote difference. - try: - result = yield self.federation_client.query_auth( - origin, event.room_id, event.event_id, local_auth_chain - ) - except RequestSendFailed as e: - # The other side isn't around or doesn't implement the - # endpoint, so lets just bail out. - logger.info("Failed to query auth from remote: %s", e) - return - - seen_remotes = yield self.store.have_seen_events( - [e.event_id for e in result["auth_chain"]] - ) - - # 3. Process any remote auth chain events we haven't seen. - for ev in result["auth_chain"]: - if ev.event_id in seen_remotes: - continue - - if ev.event_id == event.event_id: - continue - - try: - auth_ids = ev.auth_event_ids() - auth = { - (e.type, e.state_key): e - for e in result["auth_chain"] - if e.event_id in auth_ids or event.type == EventTypes.Create - } - ev.internal_metadata.outlier = True - - logger.debug( - "do_auth %s different_auth: %s", event.event_id, e.event_id - ) - - yield self._handle_new_event(origin, ev, auth_events=auth) - - if ev.event_id in event_auth_events: - auth_events[(ev.type, ev.state_key)] = ev - except AuthError: - pass - - except Exception: - # FIXME: - logger.exception("Failed to query auth chain") - - # 4. Look at rejects and their proofs. - # TODO. - - yield self._update_context_for_auth_events( - event, context, auth_events, event_key - ) - @defer.inlineCallbacks def _update_context_for_auth_events(self, event, context, auth_events, event_key): """Update the state_ids in an event context after auth event resolution, @@ -2444,15 +2351,6 @@ class FederationHandler(BaseHandler): reason_map[e.event_id] = reason - if reason == RejectedReason.AUTH_ERROR: - pass - elif reason == RejectedReason.REPLACED: - # TODO: Get proof - pass - elif reason == RejectedReason.NOT_ANCESTOR: - # TODO: Get proof. - pass - logger.debug("construct_auth_difference returning") return { From 1594de856c78d3e10b965b4c8ac121fb6a1083d1 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Thu, 17 Oct 2019 21:44:44 +0100 Subject: [PATCH 36/55] changelog --- changelog.d/6214.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/6214.misc diff --git a/changelog.d/6214.misc b/changelog.d/6214.misc new file mode 100644 index 0000000000..c3fd04d0d8 --- /dev/null +++ b/changelog.d/6214.misc @@ -0,0 +1 @@ +Remove some unused event-auth code. From f0f6a2b360829e0ba13dec239c586e95d46d60b4 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 18 Oct 2019 10:56:54 +0100 Subject: [PATCH 37/55] use the right function for when we're already in runInteraction --- synapse/storage/end_to_end_keys.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 8ce5dd8bf9..3c82f789fa 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -346,7 +346,8 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): # The "keys" property must only have one entry, which will be the public # key, so we just grab the first value in there pubkey = next(iter(key["keys"].values())) - self._simple_insert( + self._simple_insert_txn( + txn, "devices", values={ "user_id": user_id, @@ -354,12 +355,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "display_name": key_type + " signing key", "hidden": True, }, - desc="store_master_key_device", ) # and finally, store the key itself with self._cross_signing_id_gen.get_next() as stream_id: - self._simple_insert( + self._simple_insert_txn( + txn, "e2e_cross_signing_keys", values={ "user_id": user_id, @@ -367,7 +368,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): "keydata": json.dumps(key), "stream_id": stream_id, }, - desc="store_master_key", ) def set_e2e_cross_signing_key(self, user_id, key_type, key): From 770a6053a09812d23f6761442bddeaf6ef219057 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 18 Oct 2019 11:38:27 +0100 Subject: [PATCH 38/55] add note about database upgrade --- changelog.d/5759.misc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/changelog.d/5759.misc b/changelog.d/5759.misc index c0bc566c4c..dc7e2c01bf 100644 --- a/changelog.d/5759.misc +++ b/changelog.d/5759.misc @@ -1 +1,4 @@ -Allow devices to be marked as hidden, for use by features such as cross-signing. \ No newline at end of file +Allow devices to be marked as hidden, for use by features such as cross-signing. +This adds a new field with a default value to the devices field in the database, +and so the database upgrade may take a long time depending on how many devices +are in the database. From 125eb45e19e5a3bd0e6e4f9ef429f62eb9255ce4 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 18 Oct 2019 16:56:16 +0100 Subject: [PATCH 39/55] fix doc strings --- synapse/handlers/e2e_keys.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 786fbfb596..6bf3ef49a8 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -678,7 +678,7 @@ class E2eKeysHandler(object): Returns: (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to upload, and a map of users to devices to failure + a list of signatures to store, and a map of users to devices to failure reasons Raises: @@ -778,19 +778,20 @@ class E2eKeysHandler(object): def _check_master_key_signature( self, user_id, master_key_id, signed_master_key, stored_master_key, devices ): - """Check signatures of the user's master key made by their devices. + """Check signatures of a user's master key made by their devices. Args: - user_id (string): the user uploading the keys - signatures (dict[string, dict]): map of users to devices to signed keys + user_id (string): the user whose master key is being checked + master_key_id (string): the ID of the user's master key + signed_master_key (dict): the user's signed master key that was uploaded + stored_master_key (dict): our previously-stored copy of the user's master key + devices (iterable(dict)): the user's devices Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to upload, and a map of users to devices to failure - reasons + list[SignatureListItem]: a list of signatures to store Raises: - SynapseError: if the input is malformed + SynapseError: if a signature is invalid """ # for each device that signed the master key, check the signature. master_key_signature_list = [] @@ -830,7 +831,7 @@ class E2eKeysHandler(object): Returns: (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to upload, and a map of users to devices to failure + a list of signatures to store, and a map of users to devices to failure reasons Raises: @@ -930,6 +931,9 @@ class E2eKeysHandler(object): Returns: dict, str, VerifyKey: the raw key data, the key ID, and the signedjson verify key + + Raises: + NotFoundError: if the key is not found """ key = yield self.store.get_e2e_cross_signing_key( user_id, key_type, from_user_id From 93eaeec75a2d3be89df0040b1374d339e92bb9b9 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 18 Oct 2019 19:43:36 +0200 Subject: [PATCH 40/55] Remove Auth.check method (#6217) This method was somewhat redundant, and confusing. --- changelog.d/6217.misc | 1 + synapse/api/auth.py | 19 +------------------ synapse/handlers/federation.py | 7 ++++--- 3 files changed, 6 insertions(+), 21 deletions(-) create mode 100644 changelog.d/6217.misc diff --git a/changelog.d/6217.misc b/changelog.d/6217.misc new file mode 100644 index 0000000000..503352ee0b --- /dev/null +++ b/changelog.d/6217.misc @@ -0,0 +1 @@ +Remove Auth.check method. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index cb50579fd2..cd347fbe1b 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -84,27 +84,10 @@ class Auth(object): ) auth_events = yield self.store.get_events(auth_events_ids) auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} - self.check( + event_auth.check( room_version, event, auth_events=auth_events, do_sig_check=do_sig_check ) - def check(self, room_version, event, auth_events, do_sig_check=True): - """ Checks if this event is correctly authed. - - Args: - room_version (str): version of the room - event: the event being checked. - auth_events (dict: event-key -> event): the existing room state. - - - Returns: - True if the auth checks pass. - """ - with Measure(self.clock, "auth.check"): - event_auth.check( - room_version, event, auth_events, do_sig_check=do_sig_check - ) - @defer.inlineCallbacks def check_joined_room(self, room_id, user_id, current_state=None): """Check if the user is currently joined in the room diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 57f661f16e..4b4c6c15f9 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -30,6 +30,7 @@ from unpaddedbase64 import decode_base64 from twisted.internet import defer +from synapse import event_auth from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.api.errors import ( AuthError, @@ -1763,7 +1764,7 @@ class FederationHandler(BaseHandler): auth_for_e[(EventTypes.Create, "")] = create_event try: - self.auth.check(room_version, e, auth_events=auth_for_e) + event_auth.check(room_version, e, auth_events=auth_for_e) except SynapseError as err: # we may get SynapseErrors here as well as AuthErrors. For # instance, there are a couple of (ancient) events in some @@ -1919,7 +1920,7 @@ class FederationHandler(BaseHandler): } try: - self.auth.check(room_version, event, auth_events=current_auth_events) + event_auth.check(room_version, event, auth_events=current_auth_events) except AuthError as e: logger.warn("Soft-failing %r because %s", event, e) event.internal_metadata.soft_failed = True @@ -2018,7 +2019,7 @@ class FederationHandler(BaseHandler): ) try: - self.auth.check(room_version, event, auth_events=auth_events) + event_auth.check(room_version, event, auth_events=auth_events) except AuthError as e: logger.warn("Failed auth resolution for %r because %s", event, e) raise e From 6493ed572380828dfa9ed4c900deada30ceb0604 Mon Sep 17 00:00:00 2001 From: Hubert Chathi Date: Fri, 18 Oct 2019 18:45:36 +0100 Subject: [PATCH 41/55] Add changelog entry ... again? How did you make it disappear, git? --- changelog.d/5726.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/5726.feature diff --git a/changelog.d/5726.feature b/changelog.d/5726.feature new file mode 100644 index 0000000000..d3c669aec0 --- /dev/null +++ b/changelog.d/5726.feature @@ -0,0 +1 @@ +Add ability to upload cross-signing signatures. From 22a9f75097bb51d17d3b1f824665b51607f2b95e Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Sat, 19 Oct 2019 19:42:10 +0200 Subject: [PATCH 42/55] Delete format_tap.py (#6219) * Delete format_tap.py This python implementation of a tap formatting library for buildkite has been replaced with a perl implementation as part of the matrix-org/sytest repo, which is specific to sytest's language, not that of any one homeserver's. --- .buildkite/format_tap.py | 48 ---------------------------------------- changelog.d/6219.misc | 1 + 2 files changed, 1 insertion(+), 48 deletions(-) delete mode 100644 .buildkite/format_tap.py create mode 100644 changelog.d/6219.misc diff --git a/.buildkite/format_tap.py b/.buildkite/format_tap.py deleted file mode 100644 index b557a9c38e..0000000000 --- a/.buildkite/format_tap.py +++ /dev/null @@ -1,48 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2019 The Matrix.org Foundation C.I.C. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -from tap.parser import Parser -from tap.line import Result, Unknown, Diagnostic - -out = ["### TAP Output for " + sys.argv[2]] - -p = Parser() - -in_error = False - -for line in p.parse_file(sys.argv[1]): - if isinstance(line, Result): - if in_error: - out.append("") - out.append("") - out.append("") - out.append("----") - out.append("") - in_error = False - - if not line.ok and not line.todo: - in_error = True - - out.append("FAILURE Test #%d: ``%s``" % (line.number, line.description)) - out.append("") - out.append("
Show log
")
-
-    elif isinstance(line, Diagnostic) and in_error:
-        out.append(line.text)
-
-if out:
-    for line in out[:-3]:
-        print(line)
diff --git a/changelog.d/6219.misc b/changelog.d/6219.misc
new file mode 100644
index 0000000000..296406246d
--- /dev/null
+++ b/changelog.d/6219.misc
@@ -0,0 +1 @@
+Remove `format_tap.py` script in favour of a perl reimplementation in Sytest's repo.
\ No newline at end of file

From c66a06ac6b69b0a03f5c6284ded980399e9df94e Mon Sep 17 00:00:00 2001
From: Erik Johnston 
Date: Mon, 21 Oct 2019 12:56:42 +0100
Subject: [PATCH 43/55] Move storage classes into a main "data store".

This is in preparation for having multiple data stores that offer
different functionality, e.g. splitting out state or event storage.
---
 synapse/app/event_creator.py                  |    2 +-
 synapse/app/media_repository.py               |    2 +-
 synapse/app/synchrotron.py                    |    2 +-
 synapse/app/user_dir.py                       |    2 +-
 .../sender/per_destination_queue.py           |    2 +-
 .../replication/slave/storage/account_data.py |    4 +-
 .../replication/slave/storage/appservice.py   |    2 +-
 .../replication/slave/storage/client_ips.py   |    2 +-
 .../replication/slave/storage/deviceinbox.py  |    2 +-
 synapse/replication/slave/storage/devices.py  |    4 +-
 .../replication/slave/storage/directory.py    |    2 +-
 synapse/replication/slave/storage/events.py   |   20 +-
 .../replication/slave/storage/filtering.py    |    2 +-
 synapse/replication/slave/storage/keys.py     |    2 +-
 synapse/replication/slave/storage/presence.py |    2 +-
 synapse/replication/slave/storage/profile.py  |    2 +-
 .../replication/slave/storage/push_rule.py    |    2 +-
 synapse/replication/slave/storage/pushers.py  |    2 +-
 synapse/replication/slave/storage/receipts.py |    2 +-
 .../replication/slave/storage/registration.py |    2 +-
 synapse/replication/slave/storage/room.py     |    2 +-
 .../replication/slave/storage/transactions.py |    2 +-
 synapse/storage/__init__.py                   |  504 +------
 synapse/storage/data_stores/__init__.py       |   14 +
 synapse/storage/data_stores/main/__init__.py  |  524 +++++++
 .../{ => data_stores/main}/account_data.py    |    0
 .../{ => data_stores/main}/appservice.py      |    5 +-
 .../{ => data_stores/main}/client_ips.py      |    5 +-
 .../{ => data_stores/main}/deviceinbox.py     |    0
 .../storage/{ => data_stores/main}/devices.py |    0
 .../{ => data_stores/main}/directory.py       |    3 +-
 .../{ => data_stores/main}/e2e_room_keys.py   |    3 +-
 .../{ => data_stores/main}/end_to_end_keys.py |    3 +-
 .../main}/event_federation.py                 |    4 +-
 .../main}/event_push_actions.py               |    0
 .../storage/{ => data_stores/main}/events.py  |    6 +-
 .../main}/events_bg_updates.py                |    0
 .../{ => data_stores/main}/events_worker.py   |    0
 .../{ => data_stores/main}/filtering.py       |    3 +-
 .../{ => data_stores/main}/group_server.py    |    3 +-
 synapse/storage/data_stores/main/keys.py      |  214 +++
 .../main}/media_repository.py                 |    0
 .../main}/monthly_active_users.py             |    3 +-
 .../storage/{ => data_stores/main}/openid.py  |    2 +-
 synapse/storage/data_stores/main/presence.py  |  150 ++
 .../storage/{ => data_stores/main}/profile.py |    5 +-
 synapse/storage/data_stores/main/push_rule.py |  713 ++++++++++
 .../storage/{ => data_stores/main}/pusher.py  |    3 +-
 .../{ => data_stores/main}/receipts.py        |    0
 .../{ => data_stores/main}/registration.py    |    0
 .../{ => data_stores/main}/rejections.py      |    2 +-
 synapse/storage/data_stores/main/relations.py |  385 +++++
 .../storage/{ => data_stores/main}/room.py    |    2 +-
 .../storage/data_stores/main/roommember.py    | 1145 +++++++++++++++
 .../main}/schema/delta/12/v12.sql             |    0
 .../main}/schema/delta/13/v13.sql             |    0
 .../main}/schema/delta/14/v14.sql             |    0
 .../main}/schema/delta/15/appservice_txns.sql |    0
 .../schema/delta/15/presence_indices.sql      |    0
 .../main}/schema/delta/15/v15.sql             |    0
 .../schema/delta/16/events_order_index.sql    |    0
 .../delta/16/remote_media_cache_index.sql     |    0
 .../schema/delta/16/remove_duplicates.sql     |    0
 .../schema/delta/16/room_alias_index.sql      |    0
 .../schema/delta/16/unique_constraints.sql    |    0
 .../main}/schema/delta/16/users.sql           |    0
 .../main}/schema/delta/17/drop_indexes.sql    |    0
 .../main}/schema/delta/17/server_keys.sql     |    0
 .../main}/schema/delta/17/user_threepids.sql  |    0
 .../delta/18/server_keys_bigger_ints.sql      |    0
 .../main}/schema/delta/19/event_index.sql     |    0
 .../main}/schema/delta/20/dummy.sql           |    0
 .../main}/schema/delta/20/pushers.py          |    0
 .../main}/schema/delta/21/end_to_end_keys.sql |    0
 .../main}/schema/delta/21/receipts.sql        |    0
 .../main}/schema/delta/22/receipts_index.sql  |    0
 .../schema/delta/22/user_threepids_unique.sql |    0
 .../schema/delta/23/drop_state_index.sql      |    0
 .../main}/schema/delta/24/stats_reporting.sql |    0
 .../schema/delta/25/00background_updates.sql  |   21 +
 .../main}/schema/delta/25/fts.py              |    0
 .../main}/schema/delta/25/guest_access.sql    |    0
 .../schema/delta/25/history_visibility.sql    |    0
 .../main}/schema/delta/25/tags.sql            |    0
 .../main}/schema/delta/26/account_data.sql    |    0
 .../main}/schema/delta/27/account_data.sql    |    0
 .../schema/delta/27/forgotten_memberships.sql |    0
 .../main}/schema/delta/27/ts.py               |    0
 .../schema/delta/28/event_push_actions.sql    |    0
 .../schema/delta/28/events_room_stream.sql    |    0
 .../schema/delta/28/public_roms_index.sql     |    0
 .../delta/28/receipts_user_id_index.sql       |    0
 .../main}/schema/delta/28/upgrade_times.sql   |    0
 .../main}/schema/delta/28/users_is_guest.sql  |    0
 .../main}/schema/delta/29/push_actions.sql    |    0
 .../main}/schema/delta/30/alias_creator.sql   |    0
 .../main}/schema/delta/30/as_users.py         |    0
 .../main}/schema/delta/30/deleted_pushers.sql |    0
 .../main}/schema/delta/30/presence_stream.sql |    0
 .../main}/schema/delta/30/public_rooms.sql    |    0
 .../schema/delta/30/push_rule_stream.sql      |    0
 .../main}/schema/delta/30/state_stream.sql    |    0
 .../delta/30/threepid_guest_access_tokens.sql |    0
 .../main}/schema/delta/31/invites.sql         |    0
 .../31/local_media_repository_url_cache.sql   |    0
 .../main}/schema/delta/31/pushers.py          |    0
 .../main}/schema/delta/31/pushers_index.sql   |    0
 .../main}/schema/delta/31/search_update.py    |    0
 .../main}/schema/delta/32/events.sql          |    0
 .../main}/schema/delta/32/openid.sql          |    0
 .../main}/schema/delta/32/pusher_throttle.sql |    0
 .../main}/schema/delta/32/remove_indices.sql  |    0
 .../main}/schema/delta/32/reports.sql         |    0
 .../delta/33/access_tokens_device_index.sql   |    0
 .../main}/schema/delta/33/devices.sql         |    0
 .../schema/delta/33/devices_for_e2e_keys.sql  |    0
 ...ices_for_e2e_keys_clear_unknown_device.sql |    0
 .../main}/schema/delta/33/event_fields.py     |    0
 .../main}/schema/delta/33/remote_media_ts.py  |    0
 .../main}/schema/delta/33/user_ips_index.sql  |    0
 .../schema/delta/34/appservice_stream.sql     |    0
 .../main}/schema/delta/34/cache_stream.py     |    0
 .../main}/schema/delta/34/device_inbox.sql    |    0
 .../delta/34/push_display_name_rename.sql     |    0
 .../schema/delta/34/received_txn_purge.py     |    0
 .../main}/schema/delta/35/add_state_index.sql |    3 -
 .../main}/schema/delta/35/contains_url.sql    |    0
 .../main}/schema/delta/35/device_outbox.sql   |    0
 .../schema/delta/35/device_stream_id.sql      |    0
 .../delta/35/event_push_actions_index.sql     |    0
 .../35/public_room_list_change_stream.sql     |    0
 .../main}/schema/delta/35/state.sql           |    0
 .../main}/schema/delta/35/state_dedupe.sql    |    0
 .../delta/35/stream_order_to_extrem.sql       |    0
 .../schema/delta/36/readd_public_rooms.sql    |    0
 .../main}/schema/delta/37/remove_auth_idx.py  |    0
 .../main}/schema/delta/37/user_threepids.sql  |    0
 .../schema/delta/38/postgres_fts_gist.sql     |    0
 .../schema/delta/39/appservice_room_list.sql  |    0
 .../delta/39/device_federation_stream_idx.sql |    0
 .../schema/delta/39/event_push_index.sql      |    0
 .../delta/39/federation_out_position.sql      |    0
 .../schema/delta/39/membership_profile.sql    |    0
 .../schema/delta/40/current_state_idx.sql     |    0
 .../main}/schema/delta/40/device_inbox.sql    |    0
 .../schema/delta/40/device_list_streams.sql   |    0
 .../schema/delta/40/event_push_summary.sql    |    0
 .../main}/schema/delta/40/pushers.sql         |    0
 .../delta/41/device_list_stream_idx.sql       |    0
 .../schema/delta/41/device_outbound_index.sql |    0
 .../delta/41/event_search_event_id_idx.sql    |    0
 .../main}/schema/delta/41/ratelimit.sql       |    0
 .../schema/delta/42/current_state_delta.sql   |    0
 .../schema/delta/42/device_list_last_id.sql   |    0
 .../schema/delta/42/event_auth_state_only.sql |    0
 .../main}/schema/delta/42/user_dir.py         |    0
 .../main}/schema/delta/43/blocked_rooms.sql   |    0
 .../schema/delta/43/quarantine_media.sql      |    0
 .../main}/schema/delta/43/url_cache.sql       |    0
 .../main}/schema/delta/43/user_share.sql      |    0
 .../schema/delta/44/expire_url_cache.sql      |    0
 .../main}/schema/delta/45/group_server.sql    |    0
 .../main}/schema/delta/45/profile_cache.sql   |    0
 .../schema/delta/46/drop_refresh_tokens.sql   |    0
 .../delta/46/drop_unique_deleted_pushers.sql  |    0
 .../main}/schema/delta/46/group_server.sql    |    0
 .../46/local_media_repository_url_idx.sql     |    0
 .../delta/46/user_dir_null_room_ids.sql       |    0
 .../main}/schema/delta/46/user_dir_typos.sql  |    0
 .../schema/delta/47/last_access_media.sql     |    0
 .../schema/delta/47/postgres_fts_gin.sql      |    0
 .../schema/delta/47/push_actions_staging.sql  |    0
 .../main}/schema/delta/47/state_group_seq.py  |    0
 .../schema/delta/48/add_user_consent.sql      |    0
 .../delta/48/add_user_ips_last_seen_index.sql |    0
 .../schema/delta/48/deactivated_users.sql     |    0
 .../schema/delta/48/group_unique_indexes.py   |    0
 .../main}/schema/delta/48/groups_joinable.sql |    0
 .../add_user_consent_server_notice_sent.sql   |    0
 .../schema/delta/49/add_user_daily_visits.sql |    0
 .../49/add_user_ips_last_seen_only_index.sql  |    0
 .../delta/50/add_creation_ts_users_index.sql  |    0
 .../main}/schema/delta/50/erasure_store.sql   |    0
 .../delta/50/make_event_content_nullable.py   |    0
 .../main}/schema/delta/51/e2e_room_keys.sql   |    0
 .../schema/delta/51/monthly_active_users.sql  |    0
 .../52/add_event_to_state_group_index.sql     |    0
 .../52/device_list_streams_unique_idx.sql     |    0
 .../main}/schema/delta/52/e2e_room_keys.sql   |    0
 .../delta/53/add_user_type_to_users.sql       |    0
 .../delta/53/drop_sent_transactions.sql       |    0
 .../schema/delta/53/event_format_version.sql  |    0
 .../schema/delta/53/user_dir_populate.sql     |    0
 .../main}/schema/delta/53/user_ips_index.sql  |    0
 .../main}/schema/delta/53/user_share.sql      |    0
 .../schema/delta/53/user_threepid_id.sql      |    0
 .../schema/delta/53/users_in_public_rooms.sql |    0
 .../54/account_validity_with_renewal.sql      |    0
 .../delta/54/add_validity_to_server_keys.sql  |    0
 .../delta/54/delete_forward_extremities.sql   |    0
 .../schema/delta/54/drop_legacy_tables.sql    |    0
 .../schema/delta/54/drop_presence_list.sql    |    0
 .../main}/schema/delta/54/relations.sql       |    0
 .../main}/schema/delta/54/stats.sql           |    0
 .../main}/schema/delta/54/stats2.sql          |    0
 .../schema/delta/55/access_token_expiry.sql   |    0
 .../delta/55/track_threepid_validations.sql   |    0
 .../delta/55/users_alter_deactivated.sql      |    0
 .../delta/56/add_spans_to_device_lists.sql    |    0
 .../56/current_state_events_membership.sql    |    0
 .../current_state_events_membership_mk2.sql   |    0
 .../delta/56/destinations_failure_ts.sql      |    0
 ...tinations_retry_interval_type.sql.postgres |    0
 .../schema/delta/56/devices_last_seen.sql     |    0
 .../delta/56/drop_unused_event_tables.sql     |    0
 .../schema/delta/56/fix_room_keys_index.sql   |    0
 .../schema/delta/56/public_room_list_idx.sql  |    0
 .../schema/delta/56/redaction_censor.sql      |    0
 .../schema/delta/56/redaction_censor2.sql     |    0
 .../redaction_censor3_fix_update.sql.postgres |    0
 .../schema/delta/56/room_membership_idx.sql   |    0
 .../main}/schema/delta/56/stats_separated.sql |    0
 .../delta/56/unique_user_filter_index.py      |    0
 .../schema/delta/56/user_external_ids.sql     |    0
 .../delta/56/users_in_public_rooms_idx.sql    |    0
 .../full_schemas/16/application_services.sql  |    0
 .../schema/full_schemas/16/event_edges.sql    |    0
 .../full_schemas/16/event_signatures.sql      |    0
 .../main}/schema/full_schemas/16/im.sql       |    0
 .../main}/schema/full_schemas/16/keys.sql     |    0
 .../full_schemas/16/media_repository.sql      |    0
 .../main}/schema/full_schemas/16/presence.sql |    0
 .../main}/schema/full_schemas/16/profiles.sql |    0
 .../main}/schema/full_schemas/16/push.sql     |    0
 .../schema/full_schemas/16/redactions.sql     |    0
 .../schema/full_schemas/16/room_aliases.sql   |    0
 .../main}/schema/full_schemas/16/state.sql    |    0
 .../schema/full_schemas/16/transactions.sql   |    0
 .../main}/schema/full_schemas/16/users.sql    |    0
 .../schema/full_schemas/54/full.sql.postgres  |   17 -
 .../schema/full_schemas/54/full.sql.sqlite    |    1 -
 .../full_schemas/54/stream_positions.sql      |    0
 .../main}/schema/full_schemas/README.txt      |    0
 .../storage/{ => data_stores/main}/search.py  |    3 +-
 .../{ => data_stores/main}/signatures.py      |    3 +-
 synapse/storage/data_stores/main/state.py     | 1244 +++++++++++++++++
 .../{ => data_stores/main}/state_deltas.py    |    0
 .../storage/{ => data_stores/main}/stats.py   |    4 +-
 .../storage/{ => data_stores/main}/stream.py  |    2 +-
 .../storage/{ => data_stores/main}/tags.py    |    2 +-
 .../{ => data_stores/main}/transactions.py    |    3 +-
 .../{ => data_stores/main}/user_directory.py  |    4 +-
 .../main}/user_erasure_store.py               |    0
 synapse/storage/keys.py                       |  194 ---
 synapse/storage/presence.py                   |  134 --
 synapse/storage/push_rule.py                  |  698 ---------
 synapse/storage/relations.py                  |  359 -----
 synapse/storage/roommember.py                 | 1119 ---------------
 .../delta/35/00background_updates_add_col.sql |   17 +
 .../storage/schema/full_schemas/54/full.sql   |    8 +
 synapse/storage/state.py                      | 1221 ----------------
 tests/handlers/test_stats.py                  |    8 +-
 tests/storage/test_appservice.py              |    2 +-
 tests/storage/test_cleanup_extrems.py         |    2 +
 tests/storage/test_profile.py                 |    2 +-
 tests/storage/test_user_directory.py          |    2 +-
 266 files changed, 4509 insertions(+), 4331 deletions(-)
 create mode 100644 synapse/storage/data_stores/__init__.py
 create mode 100644 synapse/storage/data_stores/main/__init__.py
 rename synapse/storage/{ => data_stores/main}/account_data.py (100%)
 rename synapse/storage/{ => data_stores/main}/appservice.py (99%)
 rename synapse/storage/{ => data_stores/main}/client_ips.py (99%)
 rename synapse/storage/{ => data_stores/main}/deviceinbox.py (100%)
 rename synapse/storage/{ => data_stores/main}/devices.py (100%)
 rename synapse/storage/{ => data_stores/main}/directory.py (99%)
 rename synapse/storage/{ => data_stores/main}/e2e_room_keys.py (99%)
 rename synapse/storage/{ => data_stores/main}/end_to_end_keys.py (99%)
 rename synapse/storage/{ => data_stores/main}/event_federation.py (99%)
 rename synapse/storage/{ => data_stores/main}/event_push_actions.py (100%)
 rename synapse/storage/{ => data_stores/main}/events.py (99%)
 rename synapse/storage/{ => data_stores/main}/events_bg_updates.py (100%)
 rename synapse/storage/{ => data_stores/main}/events_worker.py (100%)
 rename synapse/storage/{ => data_stores/main}/filtering.py (97%)
 rename synapse/storage/{ => data_stores/main}/group_server.py (99%)
 create mode 100644 synapse/storage/data_stores/main/keys.py
 rename synapse/storage/{ => data_stores/main}/media_repository.py (100%)
 rename synapse/storage/{ => data_stores/main}/monthly_active_users.py (99%)
 rename synapse/storage/{ => data_stores/main}/openid.py (95%)
 create mode 100644 synapse/storage/data_stores/main/presence.py
 rename synapse/storage/{ => data_stores/main}/profile.py (98%)
 create mode 100644 synapse/storage/data_stores/main/push_rule.py
 rename synapse/storage/{ => data_stores/main}/pusher.py (99%)
 rename synapse/storage/{ => data_stores/main}/receipts.py (100%)
 rename synapse/storage/{ => data_stores/main}/registration.py (100%)
 rename synapse/storage/{ => data_stores/main}/rejections.py (96%)
 create mode 100644 synapse/storage/data_stores/main/relations.py
 rename synapse/storage/{ => data_stores/main}/room.py (99%)
 create mode 100644 synapse/storage/data_stores/main/roommember.py
 rename synapse/storage/{ => data_stores/main}/schema/delta/12/v12.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/13/v13.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/14/v14.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/15/appservice_txns.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/15/presence_indices.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/15/v15.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/16/events_order_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/16/remote_media_cache_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/16/remove_duplicates.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/16/room_alias_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/16/unique_constraints.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/16/users.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/17/drop_indexes.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/17/server_keys.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/17/user_threepids.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/18/server_keys_bigger_ints.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/19/event_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/20/dummy.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/20/pushers.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/21/end_to_end_keys.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/21/receipts.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/22/receipts_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/22/user_threepids_unique.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/23/drop_state_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/24/stats_reporting.sql (100%)
 create mode 100644 synapse/storage/data_stores/main/schema/delta/25/00background_updates.sql
 rename synapse/storage/{ => data_stores/main}/schema/delta/25/fts.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/25/guest_access.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/25/history_visibility.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/25/tags.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/26/account_data.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/27/account_data.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/27/forgotten_memberships.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/27/ts.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/28/event_push_actions.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/28/events_room_stream.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/28/public_roms_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/28/receipts_user_id_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/28/upgrade_times.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/28/users_is_guest.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/29/push_actions.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/30/alias_creator.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/30/as_users.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/30/deleted_pushers.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/30/presence_stream.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/30/public_rooms.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/30/push_rule_stream.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/30/state_stream.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/30/threepid_guest_access_tokens.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/31/invites.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/31/local_media_repository_url_cache.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/31/pushers.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/31/pushers_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/31/search_update.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/32/events.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/32/openid.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/32/pusher_throttle.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/32/remove_indices.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/32/reports.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/33/access_tokens_device_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/33/devices.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/33/devices_for_e2e_keys.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/33/event_fields.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/33/remote_media_ts.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/33/user_ips_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/34/appservice_stream.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/34/cache_stream.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/34/device_inbox.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/34/push_display_name_rename.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/34/received_txn_purge.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/35/add_state_index.sql (92%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/35/contains_url.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/35/device_outbox.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/35/device_stream_id.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/35/event_push_actions_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/35/public_room_list_change_stream.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/35/state.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/35/state_dedupe.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/35/stream_order_to_extrem.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/36/readd_public_rooms.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/37/remove_auth_idx.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/37/user_threepids.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/38/postgres_fts_gist.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/39/appservice_room_list.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/39/device_federation_stream_idx.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/39/event_push_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/39/federation_out_position.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/39/membership_profile.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/40/current_state_idx.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/40/device_inbox.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/40/device_list_streams.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/40/event_push_summary.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/40/pushers.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/41/device_list_stream_idx.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/41/device_outbound_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/41/event_search_event_id_idx.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/41/ratelimit.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/42/current_state_delta.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/42/device_list_last_id.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/42/event_auth_state_only.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/42/user_dir.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/43/blocked_rooms.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/43/quarantine_media.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/43/url_cache.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/43/user_share.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/44/expire_url_cache.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/45/group_server.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/45/profile_cache.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/46/drop_refresh_tokens.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/46/drop_unique_deleted_pushers.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/46/group_server.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/46/local_media_repository_url_idx.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/46/user_dir_null_room_ids.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/46/user_dir_typos.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/47/last_access_media.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/47/postgres_fts_gin.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/47/push_actions_staging.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/47/state_group_seq.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/48/add_user_consent.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/48/add_user_ips_last_seen_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/48/deactivated_users.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/48/group_unique_indexes.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/48/groups_joinable.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/49/add_user_consent_server_notice_sent.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/49/add_user_daily_visits.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/49/add_user_ips_last_seen_only_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/50/add_creation_ts_users_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/50/erasure_store.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/50/make_event_content_nullable.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/51/e2e_room_keys.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/51/monthly_active_users.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/52/add_event_to_state_group_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/52/device_list_streams_unique_idx.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/52/e2e_room_keys.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/53/add_user_type_to_users.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/53/drop_sent_transactions.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/53/event_format_version.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/53/user_dir_populate.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/53/user_ips_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/53/user_share.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/53/user_threepid_id.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/53/users_in_public_rooms.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/54/account_validity_with_renewal.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/54/add_validity_to_server_keys.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/54/delete_forward_extremities.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/54/drop_legacy_tables.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/54/drop_presence_list.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/54/relations.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/54/stats.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/54/stats2.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/55/access_token_expiry.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/55/track_threepid_validations.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/55/users_alter_deactivated.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/add_spans_to_device_lists.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/current_state_events_membership.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/current_state_events_membership_mk2.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/destinations_failure_ts.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/destinations_retry_interval_type.sql.postgres (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/devices_last_seen.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/drop_unused_event_tables.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/fix_room_keys_index.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/public_room_list_idx.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/redaction_censor.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/redaction_censor2.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/redaction_censor3_fix_update.sql.postgres (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/room_membership_idx.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/stats_separated.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/unique_user_filter_index.py (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/user_external_ids.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/delta/56/users_in_public_rooms_idx.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/application_services.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/event_edges.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/event_signatures.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/im.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/keys.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/media_repository.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/presence.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/profiles.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/push.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/redactions.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/room_aliases.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/state.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/transactions.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/16/users.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/54/full.sql.postgres (99%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/54/full.sql.sqlite (99%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/54/stream_positions.sql (100%)
 rename synapse/storage/{ => data_stores/main}/schema/full_schemas/README.txt (100%)
 rename synapse/storage/{ => data_stores/main}/search.py (99%)
 rename synapse/storage/{ => data_stores/main}/signatures.py (98%)
 create mode 100644 synapse/storage/data_stores/main/state.py
 rename synapse/storage/{ => data_stores/main}/state_deltas.py (100%)
 rename synapse/storage/{ => data_stores/main}/stats.py (99%)
 rename synapse/storage/{ => data_stores/main}/stream.py (99%)
 rename synapse/storage/{ => data_stores/main}/tags.py (99%)
 rename synapse/storage/{ => data_stores/main}/transactions.py (99%)
 rename synapse/storage/{ => data_stores/main}/user_directory.py (99%)
 rename synapse/storage/{ => data_stores/main}/user_erasure_store.py (100%)
 create mode 100644 synapse/storage/schema/delta/35/00background_updates_add_col.sql
 create mode 100644 synapse/storage/schema/full_schemas/54/full.sql

diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index c67fe69a50..f20d810ece 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -56,8 +56,8 @@ from synapse.rest.client.v1.room import (
     RoomStateEventRestServlet,
 )
 from synapse.server import HomeServer
+from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
 from synapse.storage.engines import create_engine
-from synapse.storage.user_directory import UserDirectoryStore
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.versionstring import get_version_string
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 2ac783ffa3..6bc7202f33 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -39,8 +39,8 @@ from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.rest.admin import register_servlets_for_media_repo
 from synapse.rest.media.v0.content_repository import ContentRepoResource
 from synapse.server import HomeServer
+from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
 from synapse.storage.engines import create_engine
-from synapse.storage.media_repository import MediaRepositoryStore
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.versionstring import get_version_string
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 473026fce5..6a7e2fa707 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -54,8 +54,8 @@ from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
 from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
 from synapse.rest.client.v2_alpha import sync
 from synapse.server import HomeServer
+from synapse.storage.data_stores.main.presence import UserPresenceState
 from synapse.storage.engines import create_engine
-from synapse.storage.presence import UserPresenceState
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
 from synapse.util.stringutils import random_string
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index e01afb39f2..a5d6dc7915 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -42,8 +42,8 @@ from synapse.replication.tcp.streams.events import (
 )
 from synapse.rest.client.v2_alpha import user_directory
 from synapse.server import HomeServer
+from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
 from synapse.storage.engines import create_engine
-from synapse.storage.user_directory import UserDirectoryStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.manhole import manhole
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index fad980b893..cc75c39476 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -30,7 +30,7 @@ from synapse.federation.units import Edu
 from synapse.handlers.presence import format_user_presence_state
 from synapse.metrics import sent_transactions_counter
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage import UserPresenceState
+from synapse.storage.presence import UserPresenceState
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 
 # This is defined in the Matrix spec and enforced by the receiver.
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 3c44d1d48d..bc2f6a12ae 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -16,8 +16,8 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.account_data import AccountDataWorkerStore
-from synapse.storage.tags import TagsWorkerStore
+from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
+from synapse.storage.data_stores.main.tags import TagsWorkerStore
 
 
 class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index cda12ea70d..a67fbeffb7 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.appservice import (
+from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceTransactionWorkerStore,
     ApplicationServiceWorkerStore,
 )
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 14ced32333..b4f58cea19 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.client_ips import LAST_SEEN_GRANULARITY
+from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
 from synapse.util.caches import CACHE_SIZE_FACTOR
 from synapse.util.caches.descriptors import Cache
 
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 284fd30d89..9fb6c5c6ff 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -15,7 +15,7 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index d9300fce33..f856c72d84 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -15,8 +15,8 @@
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
 from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.devices import DeviceWorkerStore
-from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.data_stores.main.devices import DeviceWorkerStore
+from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 1d1d48709a..8b9717c46f 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.directory import DirectoryWorkerStore
+from synapse.storage.data_stores.main.directory import DirectoryWorkerStore
 
 from ._base import BaseSlavedStore
 
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index ab5937e638..d0a0eaf75b 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -20,15 +20,17 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamCurrentStateRow,
     EventsStreamEventRow,
 )
-from synapse.storage.event_federation import EventFederationWorkerStore
-from synapse.storage.event_push_actions import EventPushActionsWorkerStore
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.storage.relations import RelationsWorkerStore
-from synapse.storage.roommember import RoomMemberWorkerStore
-from synapse.storage.signatures import SignatureWorkerStore
-from synapse.storage.state import StateGroupWorkerStore
-from synapse.storage.stream import StreamWorkerStore
-from synapse.storage.user_erasure_store import UserErasureWorkerStore
+from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
+from synapse.storage.data_stores.main.event_push_actions import (
+    EventPushActionsWorkerStore,
+)
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.relations import RelationsWorkerStore
+from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
+from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.data_stores.main.stream import StreamWorkerStore
+from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 456a14cd5c..5c84ebd125 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.filtering import FilteringStore
+from synapse.storage.data_stores.main.filtering import FilteringStore
 
 from ._base import BaseSlavedStore
 
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index cc6f7f009f..3def367ae9 100644
--- a/synapse/replication/slave/storage/keys.py
+++ b/synapse/replication/slave/storage/keys.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage import KeyStore
+from synapse.storage.data_stores.main.keys import KeyStore
 
 # KeyStore isn't really safe to use from a worker, but for now we do so and hope that
 # the races it creates aren't too bad.
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 82d808af4c..747ced0c84 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from synapse.storage import DataStore
-from synapse.storage.presence import PresenceStore
+from synapse.storage.data_stores.main.presence import PresenceStore
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 from ._base import BaseSlavedStore, __func__
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
index 46c28d4171..28c508aad3 100644
--- a/synapse/replication/slave/storage/profile.py
+++ b/synapse/replication/slave/storage/profile.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.profile import ProfileWorkerStore
+from synapse.storage.data_stores.main.profile import ProfileWorkerStore
 
 
 class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index af7012702e..3655f05e54 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.push_rule import PushRulesWorkerStore
+from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
 
 from ._slaved_id_tracker import SlavedIdTracker
 from .events import SlavedEventStore
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 8eeb267d61..b4331d0799 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.pusher import PusherWorkerStore
+from synapse.storage.data_stores.main.pusher import PusherWorkerStore
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 91afa5a72b..43d823c601 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.receipts import ReceiptsWorkerStore
+from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index 408d91df1c..4b8553e250 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.registration import RegistrationWorkerStore
+from synapse.storage.data_stores.main.registration import RegistrationWorkerStore
 
 from ._base import BaseSlavedStore
 
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index f68b3378e3..d9ad386b28 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.room import RoomWorkerStore
+from synapse.storage.data_stores.main.room import RoomWorkerStore
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index 3527beb3c9..ac88e6b8c3 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.storage.transactions import TransactionStore
+from synapse.storage.data_stores.main.transactions import TransactionStore
 
 from ._base import BaseSlavedStore
 
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e7f6ea7286..e42fba45a1 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -14,509 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import calendar
-import logging
-import time
-
-from twisted.internet import defer
-
-from synapse.api.constants import PresenceState
-from synapse.storage.devices import DeviceStore
-from synapse.storage.user_erasure_store import UserErasureStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-from .account_data import AccountDataStore
-from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
-from .client_ips import ClientIpStore
-from .deviceinbox import DeviceInboxStore
-from .directory import DirectoryStore
-from .e2e_room_keys import EndToEndRoomKeyStore
-from .end_to_end_keys import EndToEndKeyStore
-from .engines import PostgresEngine
-from .event_federation import EventFederationStore
-from .event_push_actions import EventPushActionsStore
-from .events import EventsStore
-from .events_bg_updates import EventsBackgroundUpdatesStore
-from .filtering import FilteringStore
-from .group_server import GroupServerStore
-from .keys import KeyStore
-from .media_repository import MediaRepositoryStore
-from .monthly_active_users import MonthlyActiveUsersStore
-from .openid import OpenIdStore
-from .presence import PresenceStore, UserPresenceState
-from .profile import ProfileStore
-from .push_rule import PushRuleStore
-from .pusher import PusherStore
-from .receipts import ReceiptsStore
-from .registration import RegistrationStore
-from .rejections import RejectionsStore
-from .relations import RelationsStore
-from .room import RoomStore
-from .roommember import RoomMemberStore
-from .search import SearchStore
-from .signatures import SignatureStore
-from .state import StateStore
-from .stats import StatsStore
-from .stream import StreamStore
-from .tags import TagsStore
-from .transactions import TransactionStore
-from .user_directory import UserDirectoryStore
-from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator
-
-logger = logging.getLogger(__name__)
-
-
-class DataStore(
-    EventsBackgroundUpdatesStore,
-    RoomMemberStore,
-    RoomStore,
-    RegistrationStore,
-    StreamStore,
-    ProfileStore,
-    PresenceStore,
-    TransactionStore,
-    DirectoryStore,
-    KeyStore,
-    StateStore,
-    SignatureStore,
-    ApplicationServiceStore,
-    EventsStore,
-    EventFederationStore,
-    MediaRepositoryStore,
-    RejectionsStore,
-    FilteringStore,
-    PusherStore,
-    PushRuleStore,
-    ApplicationServiceTransactionStore,
-    ReceiptsStore,
-    EndToEndKeyStore,
-    EndToEndRoomKeyStore,
-    SearchStore,
-    TagsStore,
-    AccountDataStore,
-    EventPushActionsStore,
-    OpenIdStore,
-    ClientIpStore,
-    DeviceStore,
-    DeviceInboxStore,
-    UserDirectoryStore,
-    GroupServerStore,
-    UserErasureStore,
-    MonthlyActiveUsersStore,
-    StatsStore,
-    RelationsStore,
-):
-    def __init__(self, db_conn, hs):
-        self.hs = hs
-        self._clock = hs.get_clock()
-        self.database_engine = hs.database_engine
-
-        self._stream_id_gen = StreamIdGenerator(
-            db_conn,
-            "events",
-            "stream_ordering",
-            extra_tables=[("local_invites", "stream_id")],
-        )
-        self._backfill_id_gen = StreamIdGenerator(
-            db_conn,
-            "events",
-            "stream_ordering",
-            step=-1,
-            extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
-        )
-        self._presence_id_gen = StreamIdGenerator(
-            db_conn, "presence_stream", "stream_id"
-        )
-        self._device_inbox_id_gen = StreamIdGenerator(
-            db_conn, "device_max_stream_id", "stream_id"
-        )
-        self._public_room_id_gen = StreamIdGenerator(
-            db_conn, "public_room_list_stream", "stream_id"
-        )
-        self._device_list_id_gen = StreamIdGenerator(
-            db_conn, "device_lists_stream", "stream_id"
-        )
-
-        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
-        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
-        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
-        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
-        self._push_rules_stream_id_gen = ChainedIdGenerator(
-            self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
-        )
-        self._pushers_id_gen = StreamIdGenerator(
-            db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
-        )
-        self._group_updates_id_gen = StreamIdGenerator(
-            db_conn, "local_group_updates", "stream_id"
-        )
-
-        if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen = StreamIdGenerator(
-                db_conn, "cache_invalidation_stream", "stream_id"
-            )
-        else:
-            self._cache_id_gen = None
-
-        self._presence_on_startup = self._get_active_presence(db_conn)
-
-        presence_cache_prefill, min_presence_val = self._get_cache_dict(
-            db_conn,
-            "presence_stream",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=self._presence_id_gen.get_current_token(),
-        )
-        self.presence_stream_cache = StreamChangeCache(
-            "PresenceStreamChangeCache",
-            min_presence_val,
-            prefilled_cache=presence_cache_prefill,
-        )
-
-        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
-        device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
-            db_conn,
-            "device_inbox",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=max_device_inbox_id,
-            limit=1000,
-        )
-        self._device_inbox_stream_cache = StreamChangeCache(
-            "DeviceInboxStreamChangeCache",
-            min_device_inbox_id,
-            prefilled_cache=device_inbox_prefill,
-        )
-        # The federation outbox and the local device inbox uses the same
-        # stream_id generator.
-        device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
-            db_conn,
-            "device_federation_outbox",
-            entity_column="destination",
-            stream_column="stream_id",
-            max_value=max_device_inbox_id,
-            limit=1000,
-        )
-        self._device_federation_outbox_stream_cache = StreamChangeCache(
-            "DeviceFederationOutboxStreamChangeCache",
-            min_device_outbox_id,
-            prefilled_cache=device_outbox_prefill,
-        )
-
-        device_list_max = self._device_list_id_gen.get_current_token()
-        self._device_list_stream_cache = StreamChangeCache(
-            "DeviceListStreamChangeCache", device_list_max
-        )
-        self._device_list_federation_stream_cache = StreamChangeCache(
-            "DeviceListFederationStreamChangeCache", device_list_max
-        )
-
-        events_max = self._stream_id_gen.get_current_token()
-        curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
-            db_conn,
-            "current_state_delta_stream",
-            entity_column="room_id",
-            stream_column="stream_id",
-            max_value=events_max,  # As we share the stream id with events token
-            limit=1000,
-        )
-        self._curr_state_delta_stream_cache = StreamChangeCache(
-            "_curr_state_delta_stream_cache",
-            min_curr_state_delta_id,
-            prefilled_cache=curr_state_delta_prefill,
-        )
-
-        _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
-            db_conn,
-            "local_group_updates",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=self._group_updates_id_gen.get_current_token(),
-            limit=1000,
-        )
-        self._group_updates_stream_cache = StreamChangeCache(
-            "_group_updates_stream_cache",
-            min_group_updates_id,
-            prefilled_cache=_group_updates_prefill,
-        )
-
-        self._stream_order_on_start = self.get_room_max_stream_ordering()
-        self._min_stream_order_on_start = self.get_room_min_stream_ordering()
-
-        # Used in _generate_user_daily_visits to keep track of progress
-        self._last_user_visit_update = self._get_start_of_day()
-
-        super(DataStore, self).__init__(db_conn, hs)
-
-    def take_presence_startup_info(self):
-        active_on_startup = self._presence_on_startup
-        self._presence_on_startup = None
-        return active_on_startup
-
-    def _get_active_presence(self, db_conn):
-        """Fetch non-offline presence from the database so that we can register
-        the appropriate time outs.
-        """
-
-        sql = (
-            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
-            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
-            " WHERE state != ?"
-        )
-        sql = self.database_engine.convert_param_style(sql)
-
-        txn = db_conn.cursor()
-        txn.execute(sql, (PresenceState.OFFLINE,))
-        rows = self.cursor_to_dict(txn)
-        txn.close()
-
-        for row in rows:
-            row["currently_active"] = bool(row["currently_active"])
-
-        return [UserPresenceState(**row) for row in rows]
-
-    def count_daily_users(self):
-        """
-        Counts the number of users who used this homeserver in the last 24 hours.
-        """
-        yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-        return self.runInteraction("count_daily_users", self._count_users, yesterday)
-
-    def count_monthly_users(self):
-        """
-        Counts the number of users who used this homeserver in the last 30 days.
-        Note this method is intended for phonehome metrics only and is different
-        from the mau figure in synapse.storage.monthly_active_users which,
-        amongst other things, includes a 3 day grace period before a user counts.
-        """
-        thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
-        return self.runInteraction(
-            "count_monthly_users", self._count_users, thirty_days_ago
-        )
-
-    def _count_users(self, txn, time_from):
-        """
-        Returns number of users seen in the past time_from period
-        """
-        sql = """
-            SELECT COALESCE(count(*), 0) FROM (
-                SELECT user_id FROM user_ips
-                WHERE last_seen > ?
-                GROUP BY user_id
-            ) u
-        """
-        txn.execute(sql, (time_from,))
-        count, = txn.fetchone()
-        return count
-
-    def count_r30_users(self):
-        """
-        Counts the number of 30 day retained users, defined as:-
-         * Users who have created their accounts more than 30 days ago
-         * Where last seen at most 30 days ago
-         * Where account creation and last_seen are > 30 days apart
-
-         Returns counts globaly for a given user as well as breaking
-         by platform
-        """
-
-        def _count_r30_users(txn):
-            thirty_days_in_secs = 86400 * 30
-            now = int(self._clock.time())
-            thirty_days_ago_in_secs = now - thirty_days_in_secs
-
-            sql = """
-                SELECT platform, COALESCE(count(*), 0) FROM (
-                     SELECT
-                        users.name, platform, users.creation_ts * 1000,
-                        MAX(uip.last_seen)
-                     FROM users
-                     INNER JOIN (
-                         SELECT
-                         user_id,
-                         last_seen,
-                         CASE
-                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
-                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
-                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
-                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
-                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
-                             ELSE 'unknown'
-                         END
-                         AS platform
-                         FROM user_ips
-                     ) uip
-                     ON users.name = uip.user_id
-                     AND users.appservice_id is NULL
-                     AND users.creation_ts < ?
-                     AND uip.last_seen/1000 > ?
-                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
-                     GROUP BY users.name, platform, users.creation_ts
-                ) u GROUP BY platform
-            """
-
-            results = {}
-            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
-            for row in txn:
-                if row[0] == "unknown":
-                    pass
-                results[row[0]] = row[1]
-
-            sql = """
-                SELECT COALESCE(count(*), 0) FROM (
-                    SELECT users.name, users.creation_ts * 1000,
-                                                        MAX(uip.last_seen)
-                    FROM users
-                    INNER JOIN (
-                        SELECT
-                        user_id,
-                        last_seen
-                        FROM user_ips
-                    ) uip
-                    ON users.name = uip.user_id
-                    AND appservice_id is NULL
-                    AND users.creation_ts < ?
-                    AND uip.last_seen/1000 > ?
-                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
-                    GROUP BY users.name, users.creation_ts
-                ) u
-            """
-
-            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
-            count, = txn.fetchone()
-            results["all"] = count
-
-            return results
-
-        return self.runInteraction("count_r30_users", _count_r30_users)
-
-    def _get_start_of_day(self):
-        """
-        Returns millisecond unixtime for start of UTC day.
-        """
-        now = time.gmtime()
-        today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
-        return today_start * 1000
-
-    def generate_user_daily_visits(self):
-        """
-        Generates daily visit data for use in cohort/ retention analysis
-        """
-
-        def _generate_user_daily_visits(txn):
-            logger.info("Calling _generate_user_daily_visits")
-            today_start = self._get_start_of_day()
-            a_day_in_milliseconds = 24 * 60 * 60 * 1000
-            now = self.clock.time_msec()
-
-            sql = """
-                INSERT INTO user_daily_visits (user_id, device_id, timestamp)
-                    SELECT u.user_id, u.device_id, ?
-                    FROM user_ips AS u
-                    LEFT JOIN (
-                      SELECT user_id, device_id, timestamp FROM user_daily_visits
-                      WHERE timestamp = ?
-                    ) udv
-                    ON u.user_id = udv.user_id AND u.device_id=udv.device_id
-                    INNER JOIN users ON users.name=u.user_id
-                    WHERE last_seen > ? AND last_seen <= ?
-                    AND udv.timestamp IS NULL AND users.is_guest=0
-                    AND users.appservice_id IS NULL
-                    GROUP BY u.user_id, u.device_id
-            """
-
-            # This means that the day has rolled over but there could still
-            # be entries from the previous day. There is an edge case
-            # where if the user logs in at 23:59 and overwrites their
-            # last_seen at 00:01 then they will not be counted in the
-            # previous day's stats - it is important that the query is run
-            # often to minimise this case.
-            if today_start > self._last_user_visit_update:
-                yesterday_start = today_start - a_day_in_milliseconds
-                txn.execute(
-                    sql,
-                    (
-                        yesterday_start,
-                        yesterday_start,
-                        self._last_user_visit_update,
-                        today_start,
-                    ),
-                )
-                self._last_user_visit_update = today_start
-
-            txn.execute(
-                sql, (today_start, today_start, self._last_user_visit_update, now)
-            )
-            # Update _last_user_visit_update to now. The reason to do this
-            # rather just clamping to the beginning of the day is to limit
-            # the size of the join - meaning that the query can be run more
-            # frequently
-            self._last_user_visit_update = now
-
-        return self.runInteraction(
-            "generate_user_daily_visits", _generate_user_daily_visits
-        )
-
-    def get_users(self):
-        """Function to reterive a list of users in users table.
-
-        Args:
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self._simple_select_list(
-            table="users",
-            keyvalues={},
-            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
-            desc="get_users",
-        )
-
-    @defer.inlineCallbacks
-    def get_users_paginate(self, order, start, limit):
-        """Function to reterive a paginated list of users from
-        users list. This will return a json object, which contains
-        list of users and the total number of users in users table.
-
-        Args:
-            order (str): column name to order the select by this column
-            start (int): start number to begin the query from
-            limit (int): number of rows to reterive
-        Returns:
-            defer.Deferred: resolves to json object {list[dict[str, Any]], count}
-        """
-        users = yield self.runInteraction(
-            "get_users_paginate",
-            self._simple_select_list_paginate_txn,
-            table="users",
-            keyvalues={"is_guest": False},
-            orderby=order,
-            start=start,
-            limit=limit,
-            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
-        )
-        count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
-        retval = {"users": users, "total": count}
-        return retval
-
-    def search_users(self, term):
-        """Function to search users list for one or more users with
-        the matched term.
-
-        Args:
-            term (str): search term
-            col (str): column to query term should be matched to
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self._simple_search_list(
-            table="users",
-            term=term,
-            col="name",
-            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
-            desc="search_users",
-        )
+from synapse.storage.data_stores.main import DataStore  # noqa: F401
 
 
 def are_all_users_on_domain(txn, database_engine, domain):
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
new file mode 100644
index 0000000000..56094078ed
--- /dev/null
+++ b/synapse/storage/data_stores/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
new file mode 100644
index 0000000000..d29135588f
--- /dev/null
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -0,0 +1,524 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import calendar
+import logging
+import time
+
+from twisted.internet import defer
+
+from synapse.api.constants import PresenceState
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import (
+    ChainedIdGenerator,
+    IdGenerator,
+    StreamIdGenerator,
+)
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from .account_data import AccountDataStore
+from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
+from .client_ips import ClientIpStore
+from .deviceinbox import DeviceInboxStore
+from .devices import DeviceStore
+from .directory import DirectoryStore
+from .e2e_room_keys import EndToEndRoomKeyStore
+from .end_to_end_keys import EndToEndKeyStore
+from .event_federation import EventFederationStore
+from .event_push_actions import EventPushActionsStore
+from .events import EventsStore
+from .events_bg_updates import EventsBackgroundUpdatesStore
+from .filtering import FilteringStore
+from .group_server import GroupServerStore
+from .keys import KeyStore
+from .media_repository import MediaRepositoryStore
+from .monthly_active_users import MonthlyActiveUsersStore
+from .openid import OpenIdStore
+from .presence import PresenceStore, UserPresenceState
+from .profile import ProfileStore
+from .push_rule import PushRuleStore
+from .pusher import PusherStore
+from .receipts import ReceiptsStore
+from .registration import RegistrationStore
+from .rejections import RejectionsStore
+from .relations import RelationsStore
+from .room import RoomStore
+from .roommember import RoomMemberStore
+from .search import SearchStore
+from .signatures import SignatureStore
+from .state import StateStore
+from .stats import StatsStore
+from .stream import StreamStore
+from .tags import TagsStore
+from .transactions import TransactionStore
+from .user_directory import UserDirectoryStore
+from .user_erasure_store import UserErasureStore
+
+logger = logging.getLogger(__name__)
+
+
+class DataStore(
+    EventsBackgroundUpdatesStore,
+    RoomMemberStore,
+    RoomStore,
+    RegistrationStore,
+    StreamStore,
+    ProfileStore,
+    PresenceStore,
+    TransactionStore,
+    DirectoryStore,
+    KeyStore,
+    StateStore,
+    SignatureStore,
+    ApplicationServiceStore,
+    EventsStore,
+    EventFederationStore,
+    MediaRepositoryStore,
+    RejectionsStore,
+    FilteringStore,
+    PusherStore,
+    PushRuleStore,
+    ApplicationServiceTransactionStore,
+    ReceiptsStore,
+    EndToEndKeyStore,
+    EndToEndRoomKeyStore,
+    SearchStore,
+    TagsStore,
+    AccountDataStore,
+    EventPushActionsStore,
+    OpenIdStore,
+    ClientIpStore,
+    DeviceStore,
+    DeviceInboxStore,
+    UserDirectoryStore,
+    GroupServerStore,
+    UserErasureStore,
+    MonthlyActiveUsersStore,
+    StatsStore,
+    RelationsStore,
+):
+    def __init__(self, db_conn, hs):
+        self.hs = hs
+        self._clock = hs.get_clock()
+        self.database_engine = hs.database_engine
+
+        self._stream_id_gen = StreamIdGenerator(
+            db_conn,
+            "events",
+            "stream_ordering",
+            extra_tables=[("local_invites", "stream_id")],
+        )
+        self._backfill_id_gen = StreamIdGenerator(
+            db_conn,
+            "events",
+            "stream_ordering",
+            step=-1,
+            extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+        )
+        self._presence_id_gen = StreamIdGenerator(
+            db_conn, "presence_stream", "stream_id"
+        )
+        self._device_inbox_id_gen = StreamIdGenerator(
+            db_conn, "device_max_stream_id", "stream_id"
+        )
+        self._public_room_id_gen = StreamIdGenerator(
+            db_conn, "public_room_list_stream", "stream_id"
+        )
+        self._device_list_id_gen = StreamIdGenerator(
+            db_conn, "device_lists_stream", "stream_id"
+        )
+
+        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+        self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
+        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+        self._push_rules_stream_id_gen = ChainedIdGenerator(
+            self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+        )
+        self._pushers_id_gen = StreamIdGenerator(
+            db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+        )
+        self._group_updates_id_gen = StreamIdGenerator(
+            db_conn, "local_group_updates", "stream_id"
+        )
+
+        if isinstance(self.database_engine, PostgresEngine):
+            self._cache_id_gen = StreamIdGenerator(
+                db_conn, "cache_invalidation_stream", "stream_id"
+            )
+        else:
+            self._cache_id_gen = None
+
+        self._presence_on_startup = self._get_active_presence(db_conn)
+
+        presence_cache_prefill, min_presence_val = self._get_cache_dict(
+            db_conn,
+            "presence_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._presence_id_gen.get_current_token(),
+        )
+        self.presence_stream_cache = StreamChangeCache(
+            "PresenceStreamChangeCache",
+            min_presence_val,
+            prefilled_cache=presence_cache_prefill,
+        )
+
+        max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
+        device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
+            db_conn,
+            "device_inbox",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=max_device_inbox_id,
+            limit=1000,
+        )
+        self._device_inbox_stream_cache = StreamChangeCache(
+            "DeviceInboxStreamChangeCache",
+            min_device_inbox_id,
+            prefilled_cache=device_inbox_prefill,
+        )
+        # The federation outbox and the local device inbox uses the same
+        # stream_id generator.
+        device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
+            db_conn,
+            "device_federation_outbox",
+            entity_column="destination",
+            stream_column="stream_id",
+            max_value=max_device_inbox_id,
+            limit=1000,
+        )
+        self._device_federation_outbox_stream_cache = StreamChangeCache(
+            "DeviceFederationOutboxStreamChangeCache",
+            min_device_outbox_id,
+            prefilled_cache=device_outbox_prefill,
+        )
+
+        device_list_max = self._device_list_id_gen.get_current_token()
+        self._device_list_stream_cache = StreamChangeCache(
+            "DeviceListStreamChangeCache", device_list_max
+        )
+        self._device_list_federation_stream_cache = StreamChangeCache(
+            "DeviceListFederationStreamChangeCache", device_list_max
+        )
+
+        events_max = self._stream_id_gen.get_current_token()
+        curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
+            db_conn,
+            "current_state_delta_stream",
+            entity_column="room_id",
+            stream_column="stream_id",
+            max_value=events_max,  # As we share the stream id with events token
+            limit=1000,
+        )
+        self._curr_state_delta_stream_cache = StreamChangeCache(
+            "_curr_state_delta_stream_cache",
+            min_curr_state_delta_id,
+            prefilled_cache=curr_state_delta_prefill,
+        )
+
+        _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
+            db_conn,
+            "local_group_updates",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._group_updates_id_gen.get_current_token(),
+            limit=1000,
+        )
+        self._group_updates_stream_cache = StreamChangeCache(
+            "_group_updates_stream_cache",
+            min_group_updates_id,
+            prefilled_cache=_group_updates_prefill,
+        )
+
+        self._stream_order_on_start = self.get_room_max_stream_ordering()
+        self._min_stream_order_on_start = self.get_room_min_stream_ordering()
+
+        # Used in _generate_user_daily_visits to keep track of progress
+        self._last_user_visit_update = self._get_start_of_day()
+
+        super(DataStore, self).__init__(db_conn, hs)
+
+    def take_presence_startup_info(self):
+        active_on_startup = self._presence_on_startup
+        self._presence_on_startup = None
+        return active_on_startup
+
+    def _get_active_presence(self, db_conn):
+        """Fetch non-offline presence from the database so that we can register
+        the appropriate time outs.
+        """
+
+        sql = (
+            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
+            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
+            " WHERE state != ?"
+        )
+        sql = self.database_engine.convert_param_style(sql)
+
+        txn = db_conn.cursor()
+        txn.execute(sql, (PresenceState.OFFLINE,))
+        rows = self.cursor_to_dict(txn)
+        txn.close()
+
+        for row in rows:
+            row["currently_active"] = bool(row["currently_active"])
+
+        return [UserPresenceState(**row) for row in rows]
+
+    def count_daily_users(self):
+        """
+        Counts the number of users who used this homeserver in the last 24 hours.
+        """
+        yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
+        return self.runInteraction("count_daily_users", self._count_users, yesterday)
+
+    def count_monthly_users(self):
+        """
+        Counts the number of users who used this homeserver in the last 30 days.
+        Note this method is intended for phonehome metrics only and is different
+        from the mau figure in synapse.storage.monthly_active_users which,
+        amongst other things, includes a 3 day grace period before a user counts.
+        """
+        thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+        return self.runInteraction(
+            "count_monthly_users", self._count_users, thirty_days_ago
+        )
+
+    def _count_users(self, txn, time_from):
+        """
+        Returns number of users seen in the past time_from period
+        """
+        sql = """
+            SELECT COALESCE(count(*), 0) FROM (
+                SELECT user_id FROM user_ips
+                WHERE last_seen > ?
+                GROUP BY user_id
+            ) u
+        """
+        txn.execute(sql, (time_from,))
+        count, = txn.fetchone()
+        return count
+
+    def count_r30_users(self):
+        """
+        Counts the number of 30 day retained users, defined as:-
+         * Users who have created their accounts more than 30 days ago
+         * Where last seen at most 30 days ago
+         * Where account creation and last_seen are > 30 days apart
+
+         Returns counts globaly for a given user as well as breaking
+         by platform
+        """
+
+        def _count_r30_users(txn):
+            thirty_days_in_secs = 86400 * 30
+            now = int(self._clock.time())
+            thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+            sql = """
+                SELECT platform, COALESCE(count(*), 0) FROM (
+                     SELECT
+                        users.name, platform, users.creation_ts * 1000,
+                        MAX(uip.last_seen)
+                     FROM users
+                     INNER JOIN (
+                         SELECT
+                         user_id,
+                         last_seen,
+                         CASE
+                             WHEN user_agent LIKE '%%Android%%' THEN 'android'
+                             WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+                             WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+                             WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+                             WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+                             ELSE 'unknown'
+                         END
+                         AS platform
+                         FROM user_ips
+                     ) uip
+                     ON users.name = uip.user_id
+                     AND users.appservice_id is NULL
+                     AND users.creation_ts < ?
+                     AND uip.last_seen/1000 > ?
+                     AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                     GROUP BY users.name, platform, users.creation_ts
+                ) u GROUP BY platform
+            """
+
+            results = {}
+            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+            for row in txn:
+                if row[0] == "unknown":
+                    pass
+                results[row[0]] = row[1]
+
+            sql = """
+                SELECT COALESCE(count(*), 0) FROM (
+                    SELECT users.name, users.creation_ts * 1000,
+                                                        MAX(uip.last_seen)
+                    FROM users
+                    INNER JOIN (
+                        SELECT
+                        user_id,
+                        last_seen
+                        FROM user_ips
+                    ) uip
+                    ON users.name = uip.user_id
+                    AND appservice_id is NULL
+                    AND users.creation_ts < ?
+                    AND uip.last_seen/1000 > ?
+                    AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+                    GROUP BY users.name, users.creation_ts
+                ) u
+            """
+
+            txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+            count, = txn.fetchone()
+            results["all"] = count
+
+            return results
+
+        return self.runInteraction("count_r30_users", _count_r30_users)
+
+    def _get_start_of_day(self):
+        """
+        Returns millisecond unixtime for start of UTC day.
+        """
+        now = time.gmtime()
+        today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
+        return today_start * 1000
+
+    def generate_user_daily_visits(self):
+        """
+        Generates daily visit data for use in cohort/ retention analysis
+        """
+
+        def _generate_user_daily_visits(txn):
+            logger.info("Calling _generate_user_daily_visits")
+            today_start = self._get_start_of_day()
+            a_day_in_milliseconds = 24 * 60 * 60 * 1000
+            now = self.clock.time_msec()
+
+            sql = """
+                INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+                    SELECT u.user_id, u.device_id, ?
+                    FROM user_ips AS u
+                    LEFT JOIN (
+                      SELECT user_id, device_id, timestamp FROM user_daily_visits
+                      WHERE timestamp = ?
+                    ) udv
+                    ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+                    INNER JOIN users ON users.name=u.user_id
+                    WHERE last_seen > ? AND last_seen <= ?
+                    AND udv.timestamp IS NULL AND users.is_guest=0
+                    AND users.appservice_id IS NULL
+                    GROUP BY u.user_id, u.device_id
+            """
+
+            # This means that the day has rolled over but there could still
+            # be entries from the previous day. There is an edge case
+            # where if the user logs in at 23:59 and overwrites their
+            # last_seen at 00:01 then they will not be counted in the
+            # previous day's stats - it is important that the query is run
+            # often to minimise this case.
+            if today_start > self._last_user_visit_update:
+                yesterday_start = today_start - a_day_in_milliseconds
+                txn.execute(
+                    sql,
+                    (
+                        yesterday_start,
+                        yesterday_start,
+                        self._last_user_visit_update,
+                        today_start,
+                    ),
+                )
+                self._last_user_visit_update = today_start
+
+            txn.execute(
+                sql, (today_start, today_start, self._last_user_visit_update, now)
+            )
+            # Update _last_user_visit_update to now. The reason to do this
+            # rather just clamping to the beginning of the day is to limit
+            # the size of the join - meaning that the query can be run more
+            # frequently
+            self._last_user_visit_update = now
+
+        return self.runInteraction(
+            "generate_user_daily_visits", _generate_user_daily_visits
+        )
+
+    def get_users(self):
+        """Function to reterive a list of users in users table.
+
+        Args:
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self._simple_select_list(
+            table="users",
+            keyvalues={},
+            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
+            desc="get_users",
+        )
+
+    @defer.inlineCallbacks
+    def get_users_paginate(self, order, start, limit):
+        """Function to reterive a paginated list of users from
+        users list. This will return a json object, which contains
+        list of users and the total number of users in users table.
+
+        Args:
+            order (str): column name to order the select by this column
+            start (int): start number to begin the query from
+            limit (int): number of rows to reterive
+        Returns:
+            defer.Deferred: resolves to json object {list[dict[str, Any]], count}
+        """
+        users = yield self.runInteraction(
+            "get_users_paginate",
+            self._simple_select_list_paginate_txn,
+            table="users",
+            keyvalues={"is_guest": False},
+            orderby=order,
+            start=start,
+            limit=limit,
+            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
+        )
+        count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
+        retval = {"users": users, "total": count}
+        return retval
+
+    def search_users(self, term):
+        """Function to search users list for one or more users with
+        the matched term.
+
+        Args:
+            term (str): search term
+            col (str): column to query term should be matched to
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self._simple_search_list(
+            table="users",
+            term=term,
+            col="name",
+            retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
+            desc="search_users",
+        )
diff --git a/synapse/storage/account_data.py b/synapse/storage/data_stores/main/account_data.py
similarity index 100%
rename from synapse/storage/account_data.py
rename to synapse/storage/data_stores/main/account_data.py
diff --git a/synapse/storage/appservice.py b/synapse/storage/data_stores/main/appservice.py
similarity index 99%
rename from synapse/storage/appservice.py
rename to synapse/storage/data_stores/main/appservice.py
index 435b2acd4d..81babf2029 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -22,9 +22,8 @@ from twisted.internet import defer
 
 from synapse.appservice import AppServiceTransaction
 from synapse.config.appservice import load_appservices
-from synapse.storage.events_worker import EventsWorkerStore
-
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/storage/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
similarity index 99%
rename from synapse/storage/client_ips.py
rename to synapse/storage/data_stores/main/client_ips.py
index 067820a5da..706c6a1f3f 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -20,11 +20,10 @@ from six import iteritems
 from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage import background_updates
+from synapse.storage._base import Cache
 from synapse.util.caches import CACHE_SIZE_FACTOR
 
-from . import background_updates
-from ._base import Cache
-
 logger = logging.getLogger(__name__)
 
 # Number of msec of granularity to store the user IP 'last seen' time. Smaller
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
similarity index 100%
rename from synapse/storage/deviceinbox.py
rename to synapse/storage/data_stores/main/deviceinbox.py
diff --git a/synapse/storage/devices.py b/synapse/storage/data_stores/main/devices.py
similarity index 100%
rename from synapse/storage/devices.py
rename to synapse/storage/data_stores/main/devices.py
diff --git a/synapse/storage/directory.py b/synapse/storage/data_stores/main/directory.py
similarity index 99%
rename from synapse/storage/directory.py
rename to synapse/storage/data_stores/main/directory.py
index eed7757ed5..297966d9f4 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/data_stores/main/directory.py
@@ -18,10 +18,9 @@ from collections import namedtuple
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached
 
-from ._base import SQLBaseStore
-
 RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
 
 
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
similarity index 99%
rename from synapse/storage/e2e_room_keys.py
rename to synapse/storage/data_stores/main/e2e_room_keys.py
index be2fe2bab6..ef88e79293 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -19,8 +19,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.logging.opentracing import log_kv, trace
-
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
 
 
 class EndToEndRoomKeyStore(SQLBaseStore):
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
similarity index 99%
rename from synapse/storage/end_to_end_keys.py
rename to synapse/storage/data_stores/main/end_to_end_keys.py
index 872bc75490..7e44f41046 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -19,10 +19,9 @@ from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.util.caches.descriptors import cached
 
-from ._base import SQLBaseStore, db_to_json
-
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
     @trace
diff --git a/synapse/storage/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
similarity index 99%
rename from synapse/storage/event_federation.py
rename to synapse/storage/data_stores/main/event_federation.py
index 47cc10d32a..a470a48e0f 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -26,8 +26,8 @@ from twisted.internet import defer
 from synapse.api.errors import StoreError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.storage.signatures import SignatureWorkerStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
similarity index 100%
rename from synapse/storage/event_push_actions.py
rename to synapse/storage/data_stores/main/event_push_actions.py
diff --git a/synapse/storage/events.py b/synapse/storage/data_stores/main/events.py
similarity index 99%
rename from synapse/storage/events.py
rename to synapse/storage/data_stores/main/events.py
index ee49ef235d..03b5111c5d 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -41,9 +41,9 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.state import StateResolutionStore
 from synapse.storage._base import make_in_list_sql_clause
 from synapse.storage.background_updates import BackgroundUpdateStore
-from synapse.storage.event_federation import EventFederationStore
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.storage.state import StateGroupWorkerStore
+from synapse.storage.data_stores.main.event_federation import EventFederationStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.state import StateGroupWorkerStore
 from synapse.types import RoomStreamToken, get_domain_from_id
 from synapse.util import batch_iter
 from synapse.util.async_helpers import ObservableDeferred
diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
similarity index 100%
rename from synapse/storage/events_bg_updates.py
rename to synapse/storage/data_stores/main/events_bg_updates.py
diff --git a/synapse/storage/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
similarity index 100%
rename from synapse/storage/events_worker.py
rename to synapse/storage/data_stores/main/events_worker.py
diff --git a/synapse/storage/filtering.py b/synapse/storage/data_stores/main/filtering.py
similarity index 97%
rename from synapse/storage/filtering.py
rename to synapse/storage/data_stores/main/filtering.py
index 7c2a7da836..a2a2a67927 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/data_stores/main/filtering.py
@@ -16,10 +16,9 @@
 from canonicaljson import encode_canonical_json
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 
-from ._base import SQLBaseStore, db_to_json
-
 
 class FilteringStore(SQLBaseStore):
     @cachedInlineCallbacks(num_args=2)
diff --git a/synapse/storage/group_server.py b/synapse/storage/data_stores/main/group_server.py
similarity index 99%
rename from synapse/storage/group_server.py
rename to synapse/storage/data_stores/main/group_server.py
index 15b01c6958..aeae5a2b28 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -19,8 +19,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
-
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
 
 # The category ID for the "default" category. We don't store as null in the
 # database to avoid the fun of null != null
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
new file mode 100644
index 0000000000..ebc7db3ed6
--- /dev/null
+++ b/synapse/storage/data_stores/main/keys.py
@@ -0,0 +1,214 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+
+import six
+
+from signedjson.key import decode_verify_key_bytes
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.keys import FetchKeyResult
+from synapse.util import batch_iter
+from synapse.util.caches.descriptors import cached, cachedList
+
+logger = logging.getLogger(__name__)
+
+# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
+# despite being deprecated and removed in favor of memoryview
+if six.PY2:
+    db_binary_type = six.moves.builtins.buffer
+else:
+    db_binary_type = memoryview
+
+
+class KeyStore(SQLBaseStore):
+    """Persistence for signature verification keys
+    """
+
+    @cached()
+    def _get_server_verify_key(self, server_name_and_key_id):
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
+    )
+    def get_server_verify_keys(self, server_name_and_key_ids):
+        """
+        Args:
+            server_name_and_key_ids (iterable[Tuple[str, str]]):
+                iterable of (server_name, key-id) tuples to fetch keys for
+
+        Returns:
+            Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
+                map from (server_name, key_id) -> FetchKeyResult, or None if the key is
+                unknown
+        """
+        keys = {}
+
+        def _get_keys(txn, batch):
+            """Processes a batch of keys to fetch, and adds the result to `keys`."""
+
+            # batch_iter always returns tuples so it's safe to do len(batch)
+            sql = (
+                "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
+                "FROM server_signature_keys WHERE 1=0"
+            ) + " OR (server_name=? AND key_id=?)" * len(batch)
+
+            txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
+
+            for row in txn:
+                server_name, key_id, key_bytes, ts_valid_until_ms = row
+
+                if ts_valid_until_ms is None:
+                    # Old keys may be stored with a ts_valid_until_ms of null,
+                    # in which case we treat this as if it was set to `0`, i.e.
+                    # it won't match key requests that define a minimum
+                    # `ts_valid_until_ms`.
+                    ts_valid_until_ms = 0
+
+                res = FetchKeyResult(
+                    verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
+                    valid_until_ts=ts_valid_until_ms,
+                )
+                keys[(server_name, key_id)] = res
+
+        def _txn(txn):
+            for batch in batch_iter(server_name_and_key_ids, 50):
+                _get_keys(txn, batch)
+            return keys
+
+        return self.runInteraction("get_server_verify_keys", _txn)
+
+    def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+        """Stores NACL verification keys for remote servers.
+        Args:
+            from_server (str): Where the verification keys were looked up
+            ts_added_ms (int): The time to record that the key was added
+            verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+                keys to be stored. Each entry is a triplet of
+                (server_name, key_id, key).
+        """
+        key_values = []
+        value_values = []
+        invalidations = []
+        for server_name, key_id, fetch_result in verify_keys:
+            key_values.append((server_name, key_id))
+            value_values.append(
+                (
+                    from_server,
+                    ts_added_ms,
+                    fetch_result.valid_until_ts,
+                    db_binary_type(fetch_result.verify_key.encode()),
+                )
+            )
+            # invalidate takes a tuple corresponding to the params of
+            # _get_server_verify_key. _get_server_verify_key only takes one
+            # param, which is itself the 2-tuple (server_name, key_id).
+            invalidations.append((server_name, key_id))
+
+        def _invalidate(res):
+            f = self._get_server_verify_key.invalidate
+            for i in invalidations:
+                f((i,))
+            return res
+
+        return self.runInteraction(
+            "store_server_verify_keys",
+            self._simple_upsert_many_txn,
+            table="server_signature_keys",
+            key_names=("server_name", "key_id"),
+            key_values=key_values,
+            value_names=(
+                "from_server",
+                "ts_added_ms",
+                "ts_valid_until_ms",
+                "verify_key",
+            ),
+            value_values=value_values,
+        ).addCallback(_invalidate)
+
+    def store_server_keys_json(
+        self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
+    ):
+        """Stores the JSON bytes for a set of keys from a server
+        The JSON should be signed by the originating server, the intermediate
+        server, and by this server. Updates the value for the
+        (server_name, key_id, from_server) triplet if one already existed.
+        Args:
+            server_name (str): The name of the server.
+            key_id (str): The identifer of the key this JSON is for.
+            from_server (str): The server this JSON was fetched from.
+            ts_now_ms (int): The time now in milliseconds.
+            ts_valid_until_ms (int): The time when this json stops being valid.
+            key_json (bytes): The encoded JSON.
+        """
+        return self._simple_upsert(
+            table="server_keys_json",
+            keyvalues={
+                "server_name": server_name,
+                "key_id": key_id,
+                "from_server": from_server,
+            },
+            values={
+                "server_name": server_name,
+                "key_id": key_id,
+                "from_server": from_server,
+                "ts_added_ms": ts_now_ms,
+                "ts_valid_until_ms": ts_expires_ms,
+                "key_json": db_binary_type(key_json_bytes),
+            },
+            desc="store_server_keys_json",
+        )
+
+    def get_server_keys_json(self, server_keys):
+        """Retrive the key json for a list of server_keys and key ids.
+        If no keys are found for a given server, key_id and source then
+        that server, key_id, and source triplet entry will be an empty list.
+        The JSON is returned as a byte array so that it can be efficiently
+        used in an HTTP response.
+        Args:
+            server_keys (list): List of (server_name, key_id, source) triplets.
+        Returns:
+            Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
+                Dict mapping (server_name, key_id, source) triplets to lists of dicts
+        """
+
+        def _get_server_keys_json_txn(txn):
+            results = {}
+            for server_name, key_id, from_server in server_keys:
+                keyvalues = {"server_name": server_name}
+                if key_id is not None:
+                    keyvalues["key_id"] = key_id
+                if from_server is not None:
+                    keyvalues["from_server"] = from_server
+                rows = self._simple_select_list_txn(
+                    txn,
+                    "server_keys_json",
+                    keyvalues=keyvalues,
+                    retcols=(
+                        "key_id",
+                        "from_server",
+                        "ts_added_ms",
+                        "ts_valid_until_ms",
+                        "key_json",
+                    ),
+                )
+                results[(server_name, key_id, from_server)] = rows
+            return results
+
+        return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
similarity index 100%
rename from synapse/storage/media_repository.py
rename to synapse/storage/data_stores/main/media_repository.py
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
similarity index 99%
rename from synapse/storage/monthly_active_users.py
rename to synapse/storage/data_stores/main/monthly_active_users.py
index 3803604be7..e6ee1e4aaa 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -16,10 +16,9 @@ import logging
 
 from twisted.internet import defer
 
+from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached
 
-from ._base import SQLBaseStore
-
 logger = logging.getLogger(__name__)
 
 # Number of msec of granularity to store the monthly_active_user timestamp
diff --git a/synapse/storage/openid.py b/synapse/storage/data_stores/main/openid.py
similarity index 95%
rename from synapse/storage/openid.py
rename to synapse/storage/data_stores/main/openid.py
index b3318045ee..79b40044d9 100644
--- a/synapse/storage/openid.py
+++ b/synapse/storage/data_stores/main/openid.py
@@ -1,4 +1,4 @@
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
 
 
 class OpenIdStore(SQLBaseStore):
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
new file mode 100644
index 0000000000..523ed6575e
--- /dev/null
+++ b/synapse/storage/data_stores/main/presence.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.presence import UserPresenceState
+from synapse.util import batch_iter
+from synapse.util.caches.descriptors import cached, cachedList
+
+
+class PresenceStore(SQLBaseStore):
+    @defer.inlineCallbacks
+    def update_presence(self, presence_states):
+        stream_ordering_manager = self._presence_id_gen.get_next_mult(
+            len(presence_states)
+        )
+
+        with stream_ordering_manager as stream_orderings:
+            yield self.runInteraction(
+                "update_presence",
+                self._update_presence_txn,
+                stream_orderings,
+                presence_states,
+            )
+
+        return stream_orderings[-1], self._presence_id_gen.get_current_token()
+
+    def _update_presence_txn(self, txn, stream_orderings, presence_states):
+        for stream_id, state in zip(stream_orderings, presence_states):
+            txn.call_after(
+                self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
+            )
+            txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
+
+        # Actually insert new rows
+        self._simple_insert_many_txn(
+            txn,
+            table="presence_stream",
+            values=[
+                {
+                    "stream_id": stream_id,
+                    "user_id": state.user_id,
+                    "state": state.state,
+                    "last_active_ts": state.last_active_ts,
+                    "last_federation_update_ts": state.last_federation_update_ts,
+                    "last_user_sync_ts": state.last_user_sync_ts,
+                    "status_msg": state.status_msg,
+                    "currently_active": state.currently_active,
+                }
+                for state in presence_states
+            ],
+        )
+
+        # Delete old rows to stop database from getting really big
+        sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
+
+        for states in batch_iter(presence_states, 50):
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "user_id", [s.user_id for s in states]
+            )
+            txn.execute(sql + clause, [stream_id] + list(args))
+
+    def get_all_presence_updates(self, last_id, current_id):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_presence_updates_txn(txn):
+            sql = (
+                "SELECT stream_id, user_id, state, last_active_ts,"
+                " last_federation_update_ts, last_user_sync_ts, status_msg,"
+                " currently_active"
+                " FROM presence_stream"
+                " WHERE ? < stream_id AND stream_id <= ?"
+            )
+            txn.execute(sql, (last_id, current_id))
+            return txn.fetchall()
+
+        return self.runInteraction(
+            "get_all_presence_updates", get_all_presence_updates_txn
+        )
+
+    @cached()
+    def _get_presence_for_user(self, user_id):
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_presence_for_user",
+        list_name="user_ids",
+        num_args=1,
+        inlineCallbacks=True,
+    )
+    def get_presence_for_users(self, user_ids):
+        rows = yield self._simple_select_many_batch(
+            table="presence_stream",
+            column="user_id",
+            iterable=user_ids,
+            keyvalues={},
+            retcols=(
+                "user_id",
+                "state",
+                "last_active_ts",
+                "last_federation_update_ts",
+                "last_user_sync_ts",
+                "status_msg",
+                "currently_active",
+            ),
+            desc="get_presence_for_users",
+        )
+
+        for row in rows:
+            row["currently_active"] = bool(row["currently_active"])
+
+        return {row["user_id"]: UserPresenceState(**row) for row in rows}
+
+    def get_current_presence_token(self):
+        return self._presence_id_gen.get_current_token()
+
+    def allow_presence_visible(self, observed_localpart, observer_userid):
+        return self._simple_insert(
+            table="presence_allow_inbound",
+            values={
+                "observed_user_id": observed_localpart,
+                "observer_user_id": observer_userid,
+            },
+            desc="allow_presence_visible",
+            or_ignore=True,
+        )
+
+    def disallow_presence_visible(self, observed_localpart, observer_userid):
+        return self._simple_delete_one(
+            table="presence_allow_inbound",
+            keyvalues={
+                "observed_user_id": observed_localpart,
+                "observer_user_id": observer_userid,
+            },
+            desc="disallow_presence_visible",
+        )
diff --git a/synapse/storage/profile.py b/synapse/storage/data_stores/main/profile.py
similarity index 98%
rename from synapse/storage/profile.py
rename to synapse/storage/data_stores/main/profile.py
index 912c1df6be..e4e8a1c1d6 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -16,9 +16,8 @@
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
-from synapse.storage.roommember import ProfileInfo
-
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.roommember import ProfileInfo
 
 
 class ProfileWorkerStore(SQLBaseStore):
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
new file mode 100644
index 0000000000..cd95f1ce60
--- /dev/null
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -0,0 +1,713 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import logging
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.push.baserules import list_with_base_rules
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
+from synapse.storage.data_stores.main.pusher import PusherWorkerStore
+from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
+from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+logger = logging.getLogger(__name__)
+
+
+def _load_rules(rawrules, enabled_map):
+    ruleslist = []
+    for rawrule in rawrules:
+        rule = dict(rawrule)
+        rule["conditions"] = json.loads(rawrule["conditions"])
+        rule["actions"] = json.loads(rawrule["actions"])
+        ruleslist.append(rule)
+
+    # We're going to be mutating this a lot, so do a deep copy
+    rules = list(list_with_base_rules(ruleslist))
+
+    for i, rule in enumerate(rules):
+        rule_id = rule["rule_id"]
+        if rule_id in enabled_map:
+            if rule.get("enabled", True) != bool(enabled_map[rule_id]):
+                # Rules are cached across users.
+                rule = dict(rule)
+                rule["enabled"] = bool(enabled_map[rule_id])
+                rules[i] = rule
+
+    return rules
+
+
+class PushRulesWorkerStore(
+    ApplicationServiceWorkerStore,
+    ReceiptsWorkerStore,
+    PusherWorkerStore,
+    RoomMemberWorkerStore,
+    SQLBaseStore,
+):
+    """This is an abstract base class where subclasses must implement
+    `get_max_push_rules_stream_id` which can be called in the initializer.
+    """
+
+    # This ABCMeta metaclass ensures that we cannot be instantiated without
+    # the abstract methods being implemented.
+    __metaclass__ = abc.ABCMeta
+
+    def __init__(self, db_conn, hs):
+        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
+
+        push_rules_prefill, push_rules_id = self._get_cache_dict(
+            db_conn,
+            "push_rules_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self.get_max_push_rules_stream_id(),
+        )
+
+        self.push_rules_stream_cache = StreamChangeCache(
+            "PushRulesStreamChangeCache",
+            push_rules_id,
+            prefilled_cache=push_rules_prefill,
+        )
+
+    @abc.abstractmethod
+    def get_max_push_rules_stream_id(self):
+        """Get the position of the push rules stream.
+
+        Returns:
+            int
+        """
+        raise NotImplementedError()
+
+    @cachedInlineCallbacks(max_entries=5000)
+    def get_push_rules_for_user(self, user_id):
+        rows = yield self._simple_select_list(
+            table="push_rules",
+            keyvalues={"user_name": user_id},
+            retcols=(
+                "user_name",
+                "rule_id",
+                "priority_class",
+                "priority",
+                "conditions",
+                "actions",
+            ),
+            desc="get_push_rules_enabled_for_user",
+        )
+
+        rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+
+        enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
+
+        rules = _load_rules(rows, enabled_map)
+
+        return rules
+
+    @cachedInlineCallbacks(max_entries=5000)
+    def get_push_rules_enabled_for_user(self, user_id):
+        results = yield self._simple_select_list(
+            table="push_rules_enable",
+            keyvalues={"user_name": user_id},
+            retcols=("user_name", "rule_id", "enabled"),
+            desc="get_push_rules_enabled_for_user",
+        )
+        return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
+
+    def have_push_rules_changed_for_user(self, user_id, last_id):
+        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+            return defer.succeed(False)
+        else:
+
+            def have_push_rules_changed_txn(txn):
+                sql = (
+                    "SELECT COUNT(stream_id) FROM push_rules_stream"
+                    " WHERE user_id = ? AND ? < stream_id"
+                )
+                txn.execute(sql, (user_id, last_id))
+                count, = txn.fetchone()
+                return bool(count)
+
+            return self.runInteraction(
+                "have_push_rules_changed", have_push_rules_changed_txn
+            )
+
+    @cachedList(
+        cached_method_name="get_push_rules_for_user",
+        list_name="user_ids",
+        num_args=1,
+        inlineCallbacks=True,
+    )
+    def bulk_get_push_rules(self, user_ids):
+        if not user_ids:
+            return {}
+
+        results = {user_id: [] for user_id in user_ids}
+
+        rows = yield self._simple_select_many_batch(
+            table="push_rules",
+            column="user_name",
+            iterable=user_ids,
+            retcols=("*",),
+            desc="bulk_get_push_rules",
+        )
+
+        rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+
+        for row in rows:
+            results.setdefault(row["user_name"], []).append(row)
+
+        enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+
+        for user_id, rules in results.items():
+            results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
+
+        return results
+
+    @defer.inlineCallbacks
+    def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
+        """Copy a single push rule from one room to another for a specific user.
+
+        Args:
+            new_room_id (str): ID of the new room.
+            user_id (str): ID of user the push rule belongs to.
+            rule (Dict): A push rule.
+        """
+        # Create new rule id
+        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+        new_rule_id = rule_id_scope + "/" + new_room_id
+
+        # Change room id in each condition
+        for condition in rule.get("conditions", []):
+            if condition.get("key") == "room_id":
+                condition["pattern"] = new_room_id
+
+        # Add the rule for the new room
+        yield self.add_push_rule(
+            user_id=user_id,
+            rule_id=new_rule_id,
+            priority_class=rule["priority_class"],
+            conditions=rule["conditions"],
+            actions=rule["actions"],
+        )
+
+    @defer.inlineCallbacks
+    def copy_push_rules_from_room_to_room_for_user(
+        self, old_room_id, new_room_id, user_id
+    ):
+        """Copy all of the push rules from one room to another for a specific
+        user.
+
+        Args:
+            old_room_id (str): ID of the old room.
+            new_room_id (str): ID of the new room.
+            user_id (str): ID of user to copy push rules for.
+        """
+        # Retrieve push rules for this user
+        user_push_rules = yield self.get_push_rules_for_user(user_id)
+
+        # Get rules relating to the old room and copy them to the new room
+        for rule in user_push_rules:
+            conditions = rule.get("conditions", [])
+            if any(
+                (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
+                for c in conditions
+            ):
+                yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
+
+    @defer.inlineCallbacks
+    def bulk_get_push_rules_for_room(self, event, context):
+        state_group = context.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        current_state_ids = yield context.get_current_state_ids(self)
+        result = yield self._bulk_get_push_rules_for_room(
+            event.room_id, state_group, current_state_ids, event=event
+        )
+        return result
+
+    @cachedInlineCallbacks(num_args=2, cache_context=True)
+    def _bulk_get_push_rules_for_room(
+        self, room_id, state_group, current_state_ids, cache_context, event=None
+    ):
+        # We don't use `state_group`, its there so that we can cache based
+        # on it. However, its important that its never None, since two current_state's
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        # We also will want to generate notifs for other people in the room so
+        # their unread countss are correct in the event stream, but to avoid
+        # generating them for bot / AS users etc, we only do so for people who've
+        # sent a read receipt into the room.
+
+        users_in_room = yield self._get_joined_users_from_context(
+            room_id,
+            state_group,
+            current_state_ids,
+            on_invalidate=cache_context.invalidate,
+            event=event,
+        )
+
+        # We ignore app service users for now. This is so that we don't fill
+        # up the `get_if_users_have_pushers` cache with AS entries that we
+        # know don't have pushers, nor even read receipts.
+        local_users_in_room = set(
+            u
+            for u in users_in_room
+            if self.hs.is_mine_id(u)
+            and not self.get_if_app_services_interested_in_user(u)
+        )
+
+        # users in the room who have pushers need to get push rules run because
+        # that's how their pushers work
+        if_users_with_pushers = yield self.get_if_users_have_pushers(
+            local_users_in_room, on_invalidate=cache_context.invalidate
+        )
+        user_ids = set(
+            uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
+        )
+
+        users_with_receipts = yield self.get_users_with_read_receipts_in_room(
+            room_id, on_invalidate=cache_context.invalidate
+        )
+
+        # any users with pushers must be ours: they have pushers
+        for uid in users_with_receipts:
+            if uid in local_users_in_room:
+                user_ids.add(uid)
+
+        rules_by_user = yield self.bulk_get_push_rules(
+            user_ids, on_invalidate=cache_context.invalidate
+        )
+
+        rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
+
+        return rules_by_user
+
+    @cachedList(
+        cached_method_name="get_push_rules_enabled_for_user",
+        list_name="user_ids",
+        num_args=1,
+        inlineCallbacks=True,
+    )
+    def bulk_get_push_rules_enabled(self, user_ids):
+        if not user_ids:
+            return {}
+
+        results = {user_id: {} for user_id in user_ids}
+
+        rows = yield self._simple_select_many_batch(
+            table="push_rules_enable",
+            column="user_name",
+            iterable=user_ids,
+            retcols=("user_name", "rule_id", "enabled"),
+            desc="bulk_get_push_rules_enabled",
+        )
+        for row in rows:
+            enabled = bool(row["enabled"])
+            results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
+        return results
+
+
+class PushRuleStore(PushRulesWorkerStore):
+    @defer.inlineCallbacks
+    def add_push_rule(
+        self,
+        user_id,
+        rule_id,
+        priority_class,
+        conditions,
+        actions,
+        before=None,
+        after=None,
+    ):
+        conditions_json = json.dumps(conditions)
+        actions_json = json.dumps(actions)
+        with self._push_rules_stream_id_gen.get_next() as ids:
+            stream_id, event_stream_ordering = ids
+            if before or after:
+                yield self.runInteraction(
+                    "_add_push_rule_relative_txn",
+                    self._add_push_rule_relative_txn,
+                    stream_id,
+                    event_stream_ordering,
+                    user_id,
+                    rule_id,
+                    priority_class,
+                    conditions_json,
+                    actions_json,
+                    before,
+                    after,
+                )
+            else:
+                yield self.runInteraction(
+                    "_add_push_rule_highest_priority_txn",
+                    self._add_push_rule_highest_priority_txn,
+                    stream_id,
+                    event_stream_ordering,
+                    user_id,
+                    rule_id,
+                    priority_class,
+                    conditions_json,
+                    actions_json,
+                )
+
+    def _add_push_rule_relative_txn(
+        self,
+        txn,
+        stream_id,
+        event_stream_ordering,
+        user_id,
+        rule_id,
+        priority_class,
+        conditions_json,
+        actions_json,
+        before,
+        after,
+    ):
+        # Lock the table since otherwise we'll have annoying races between the
+        # SELECT here and the UPSERT below.
+        self.database_engine.lock_table(txn, "push_rules")
+
+        relative_to_rule = before or after
+
+        res = self._simple_select_one_txn(
+            txn,
+            table="push_rules",
+            keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
+            retcols=["priority_class", "priority"],
+            allow_none=True,
+        )
+
+        if not res:
+            raise RuleNotFoundException(
+                "before/after rule not found: %s" % (relative_to_rule,)
+            )
+
+        base_priority_class = res["priority_class"]
+        base_rule_priority = res["priority"]
+
+        if base_priority_class != priority_class:
+            raise InconsistentRuleException(
+                "Given priority class does not match class of relative rule"
+            )
+
+        if before:
+            # Higher priority rules are executed first, So adding a rule before
+            # a rule means giving it a higher priority than that rule.
+            new_rule_priority = base_rule_priority + 1
+        else:
+            # We increment the priority of the existing rules to make space for
+            # the new rule. Therefore if we want this rule to appear after
+            # an existing rule we give it the priority of the existing rule,
+            # and then increment the priority of the existing rule.
+            new_rule_priority = base_rule_priority
+
+        sql = (
+            "UPDATE push_rules SET priority = priority + 1"
+            " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
+        )
+
+        txn.execute(sql, (user_id, priority_class, new_rule_priority))
+
+        self._upsert_push_rule_txn(
+            txn,
+            stream_id,
+            event_stream_ordering,
+            user_id,
+            rule_id,
+            priority_class,
+            new_rule_priority,
+            conditions_json,
+            actions_json,
+        )
+
+    def _add_push_rule_highest_priority_txn(
+        self,
+        txn,
+        stream_id,
+        event_stream_ordering,
+        user_id,
+        rule_id,
+        priority_class,
+        conditions_json,
+        actions_json,
+    ):
+        # Lock the table since otherwise we'll have annoying races between the
+        # SELECT here and the UPSERT below.
+        self.database_engine.lock_table(txn, "push_rules")
+
+        # find the highest priority rule in that class
+        sql = (
+            "SELECT COUNT(*), MAX(priority) FROM push_rules"
+            " WHERE user_name = ? and priority_class = ?"
+        )
+        txn.execute(sql, (user_id, priority_class))
+        res = txn.fetchall()
+        (how_many, highest_prio) = res[0]
+
+        new_prio = 0
+        if how_many > 0:
+            new_prio = highest_prio + 1
+
+        self._upsert_push_rule_txn(
+            txn,
+            stream_id,
+            event_stream_ordering,
+            user_id,
+            rule_id,
+            priority_class,
+            new_prio,
+            conditions_json,
+            actions_json,
+        )
+
+    def _upsert_push_rule_txn(
+        self,
+        txn,
+        stream_id,
+        event_stream_ordering,
+        user_id,
+        rule_id,
+        priority_class,
+        priority,
+        conditions_json,
+        actions_json,
+        update_stream=True,
+    ):
+        """Specialised version of _simple_upsert_txn that picks a push_rule_id
+        using the _push_rule_id_gen if it needs to insert the rule. It assumes
+        that the "push_rules" table is locked"""
+
+        sql = (
+            "UPDATE push_rules"
+            " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
+            " WHERE user_name = ? AND rule_id = ?"
+        )
+
+        txn.execute(
+            sql,
+            (priority_class, priority, conditions_json, actions_json, user_id, rule_id),
+        )
+
+        if txn.rowcount == 0:
+            # We didn't update a row with the given rule_id so insert one
+            push_rule_id = self._push_rule_id_gen.get_next()
+
+            self._simple_insert_txn(
+                txn,
+                table="push_rules",
+                values={
+                    "id": push_rule_id,
+                    "user_name": user_id,
+                    "rule_id": rule_id,
+                    "priority_class": priority_class,
+                    "priority": priority,
+                    "conditions": conditions_json,
+                    "actions": actions_json,
+                },
+            )
+
+        if update_stream:
+            self._insert_push_rules_update_txn(
+                txn,
+                stream_id,
+                event_stream_ordering,
+                user_id,
+                rule_id,
+                op="ADD",
+                data={
+                    "priority_class": priority_class,
+                    "priority": priority,
+                    "conditions": conditions_json,
+                    "actions": actions_json,
+                },
+            )
+
+    @defer.inlineCallbacks
+    def delete_push_rule(self, user_id, rule_id):
+        """
+        Delete a push rule. Args specify the row to be deleted and can be
+        any of the columns in the push_rule table, but below are the
+        standard ones
+
+        Args:
+            user_id (str): The matrix ID of the push rule owner
+            rule_id (str): The rule_id of the rule to be deleted
+        """
+
+        def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+            self._simple_delete_one_txn(
+                txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
+            )
+
+            self._insert_push_rules_update_txn(
+                txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
+            )
+
+        with self._push_rules_stream_id_gen.get_next() as ids:
+            stream_id, event_stream_ordering = ids
+            yield self.runInteraction(
+                "delete_push_rule",
+                delete_push_rule_txn,
+                stream_id,
+                event_stream_ordering,
+            )
+
+    @defer.inlineCallbacks
+    def set_push_rule_enabled(self, user_id, rule_id, enabled):
+        with self._push_rules_stream_id_gen.get_next() as ids:
+            stream_id, event_stream_ordering = ids
+            yield self.runInteraction(
+                "_set_push_rule_enabled_txn",
+                self._set_push_rule_enabled_txn,
+                stream_id,
+                event_stream_ordering,
+                user_id,
+                rule_id,
+                enabled,
+            )
+
+    def _set_push_rule_enabled_txn(
+        self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
+    ):
+        new_id = self._push_rules_enable_id_gen.get_next()
+        self._simple_upsert_txn(
+            txn,
+            "push_rules_enable",
+            {"user_name": user_id, "rule_id": rule_id},
+            {"enabled": 1 if enabled else 0},
+            {"id": new_id},
+        )
+
+        self._insert_push_rules_update_txn(
+            txn,
+            stream_id,
+            event_stream_ordering,
+            user_id,
+            rule_id,
+            op="ENABLE" if enabled else "DISABLE",
+        )
+
+    @defer.inlineCallbacks
+    def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+        actions_json = json.dumps(actions)
+
+        def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
+            if is_default_rule:
+                # Add a dummy rule to the rules table with the user specified
+                # actions.
+                priority_class = -1
+                priority = 1
+                self._upsert_push_rule_txn(
+                    txn,
+                    stream_id,
+                    event_stream_ordering,
+                    user_id,
+                    rule_id,
+                    priority_class,
+                    priority,
+                    "[]",
+                    actions_json,
+                    update_stream=False,
+                )
+            else:
+                self._simple_update_one_txn(
+                    txn,
+                    "push_rules",
+                    {"user_name": user_id, "rule_id": rule_id},
+                    {"actions": actions_json},
+                )
+
+            self._insert_push_rules_update_txn(
+                txn,
+                stream_id,
+                event_stream_ordering,
+                user_id,
+                rule_id,
+                op="ACTIONS",
+                data={"actions": actions_json},
+            )
+
+        with self._push_rules_stream_id_gen.get_next() as ids:
+            stream_id, event_stream_ordering = ids
+            yield self.runInteraction(
+                "set_push_rule_actions",
+                set_push_rule_actions_txn,
+                stream_id,
+                event_stream_ordering,
+            )
+
+    def _insert_push_rules_update_txn(
+        self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
+    ):
+        values = {
+            "stream_id": stream_id,
+            "event_stream_ordering": event_stream_ordering,
+            "user_id": user_id,
+            "rule_id": rule_id,
+            "op": op,
+        }
+        if data is not None:
+            values.update(data)
+
+        self._simple_insert_txn(txn, "push_rules_stream", values=values)
+
+        txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
+        txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
+        txn.call_after(
+            self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
+        )
+
+    def get_all_push_rule_updates(self, last_id, current_id, limit):
+        """Get all the push rules changes that have happend on the server"""
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_push_rule_updates_txn(txn):
+            sql = (
+                "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
+                " op, priority_class, priority, conditions, actions"
+                " FROM push_rules_stream"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            return txn.fetchall()
+
+        return self.runInteraction(
+            "get_all_push_rule_updates", get_all_push_rule_updates_txn
+        )
+
+    def get_push_rules_stream_token(self):
+        """Get the position of the push rules stream.
+        Returns a pair of a stream id for the push_rules stream and the
+        room stream ordering it corresponds to."""
+        return self._push_rules_stream_id_gen.get_current_token()
+
+    def get_max_push_rules_stream_id(self):
+        return self.get_push_rules_stream_token()[0]
diff --git a/synapse/storage/pusher.py b/synapse/storage/data_stores/main/pusher.py
similarity index 99%
rename from synapse/storage/pusher.py
rename to synapse/storage/data_stores/main/pusher.py
index b12e80440a..f005c1ae0a 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -22,10 +22,9 @@ from canonicaljson import encode_canonical_json, json
 
 from twisted.internet import defer
 
+from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
 
-from ._base import SQLBaseStore
-
 logger = logging.getLogger(__name__)
 
 if six.PY2:
diff --git a/synapse/storage/receipts.py b/synapse/storage/data_stores/main/receipts.py
similarity index 100%
rename from synapse/storage/receipts.py
rename to synapse/storage/data_stores/main/receipts.py
diff --git a/synapse/storage/registration.py b/synapse/storage/data_stores/main/registration.py
similarity index 100%
rename from synapse/storage/registration.py
rename to synapse/storage/data_stores/main/registration.py
diff --git a/synapse/storage/rejections.py b/synapse/storage/data_stores/main/rejections.py
similarity index 96%
rename from synapse/storage/rejections.py
rename to synapse/storage/data_stores/main/rejections.py
index f4c1c2a457..7d5de0ea2e 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/data_stores/main/rejections.py
@@ -15,7 +15,7 @@
 
 import logging
 
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py
new file mode 100644
index 0000000000..858f65582b
--- /dev/null
+++ b/synapse/storage/data_stores/main/relations.py
@@ -0,0 +1,385 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import attr
+
+from synapse.api.constants import RelationTypes
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.stream import generate_pagination_where_clause
+from synapse.storage.relations import (
+    AggregationPaginationToken,
+    PaginationChunk,
+    RelationPaginationToken,
+)
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+class RelationsWorkerStore(SQLBaseStore):
+    @cached(tree=True)
+    def get_relations_for_event(
+        self,
+        event_id,
+        relation_type=None,
+        event_type=None,
+        aggregation_key=None,
+        limit=5,
+        direction="b",
+        from_token=None,
+        to_token=None,
+    ):
+        """Get a list of relations for an event, ordered by topological ordering.
+
+        Args:
+            event_id (str): Fetch events that relate to this event ID.
+            relation_type (str|None): Only fetch events with this relation
+                type, if given.
+            event_type (str|None): Only fetch events with this event type, if
+                given.
+            aggregation_key (str|None): Only fetch events with this aggregation
+                key, if given.
+            limit (int): Only fetch the most recent `limit` events.
+            direction (str): Whether to fetch the most recent first (`"b"`) or
+                the oldest first (`"f"`).
+            from_token (RelationPaginationToken|None): Fetch rows from the given
+                token, or from the start if None.
+            to_token (RelationPaginationToken|None): Fetch rows up to the given
+                token, or up to the end if None.
+
+        Returns:
+            Deferred[PaginationChunk]: List of event IDs that match relations
+            requested. The rows are of the form `{"event_id": "..."}`.
+        """
+
+        where_clause = ["relates_to_id = ?"]
+        where_args = [event_id]
+
+        if relation_type is not None:
+            where_clause.append("relation_type = ?")
+            where_args.append(relation_type)
+
+        if event_type is not None:
+            where_clause.append("type = ?")
+            where_args.append(event_type)
+
+        if aggregation_key:
+            where_clause.append("aggregation_key = ?")
+            where_args.append(aggregation_key)
+
+        pagination_clause = generate_pagination_where_clause(
+            direction=direction,
+            column_names=("topological_ordering", "stream_ordering"),
+            from_token=attr.astuple(from_token) if from_token else None,
+            to_token=attr.astuple(to_token) if to_token else None,
+            engine=self.database_engine,
+        )
+
+        if pagination_clause:
+            where_clause.append(pagination_clause)
+
+        if direction == "b":
+            order = "DESC"
+        else:
+            order = "ASC"
+
+        sql = """
+            SELECT event_id, topological_ordering, stream_ordering
+            FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE %s
+            ORDER BY topological_ordering %s, stream_ordering %s
+            LIMIT ?
+        """ % (
+            " AND ".join(where_clause),
+            order,
+            order,
+        )
+
+        def _get_recent_references_for_event_txn(txn):
+            txn.execute(sql, where_args + [limit + 1])
+
+            last_topo_id = None
+            last_stream_id = None
+            events = []
+            for row in txn:
+                events.append({"event_id": row[0]})
+                last_topo_id = row[1]
+                last_stream_id = row[2]
+
+            next_batch = None
+            if len(events) > limit and last_topo_id and last_stream_id:
+                next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+
+            return PaginationChunk(
+                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+            )
+
+        return self.runInteraction(
+            "get_recent_references_for_event", _get_recent_references_for_event_txn
+        )
+
+    @cached(tree=True)
+    def get_aggregation_groups_for_event(
+        self,
+        event_id,
+        event_type=None,
+        limit=5,
+        direction="b",
+        from_token=None,
+        to_token=None,
+    ):
+        """Get a list of annotations on the event, grouped by event type and
+        aggregation key, sorted by count.
+
+        This is used e.g. to get the what and how many reactions have happend
+        on an event.
+
+        Args:
+            event_id (str): Fetch events that relate to this event ID.
+            event_type (str|None): Only fetch events with this event type, if
+                given.
+            limit (int): Only fetch the `limit` groups.
+            direction (str): Whether to fetch the highest count first (`"b"`) or
+                the lowest count first (`"f"`).
+            from_token (AggregationPaginationToken|None): Fetch rows from the
+                given token, or from the start if None.
+            to_token (AggregationPaginationToken|None): Fetch rows up to the
+                given token, or up to the end if None.
+
+
+        Returns:
+            Deferred[PaginationChunk]: List of groups of annotations that
+            match. Each row is a dict with `type`, `key` and `count` fields.
+        """
+
+        where_clause = ["relates_to_id = ?", "relation_type = ?"]
+        where_args = [event_id, RelationTypes.ANNOTATION]
+
+        if event_type:
+            where_clause.append("type = ?")
+            where_args.append(event_type)
+
+        having_clause = generate_pagination_where_clause(
+            direction=direction,
+            column_names=("COUNT(*)", "MAX(stream_ordering)"),
+            from_token=attr.astuple(from_token) if from_token else None,
+            to_token=attr.astuple(to_token) if to_token else None,
+            engine=self.database_engine,
+        )
+
+        if direction == "b":
+            order = "DESC"
+        else:
+            order = "ASC"
+
+        if having_clause:
+            having_clause = "HAVING " + having_clause
+        else:
+            having_clause = ""
+
+        sql = """
+            SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
+            FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE {where_clause}
+            GROUP BY relation_type, type, aggregation_key
+            {having_clause}
+            ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+            LIMIT ?
+        """.format(
+            where_clause=" AND ".join(where_clause),
+            order=order,
+            having_clause=having_clause,
+        )
+
+        def _get_aggregation_groups_for_event_txn(txn):
+            txn.execute(sql, where_args + [limit + 1])
+
+            next_batch = None
+            events = []
+            for row in txn:
+                events.append({"type": row[0], "key": row[1], "count": row[2]})
+                next_batch = AggregationPaginationToken(row[2], row[3])
+
+            if len(events) <= limit:
+                next_batch = None
+
+            return PaginationChunk(
+                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+            )
+
+        return self.runInteraction(
+            "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+        )
+
+    @cachedInlineCallbacks()
+    def get_applicable_edit(self, event_id):
+        """Get the most recent edit (if any) that has happened for the given
+        event.
+
+        Correctly handles checking whether edits were allowed to happen.
+
+        Args:
+            event_id (str): The original event ID
+
+        Returns:
+            Deferred[EventBase|None]: Returns the most recent edit, if any.
+        """
+
+        # We only allow edits for `m.room.message` events that have the same sender
+        # and event type. We can't assert these things during regular event auth so
+        # we have to do the checks post hoc.
+
+        # Fetches latest edit that has the same type and sender as the
+        # original, and is an `m.room.message`.
+        sql = """
+            SELECT edit.event_id FROM events AS edit
+            INNER JOIN event_relations USING (event_id)
+            INNER JOIN events AS original ON
+                original.event_id = relates_to_id
+                AND edit.type = original.type
+                AND edit.sender = original.sender
+            WHERE
+                relates_to_id = ?
+                AND relation_type = ?
+                AND edit.type = 'm.room.message'
+            ORDER by edit.origin_server_ts DESC, edit.event_id DESC
+            LIMIT 1
+        """
+
+        def _get_applicable_edit_txn(txn):
+            txn.execute(sql, (event_id, RelationTypes.REPLACE))
+            row = txn.fetchone()
+            if row:
+                return row[0]
+
+        edit_id = yield self.runInteraction(
+            "get_applicable_edit", _get_applicable_edit_txn
+        )
+
+        if not edit_id:
+            return
+
+        edit_event = yield self.get_event(edit_id, allow_none=True)
+        return edit_event
+
+    def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+        """Check if a user has already annotated an event with the same key
+        (e.g. already liked an event).
+
+        Args:
+            parent_id (str): The event being annotated
+            event_type (str): The event type of the annotation
+            aggregation_key (str): The aggregation key of the annotation
+            sender (str): The sender of the annotation
+
+        Returns:
+            Deferred[bool]
+        """
+
+        sql = """
+            SELECT 1 FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE
+                relates_to_id = ?
+                AND relation_type = ?
+                AND type = ?
+                AND sender = ?
+                AND aggregation_key = ?
+            LIMIT 1;
+        """
+
+        def _get_if_user_has_annotated_event(txn):
+            txn.execute(
+                sql,
+                (
+                    parent_id,
+                    RelationTypes.ANNOTATION,
+                    event_type,
+                    sender,
+                    aggregation_key,
+                ),
+            )
+
+            return bool(txn.fetchone())
+
+        return self.runInteraction(
+            "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
+        )
+
+
+class RelationsStore(RelationsWorkerStore):
+    def _handle_event_relations(self, txn, event):
+        """Handles inserting relation data during peristence of events
+
+        Args:
+            txn
+            event (EventBase)
+        """
+        relation = event.content.get("m.relates_to")
+        if not relation:
+            # No relations
+            return
+
+        rel_type = relation.get("rel_type")
+        if rel_type not in (
+            RelationTypes.ANNOTATION,
+            RelationTypes.REFERENCE,
+            RelationTypes.REPLACE,
+        ):
+            # Unknown relation type
+            return
+
+        parent_id = relation.get("event_id")
+        if not parent_id:
+            # Invalid relation
+            return
+
+        aggregation_key = relation.get("key")
+
+        self._simple_insert_txn(
+            txn,
+            table="event_relations",
+            values={
+                "event_id": event.event_id,
+                "relates_to_id": parent_id,
+                "relation_type": rel_type,
+                "aggregation_key": aggregation_key,
+            },
+        )
+
+        txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
+        txn.call_after(
+            self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
+        )
+
+        if rel_type == RelationTypes.REPLACE:
+            txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
+
+    def _handle_redaction(self, txn, redacted_event_id):
+        """Handles receiving a redaction and checking whether we need to remove
+        any redacted relations from the database.
+
+        Args:
+            txn
+            redacted_event_id (str): The event that was redacted.
+        """
+
+        self._simple_delete_txn(
+            txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
+        )
diff --git a/synapse/storage/room.py b/synapse/storage/data_stores/main/room.py
similarity index 99%
rename from synapse/storage/room.py
rename to synapse/storage/data_stores/main/room.py
index 43cc56fa6f..4428e5c55d 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -25,7 +25,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.search import SearchStore
+from synapse.storage.data_stores.main.search import SearchStore
 from synapse.types import ThirdPartyInstanceID
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
new file mode 100644
index 0000000000..e47ab604dd
--- /dev/null
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -0,0 +1,1145 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from six import iteritems, itervalues
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.engines import Sqlite3Engine
+from synapse.storage.roommember import (
+    GetRoomsForUserWithStreamOrdering,
+    MemberSummary,
+    ProfileInfo,
+    RoomsForUser,
+)
+from synapse.types import get_domain_from_id
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches import intern_string
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.metrics import Measure
+from synapse.util.stringutils import to_ascii
+
+logger = logging.getLogger(__name__)
+
+
+_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
+_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
+
+
+class RoomMemberWorkerStore(EventsWorkerStore):
+    def __init__(self, db_conn, hs):
+        super(RoomMemberWorkerStore, self).__init__(db_conn, hs)
+
+        # Is the current_state_events.membership up to date? Or is the
+        # background update still running?
+        self._current_state_events_membership_up_to_date = False
+
+        txn = LoggingTransaction(
+            db_conn.cursor(),
+            name="_check_safe_current_state_events_membership_updated",
+            database_engine=self.database_engine,
+        )
+        self._check_safe_current_state_events_membership_updated_txn(txn)
+        txn.close()
+
+        if self.hs.config.metrics_flags.known_servers:
+            self._known_servers_count = 1
+            self.hs.get_clock().looping_call(
+                run_as_background_process,
+                60 * 1000,
+                "_count_known_servers",
+                self._count_known_servers,
+            )
+            self.hs.get_clock().call_later(
+                1000,
+                run_as_background_process,
+                "_count_known_servers",
+                self._count_known_servers,
+            )
+            LaterGauge(
+                "synapse_federation_known_servers",
+                "",
+                [],
+                lambda: self._known_servers_count,
+            )
+
+    @defer.inlineCallbacks
+    def _count_known_servers(self):
+        """
+        Count the servers that this server knows about.
+
+        The statistic is stored on the class for the
+        `synapse_federation_known_servers` LaterGauge to collect.
+        """
+
+        def _transact(txn):
+            if isinstance(self.database_engine, Sqlite3Engine):
+                query = """
+                    SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
+                    FROM (
+                        SELECT rm.user_id as user_id, instr(rm.user_id, ':')
+                            AS pos FROM room_memberships as rm
+                        INNER JOIN current_state_events as c ON rm.event_id = c.event_id
+                        WHERE c.type = 'm.room.member'
+                    ) as out
+                """
+            else:
+                query = """
+                    SELECT COUNT(DISTINCT split_part(state_key, ':', 2))
+                    FROM current_state_events
+                    WHERE type = 'm.room.member' AND membership = 'join';
+                """
+            txn.execute(query)
+            return list(txn)[0][0]
+
+        count = yield self.runInteraction("get_known_servers", _transact)
+
+        # We always know about ourselves, even if we have nothing in
+        # room_memberships (for example, the server is new).
+        self._known_servers_count = max([count, 1])
+        return self._known_servers_count
+
+    def _check_safe_current_state_events_membership_updated_txn(self, txn):
+        """Checks if it is safe to assume the new current_state_events
+        membership column is up to date
+        """
+
+        pending_update = self._simple_select_one_txn(
+            txn,
+            table="background_updates",
+            keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
+            retcols=["update_name"],
+            allow_none=True,
+        )
+
+        self._current_state_events_membership_up_to_date = not pending_update
+
+        # If the update is still running, reschedule to run.
+        if pending_update:
+            self._clock.call_later(
+                15.0,
+                run_as_background_process,
+                "_check_safe_current_state_events_membership_updated",
+                self.runInteraction,
+                "_check_safe_current_state_events_membership_updated",
+                self._check_safe_current_state_events_membership_updated_txn,
+            )
+
+    @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
+    def get_hosts_in_room(self, room_id, cache_context):
+        """Returns the set of all hosts currently in the room
+        """
+        user_ids = yield self.get_users_in_room(
+            room_id, on_invalidate=cache_context.invalidate
+        )
+        hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
+        return hosts
+
+    @cached(max_entries=100000, iterable=True)
+    def get_users_in_room(self, room_id):
+        return self.runInteraction(
+            "get_users_in_room", self.get_users_in_room_txn, room_id
+        )
+
+    def get_users_in_room_txn(self, txn, room_id):
+        # If we can assume current_state_events.membership is up to date
+        # then we can avoid a join, which is a Very Good Thing given how
+        # frequently this function gets called.
+        if self._current_state_events_membership_up_to_date:
+            sql = """
+                SELECT state_key FROM current_state_events
+                WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
+            """
+        else:
+            sql = """
+                SELECT state_key FROM room_memberships as m
+                INNER JOIN current_state_events as c
+                ON m.event_id = c.event_id
+                AND m.room_id = c.room_id
+                AND m.user_id = c.state_key
+                WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
+            """
+
+        txn.execute(sql, (room_id, Membership.JOIN))
+        return [to_ascii(r[0]) for r in txn]
+
+    @cached(max_entries=100000)
+    def get_room_summary(self, room_id):
+        """ Get the details of a room roughly suitable for use by the room
+        summary extension to /sync. Useful when lazy loading room members.
+        Args:
+            room_id (str): The room ID to query
+        Returns:
+            Deferred[dict[str, MemberSummary]:
+                dict of membership states, pointing to a MemberSummary named tuple.
+        """
+
+        def _get_room_summary_txn(txn):
+            # first get counts.
+            # We do this all in one transaction to keep the cache small.
+            # FIXME: get rid of this when we have room_stats
+
+            # If we can assume current_state_events.membership is up to date
+            # then we can avoid a join, which is a Very Good Thing given how
+            # frequently this function gets called.
+            if self._current_state_events_membership_up_to_date:
+                # Note, rejected events will have a null membership field, so
+                # we we manually filter them out.
+                sql = """
+                    SELECT count(*), membership FROM current_state_events
+                    WHERE type = 'm.room.member' AND room_id = ?
+                        AND membership IS NOT NULL
+                    GROUP BY membership
+                """
+            else:
+                sql = """
+                    SELECT count(*), m.membership FROM room_memberships as m
+                    INNER JOIN current_state_events as c
+                    ON m.event_id = c.event_id
+                    AND m.room_id = c.room_id
+                    AND m.user_id = c.state_key
+                    WHERE c.type = 'm.room.member' AND c.room_id = ?
+                    GROUP BY m.membership
+                """
+
+            txn.execute(sql, (room_id,))
+            res = {}
+            for count, membership in txn:
+                summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
+
+            # we order by membership and then fairly arbitrarily by event_id so
+            # heroes are consistent
+            if self._current_state_events_membership_up_to_date:
+                # Note, rejected events will have a null membership field, so
+                # we we manually filter them out.
+                sql = """
+                    SELECT state_key, membership, event_id
+                    FROM current_state_events
+                    WHERE type = 'm.room.member' AND room_id = ?
+                        AND membership IS NOT NULL
+                    ORDER BY
+                        CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+                        event_id ASC
+                    LIMIT ?
+                """
+            else:
+                sql = """
+                    SELECT c.state_key, m.membership, c.event_id
+                    FROM room_memberships as m
+                    INNER JOIN current_state_events as c USING (room_id, event_id)
+                    WHERE c.type = 'm.room.member' AND c.room_id = ?
+                    ORDER BY
+                        CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+                        c.event_id ASC
+                    LIMIT ?
+                """
+
+            # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
+            txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
+            for user_id, membership, event_id in txn:
+                summary = res[to_ascii(membership)]
+                # we will always have a summary for this membership type at this
+                # point given the summary currently contains the counts.
+                members = summary.members
+                members.append((to_ascii(user_id), to_ascii(event_id)))
+
+            return res
+
+        return self.runInteraction("get_room_summary", _get_room_summary_txn)
+
+    def _get_user_counts_in_room_txn(self, txn, room_id):
+        """
+        Get the user count in a room by membership.
+
+        Args:
+            room_id (str)
+            membership (Membership)
+
+        Returns:
+            Deferred[int]
+        """
+        sql = """
+        SELECT m.membership, count(*) FROM room_memberships as m
+            INNER JOIN current_state_events as c USING(event_id)
+            WHERE c.type = 'm.room.member' AND c.room_id = ?
+            GROUP BY m.membership
+        """
+
+        txn.execute(sql, (room_id,))
+        return {row[0]: row[1] for row in txn}
+
+    @cached()
+    def get_invited_rooms_for_user(self, user_id):
+        """ Get all the rooms the user is invited to
+        Args:
+            user_id (str): The user ID.
+        Returns:
+            A deferred list of RoomsForUser.
+        """
+
+        return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
+
+    @defer.inlineCallbacks
+    def get_invite_for_user_in_room(self, user_id, room_id):
+        """Gets the invite for the given user and room
+
+        Args:
+            user_id (str)
+            room_id (str)
+
+        Returns:
+            Deferred: Resolves to either a RoomsForUser or None if no invite was
+                found.
+        """
+        invites = yield self.get_invited_rooms_for_user(user_id)
+        for invite in invites:
+            if invite.room_id == room_id:
+                return invite
+        return None
+
+    @defer.inlineCallbacks
+    def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
+        """ Get all the rooms for this user where the membership for this user
+        matches one in the membership list.
+
+        Filters out forgotten rooms.
+
+        Args:
+            user_id (str): The user ID.
+            membership_list (list): A list of synapse.api.constants.Membership
+            values which the user must be in.
+
+        Returns:
+            Deferred[list[RoomsForUser]]
+        """
+        if not membership_list:
+            return defer.succeed(None)
+
+        rooms = yield self.runInteraction(
+            "get_rooms_for_user_where_membership_is",
+            self._get_rooms_for_user_where_membership_is_txn,
+            user_id,
+            membership_list,
+        )
+
+        # Now we filter out forgotten rooms
+        forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
+        return [room for room in rooms if room.room_id not in forgotten_rooms]
+
+    def _get_rooms_for_user_where_membership_is_txn(
+        self, txn, user_id, membership_list
+    ):
+
+        do_invite = Membership.INVITE in membership_list
+        membership_list = [m for m in membership_list if m != Membership.INVITE]
+
+        results = []
+        if membership_list:
+            if self._current_state_events_membership_up_to_date:
+                clause, args = make_in_list_sql_clause(
+                    self.database_engine, "c.membership", membership_list
+                )
+                sql = """
+                    SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+                    FROM current_state_events AS c
+                    INNER JOIN events AS e USING (room_id, event_id)
+                    WHERE
+                        c.type = 'm.room.member'
+                        AND state_key = ?
+                        AND %s
+                """ % (
+                    clause,
+                )
+            else:
+                clause, args = make_in_list_sql_clause(
+                    self.database_engine, "m.membership", membership_list
+                )
+                sql = """
+                    SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
+                    FROM current_state_events AS c
+                    INNER JOIN room_memberships AS m USING (room_id, event_id)
+                    INNER JOIN events AS e USING (room_id, event_id)
+                    WHERE
+                        c.type = 'm.room.member'
+                        AND state_key = ?
+                        AND %s
+                """ % (
+                    clause,
+                )
+
+            txn.execute(sql, (user_id, *args))
+            results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
+
+        if do_invite:
+            sql = (
+                "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
+                " FROM local_invites as i"
+                " INNER JOIN events as e USING (event_id)"
+                " WHERE invitee = ? AND locally_rejected is NULL"
+                " AND replaced_by is NULL"
+            )
+
+            txn.execute(sql, (user_id,))
+            results.extend(
+                RoomsForUser(
+                    room_id=r["room_id"],
+                    sender=r["inviter"],
+                    event_id=r["event_id"],
+                    stream_ordering=r["stream_ordering"],
+                    membership=Membership.INVITE,
+                )
+                for r in self.cursor_to_dict(txn)
+            )
+
+        return results
+
+    @cachedInlineCallbacks(max_entries=500000, iterable=True)
+    def get_rooms_for_user_with_stream_ordering(self, user_id):
+        """Returns a set of room_ids the user is currently joined to
+
+        Args:
+            user_id (str)
+
+        Returns:
+            Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
+            the rooms the user is in currently, along with the stream ordering
+            of the most recent join for that user and room.
+        """
+        rooms = yield self.get_rooms_for_user_where_membership_is(
+            user_id, membership_list=[Membership.JOIN]
+        )
+        return frozenset(
+            GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+            for r in rooms
+        )
+
+    @defer.inlineCallbacks
+    def get_rooms_for_user(self, user_id, on_invalidate=None):
+        """Returns a set of room_ids the user is currently joined to
+        """
+        rooms = yield self.get_rooms_for_user_with_stream_ordering(
+            user_id, on_invalidate=on_invalidate
+        )
+        return frozenset(r.room_id for r in rooms)
+
+    @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
+    def get_users_who_share_room_with_user(self, user_id, cache_context):
+        """Returns the set of users who share a room with `user_id`
+        """
+        room_ids = yield self.get_rooms_for_user(
+            user_id, on_invalidate=cache_context.invalidate
+        )
+
+        user_who_share_room = set()
+        for room_id in room_ids:
+            user_ids = yield self.get_users_in_room(
+                room_id, on_invalidate=cache_context.invalidate
+            )
+            user_who_share_room.update(user_ids)
+
+        return user_who_share_room
+
+    @defer.inlineCallbacks
+    def get_joined_users_from_context(self, event, context):
+        state_group = context.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        current_state_ids = yield context.get_current_state_ids(self)
+        result = yield self._get_joined_users_from_context(
+            event.room_id, state_group, current_state_ids, event=event, context=context
+        )
+        return result
+
+    @defer.inlineCallbacks
+    def get_joined_users_from_state(self, room_id, state_entry):
+        state_group = state_entry.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        with Measure(self._clock, "get_joined_users_from_state"):
+            return (
+                yield self._get_joined_users_from_context(
+                    room_id, state_group, state_entry.state, context=state_entry
+                )
+            )
+
+    @cachedInlineCallbacks(
+        num_args=2, cache_context=True, iterable=True, max_entries=100000
+    )
+    def _get_joined_users_from_context(
+        self,
+        room_id,
+        state_group,
+        current_state_ids,
+        cache_context,
+        event=None,
+        context=None,
+    ):
+        # We don't use `state_group`, it's there so that we can cache based
+        # on it. However, it's important that it's never None, since two current_states
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        users_in_room = {}
+        member_event_ids = [
+            e_id
+            for key, e_id in iteritems(current_state_ids)
+            if key[0] == EventTypes.Member
+        ]
+
+        if context is not None:
+            # If we have a context with a delta from a previous state group,
+            # check if we also have the result from the previous group in cache.
+            # If we do then we can reuse that result and simply update it with
+            # any membership changes in `delta_ids`
+            if context.prev_group and context.delta_ids:
+                prev_res = self._get_joined_users_from_context.cache.get(
+                    (room_id, context.prev_group), None
+                )
+                if prev_res and isinstance(prev_res, dict):
+                    users_in_room = dict(prev_res)
+                    member_event_ids = [
+                        e_id
+                        for key, e_id in iteritems(context.delta_ids)
+                        if key[0] == EventTypes.Member
+                    ]
+                    for etype, state_key in context.delta_ids:
+                        users_in_room.pop(state_key, None)
+
+        # We check if we have any of the member event ids in the event cache
+        # before we ask the DB
+
+        # We don't update the event cache hit ratio as it completely throws off
+        # the hit ratio counts. After all, we don't populate the cache if we
+        # miss it here
+        event_map = self._get_events_from_cache(
+            member_event_ids, allow_rejected=False, update_metrics=False
+        )
+
+        missing_member_event_ids = []
+        for event_id in member_event_ids:
+            ev_entry = event_map.get(event_id)
+            if ev_entry:
+                if ev_entry.event.membership == Membership.JOIN:
+                    users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
+                        display_name=to_ascii(
+                            ev_entry.event.content.get("displayname", None)
+                        ),
+                        avatar_url=to_ascii(
+                            ev_entry.event.content.get("avatar_url", None)
+                        ),
+                    )
+            else:
+                missing_member_event_ids.append(event_id)
+
+        if missing_member_event_ids:
+            event_to_memberships = yield self._get_joined_profiles_from_event_ids(
+                missing_member_event_ids
+            )
+            users_in_room.update((row for row in event_to_memberships.values() if row))
+
+        if event is not None and event.type == EventTypes.Member:
+            if event.membership == Membership.JOIN:
+                if event.event_id in member_event_ids:
+                    users_in_room[to_ascii(event.state_key)] = ProfileInfo(
+                        display_name=to_ascii(event.content.get("displayname", None)),
+                        avatar_url=to_ascii(event.content.get("avatar_url", None)),
+                    )
+
+        return users_in_room
+
+    @cached(max_entries=10000)
+    def _get_joined_profile_from_event_id(self, event_id):
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_joined_profile_from_event_id",
+        list_name="event_ids",
+        inlineCallbacks=True,
+    )
+    def _get_joined_profiles_from_event_ids(self, event_ids):
+        """For given set of member event_ids check if they point to a join
+        event and if so return the associated user and profile info.
+
+        Args:
+            event_ids (Iterable[str]): The member event IDs to lookup
+
+        Returns:
+            Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+            to `user_id` and ProfileInfo (or None if not join event).
+        """
+
+        rows = yield self._simple_select_many_batch(
+            table="room_memberships",
+            column="event_id",
+            iterable=event_ids,
+            retcols=("user_id", "display_name", "avatar_url", "event_id"),
+            keyvalues={"membership": Membership.JOIN},
+            batch_size=500,
+            desc="_get_membership_from_event_ids",
+        )
+
+        return {
+            row["event_id"]: (
+                row["user_id"],
+                ProfileInfo(
+                    avatar_url=row["avatar_url"], display_name=row["display_name"]
+                ),
+            )
+            for row in rows
+        }
+
+    @cachedInlineCallbacks(max_entries=10000)
+    def is_host_joined(self, room_id, host):
+        if "%" in host or "_" in host:
+            raise Exception("Invalid host name")
+
+        sql = """
+            SELECT state_key FROM current_state_events AS c
+            INNER JOIN room_memberships AS m USING (event_id)
+            WHERE m.membership = 'join'
+                AND type = 'm.room.member'
+                AND c.room_id = ?
+                AND state_key LIKE ?
+            LIMIT 1
+        """
+
+        # We do need to be careful to ensure that host doesn't have any wild cards
+        # in it, but we checked above for known ones and we'll check below that
+        # the returned user actually has the correct domain.
+        like_clause = "%:" + host
+
+        rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
+
+        if not rows:
+            return False
+
+        user_id = rows[0][0]
+        if get_domain_from_id(user_id) != host:
+            # This can only happen if the host name has something funky in it
+            raise Exception("Invalid host name")
+
+        return True
+
+    @cachedInlineCallbacks()
+    def was_host_joined(self, room_id, host):
+        """Check whether the server is or ever was in the room.
+
+        Args:
+            room_id (str)
+            host (str)
+
+        Returns:
+            Deferred: Resolves to True if the host is/was in the room, otherwise
+            False.
+        """
+        if "%" in host or "_" in host:
+            raise Exception("Invalid host name")
+
+        sql = """
+            SELECT user_id FROM room_memberships
+            WHERE room_id = ?
+                AND user_id LIKE ?
+                AND membership = 'join'
+            LIMIT 1
+        """
+
+        # We do need to be careful to ensure that host doesn't have any wild cards
+        # in it, but we checked above for known ones and we'll check below that
+        # the returned user actually has the correct domain.
+        like_clause = "%:" + host
+
+        rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
+
+        if not rows:
+            return False
+
+        user_id = rows[0][0]
+        if get_domain_from_id(user_id) != host:
+            # This can only happen if the host name has something funky in it
+            raise Exception("Invalid host name")
+
+        return True
+
+    @defer.inlineCallbacks
+    def get_joined_hosts(self, room_id, state_entry):
+        state_group = state_entry.state_group
+        if not state_group:
+            # If state_group is None it means it has yet to be assigned a
+            # state group, i.e. we need to make sure that calls with a state_group
+            # of None don't hit previous cached calls with a None state_group.
+            # To do this we set the state_group to a new object as object() != object()
+            state_group = object()
+
+        with Measure(self._clock, "get_joined_hosts"):
+            return (
+                yield self._get_joined_hosts(
+                    room_id, state_group, state_entry.state, state_entry=state_entry
+                )
+            )
+
+    @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
+    # @defer.inlineCallbacks
+    def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
+        # We don't use `state_group`, its there so that we can cache based
+        # on it. However, its important that its never None, since two current_state's
+        # with a state_group of None are likely to be different.
+        # See bulk_get_push_rules_for_room for how we work around this.
+        assert state_group is not None
+
+        cache = self._get_joined_hosts_cache(room_id)
+        joined_hosts = yield cache.get_destinations(state_entry)
+
+        return joined_hosts
+
+    @cached(max_entries=10000)
+    def _get_joined_hosts_cache(self, room_id):
+        return _JoinedHostsCache(self, room_id)
+
+    @cachedInlineCallbacks(num_args=2)
+    def did_forget(self, user_id, room_id):
+        """Returns whether user_id has elected to discard history for room_id.
+
+        Returns False if they have since re-joined."""
+
+        def f(txn):
+            sql = (
+                "SELECT"
+                "  COUNT(*)"
+                " FROM"
+                "  room_memberships"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+                " AND"
+                "  forgotten = 0"
+            )
+            txn.execute(sql, (user_id, room_id))
+            rows = txn.fetchall()
+            return rows[0][0]
+
+        count = yield self.runInteraction("did_forget_membership", f)
+        return count == 0
+
+    @cached()
+    def get_forgotten_rooms_for_user(self, user_id):
+        """Gets all rooms the user has forgotten.
+
+        Args:
+            user_id (str)
+
+        Returns:
+            Deferred[set[str]]
+        """
+
+        def _get_forgotten_rooms_for_user_txn(txn):
+            # This is a slightly convoluted query that first looks up all rooms
+            # that the user has forgotten in the past, then rechecks that list
+            # to see if any have subsequently been updated. This is done so that
+            # we can use a partial index on `forgotten = 1` on the assumption
+            # that few users will actually forget many rooms.
+            #
+            # Note that a room is considered "forgotten" if *all* membership
+            # events for that user and room have the forgotten field set (as
+            # when a user forgets a room we update all rows for that user and
+            # room, not just the current one).
+            sql = """
+                SELECT room_id, (
+                    SELECT count(*) FROM room_memberships
+                    WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0
+                ) AS count
+                FROM room_memberships AS m
+                WHERE user_id = ? AND forgotten = 1
+                GROUP BY room_id, user_id;
+            """
+            txn.execute(sql, (user_id,))
+            return set(row[0] for row in txn if row[1] == 0)
+
+        return self.runInteraction(
+            "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
+        )
+
+    @defer.inlineCallbacks
+    def get_rooms_user_has_been_in(self, user_id):
+        """Get all rooms that the user has ever been in.
+
+        Args:
+            user_id (str)
+
+        Returns:
+            Deferred[set[str]]: Set of room IDs.
+        """
+
+        room_ids = yield self._simple_select_onecol(
+            table="room_memberships",
+            keyvalues={"membership": Membership.JOIN, "user_id": user_id},
+            retcol="room_id",
+            desc="get_rooms_user_has_been_in",
+        )
+
+        return set(room_ids)
+
+
+class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
+    def __init__(self, db_conn, hs):
+        super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
+        self.register_background_update_handler(
+            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+        )
+        self.register_background_update_handler(
+            _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+            self._background_current_state_membership,
+        )
+        self.register_background_index_update(
+            "room_membership_forgotten_idx",
+            index_name="room_memberships_user_room_forgotten",
+            table="room_memberships",
+            columns=["user_id", "room_id"],
+            where_clause="forgotten = 1",
+        )
+
+    @defer.inlineCallbacks
+    def _background_add_membership_profile(self, progress, batch_size):
+        target_min_stream_id = progress.get(
+            "target_min_stream_id_inclusive", self._min_stream_order_on_start
+        )
+        max_stream_id = progress.get(
+            "max_stream_id_exclusive", self._stream_order_on_start + 1
+        )
+
+        INSERT_CLUMP_SIZE = 1000
+
+        def add_membership_profile_txn(txn):
+            sql = """
+                SELECT stream_ordering, event_id, events.room_id, event_json.json
+                FROM events
+                INNER JOIN event_json USING (event_id)
+                INNER JOIN room_memberships USING (event_id)
+                WHERE ? <= stream_ordering AND stream_ordering < ?
+                AND type = 'm.room.member'
+                ORDER BY stream_ordering DESC
+                LIMIT ?
+            """
+
+            txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+            rows = self.cursor_to_dict(txn)
+            if not rows:
+                return 0
+
+            min_stream_id = rows[-1]["stream_ordering"]
+
+            to_update = []
+            for row in rows:
+                event_id = row["event_id"]
+                room_id = row["room_id"]
+                try:
+                    event_json = json.loads(row["json"])
+                    content = event_json["content"]
+                except Exception:
+                    continue
+
+                display_name = content.get("displayname", None)
+                avatar_url = content.get("avatar_url", None)
+
+                if display_name or avatar_url:
+                    to_update.append((display_name, avatar_url, event_id, room_id))
+
+            to_update_sql = """
+                UPDATE room_memberships SET display_name = ?, avatar_url = ?
+                WHERE event_id = ? AND room_id = ?
+            """
+            for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
+                clump = to_update[index : index + INSERT_CLUMP_SIZE]
+                txn.executemany(to_update_sql, clump)
+
+            progress = {
+                "target_min_stream_id_inclusive": target_min_stream_id,
+                "max_stream_id_exclusive": min_stream_id,
+            }
+
+            self._background_update_progress_txn(
+                txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
+            )
+
+            return len(rows)
+
+        result = yield self.runInteraction(
+            _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
+        )
+
+        if not result:
+            yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
+
+        return result
+
+    @defer.inlineCallbacks
+    def _background_current_state_membership(self, progress, batch_size):
+        """Update the new membership column on current_state_events.
+
+        This works by iterating over all rooms in alphebetical order.
+        """
+
+        def _background_current_state_membership_txn(txn, last_processed_room):
+            processed = 0
+            while processed < batch_size:
+                txn.execute(
+                    """
+                        SELECT MIN(room_id) FROM current_state_events WHERE room_id > ?
+                    """,
+                    (last_processed_room,),
+                )
+                row = txn.fetchone()
+                if not row or not row[0]:
+                    return processed, True
+
+                next_room, = row
+
+                sql = """
+                    UPDATE current_state_events
+                    SET membership = (
+                        SELECT membership FROM room_memberships
+                        WHERE event_id = current_state_events.event_id
+                    )
+                    WHERE room_id = ?
+                """
+                txn.execute(sql, (next_room,))
+                processed += txn.rowcount
+
+                last_processed_room = next_room
+
+            self._background_update_progress_txn(
+                txn,
+                _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+                {"last_processed_room": last_processed_room},
+            )
+
+            return processed, False
+
+        # If we haven't got a last processed room then just use the empty
+        # string, which will compare before all room IDs correctly.
+        last_processed_room = progress.get("last_processed_room", "")
+
+        row_count, finished = yield self.runInteraction(
+            "_background_current_state_membership_update",
+            _background_current_state_membership_txn,
+            last_processed_room,
+        )
+
+        if finished:
+            yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME)
+
+        return row_count
+
+
+class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
+    def __init__(self, db_conn, hs):
+        super(RoomMemberStore, self).__init__(db_conn, hs)
+
+    def _store_room_members_txn(self, txn, events, backfilled):
+        """Store a room member in the database.
+        """
+        self._simple_insert_many_txn(
+            txn,
+            table="room_memberships",
+            values=[
+                {
+                    "event_id": event.event_id,
+                    "user_id": event.state_key,
+                    "sender": event.user_id,
+                    "room_id": event.room_id,
+                    "membership": event.membership,
+                    "display_name": event.content.get("displayname", None),
+                    "avatar_url": event.content.get("avatar_url", None),
+                }
+                for event in events
+            ],
+        )
+
+        for event in events:
+            txn.call_after(
+                self._membership_stream_cache.entity_has_changed,
+                event.state_key,
+                event.internal_metadata.stream_ordering,
+            )
+            txn.call_after(
+                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+            )
+
+            # We update the local_invites table only if the event is "current",
+            # i.e., its something that has just happened. If the event is an
+            # outlier it is only current if its an "out of band membership",
+            # like a remote invite or a rejection of a remote invite.
+            is_new_state = not backfilled and (
+                not event.internal_metadata.is_outlier()
+                or event.internal_metadata.is_out_of_band_membership()
+            )
+            is_mine = self.hs.is_mine_id(event.state_key)
+            if is_new_state and is_mine:
+                if event.membership == Membership.INVITE:
+                    self._simple_insert_txn(
+                        txn,
+                        table="local_invites",
+                        values={
+                            "event_id": event.event_id,
+                            "invitee": event.state_key,
+                            "inviter": event.sender,
+                            "room_id": event.room_id,
+                            "stream_id": event.internal_metadata.stream_ordering,
+                        },
+                    )
+                else:
+                    sql = (
+                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+                        " AND replaced_by is NULL"
+                    )
+
+                    txn.execute(
+                        sql,
+                        (
+                            event.internal_metadata.stream_ordering,
+                            event.event_id,
+                            event.room_id,
+                            event.state_key,
+                        ),
+                    )
+
+    @defer.inlineCallbacks
+    def locally_reject_invite(self, user_id, room_id):
+        sql = (
+            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+            " AND replaced_by is NULL"
+        )
+
+        def f(txn, stream_ordering):
+            txn.execute(sql, (stream_ordering, True, room_id, user_id))
+
+        with self._stream_id_gen.get_next() as stream_ordering:
+            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+
+    def forget(self, user_id, room_id):
+        """Indicate that user_id wishes to discard history for room_id."""
+
+        def f(txn):
+            sql = (
+                "UPDATE"
+                "  room_memberships"
+                " SET"
+                "  forgotten = 1"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+            )
+            txn.execute(sql, (user_id, room_id))
+
+            self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
+            self._invalidate_cache_and_stream(
+                txn, self.get_forgotten_rooms_for_user, (user_id,)
+            )
+
+        return self.runInteraction("forget_membership", f)
+
+
+class _JoinedHostsCache(object):
+    """Cache for joined hosts in a room that is optimised to handle updates
+    via state deltas.
+    """
+
+    def __init__(self, store, room_id):
+        self.store = store
+        self.room_id = room_id
+
+        self.hosts_to_joined_users = {}
+
+        self.state_group = object()
+
+        self.linearizer = Linearizer("_JoinedHostsCache")
+
+        self._len = 0
+
+    @defer.inlineCallbacks
+    def get_destinations(self, state_entry):
+        """Get set of destinations for a state entry
+
+        Args:
+            state_entry(synapse.state._StateCacheEntry)
+        """
+        if state_entry.state_group == self.state_group:
+            return frozenset(self.hosts_to_joined_users)
+
+        with (yield self.linearizer.queue(())):
+            if state_entry.state_group == self.state_group:
+                pass
+            elif state_entry.prev_group == self.state_group:
+                for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
+                    if typ != EventTypes.Member:
+                        continue
+
+                    host = intern_string(get_domain_from_id(state_key))
+                    user_id = state_key
+                    known_joins = self.hosts_to_joined_users.setdefault(host, set())
+
+                    event = yield self.store.get_event(event_id)
+                    if event.membership == Membership.JOIN:
+                        known_joins.add(user_id)
+                    else:
+                        known_joins.discard(user_id)
+
+                        if not known_joins:
+                            self.hosts_to_joined_users.pop(host, None)
+            else:
+                joined_users = yield self.store.get_joined_users_from_state(
+                    self.room_id, state_entry
+                )
+
+                self.hosts_to_joined_users = {}
+                for user_id in joined_users:
+                    host = intern_string(get_domain_from_id(user_id))
+                    self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+            if state_entry.state_group:
+                self.state_group = state_entry.state_group
+            else:
+                self.state_group = object()
+            self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
+        return frozenset(self.hosts_to_joined_users)
+
+    def __len__(self):
+        return self._len
diff --git a/synapse/storage/schema/delta/12/v12.sql b/synapse/storage/data_stores/main/schema/delta/12/v12.sql
similarity index 100%
rename from synapse/storage/schema/delta/12/v12.sql
rename to synapse/storage/data_stores/main/schema/delta/12/v12.sql
diff --git a/synapse/storage/schema/delta/13/v13.sql b/synapse/storage/data_stores/main/schema/delta/13/v13.sql
similarity index 100%
rename from synapse/storage/schema/delta/13/v13.sql
rename to synapse/storage/data_stores/main/schema/delta/13/v13.sql
diff --git a/synapse/storage/schema/delta/14/v14.sql b/synapse/storage/data_stores/main/schema/delta/14/v14.sql
similarity index 100%
rename from synapse/storage/schema/delta/14/v14.sql
rename to synapse/storage/data_stores/main/schema/delta/14/v14.sql
diff --git a/synapse/storage/schema/delta/15/appservice_txns.sql b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql
similarity index 100%
rename from synapse/storage/schema/delta/15/appservice_txns.sql
rename to synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql
diff --git a/synapse/storage/schema/delta/15/presence_indices.sql b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql
similarity index 100%
rename from synapse/storage/schema/delta/15/presence_indices.sql
rename to synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql
diff --git a/synapse/storage/schema/delta/15/v15.sql b/synapse/storage/data_stores/main/schema/delta/15/v15.sql
similarity index 100%
rename from synapse/storage/schema/delta/15/v15.sql
rename to synapse/storage/data_stores/main/schema/delta/15/v15.sql
diff --git a/synapse/storage/schema/delta/16/events_order_index.sql b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/16/events_order_index.sql
rename to synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql
diff --git a/synapse/storage/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/16/remote_media_cache_index.sql
rename to synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql
diff --git a/synapse/storage/schema/delta/16/remove_duplicates.sql b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql
similarity index 100%
rename from synapse/storage/schema/delta/16/remove_duplicates.sql
rename to synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql
diff --git a/synapse/storage/schema/delta/16/room_alias_index.sql b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/16/room_alias_index.sql
rename to synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql
diff --git a/synapse/storage/schema/delta/16/unique_constraints.sql b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql
similarity index 100%
rename from synapse/storage/schema/delta/16/unique_constraints.sql
rename to synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql
diff --git a/synapse/storage/schema/delta/16/users.sql b/synapse/storage/data_stores/main/schema/delta/16/users.sql
similarity index 100%
rename from synapse/storage/schema/delta/16/users.sql
rename to synapse/storage/data_stores/main/schema/delta/16/users.sql
diff --git a/synapse/storage/schema/delta/17/drop_indexes.sql b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql
similarity index 100%
rename from synapse/storage/schema/delta/17/drop_indexes.sql
rename to synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql
diff --git a/synapse/storage/schema/delta/17/server_keys.sql b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql
similarity index 100%
rename from synapse/storage/schema/delta/17/server_keys.sql
rename to synapse/storage/data_stores/main/schema/delta/17/server_keys.sql
diff --git a/synapse/storage/schema/delta/17/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql
similarity index 100%
rename from synapse/storage/schema/delta/17/user_threepids.sql
rename to synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql
diff --git a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql
similarity index 100%
rename from synapse/storage/schema/delta/18/server_keys_bigger_ints.sql
rename to synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql
diff --git a/synapse/storage/schema/delta/19/event_index.sql b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/19/event_index.sql
rename to synapse/storage/data_stores/main/schema/delta/19/event_index.sql
diff --git a/synapse/storage/schema/delta/20/dummy.sql b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql
similarity index 100%
rename from synapse/storage/schema/delta/20/dummy.sql
rename to synapse/storage/data_stores/main/schema/delta/20/dummy.sql
diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/data_stores/main/schema/delta/20/pushers.py
similarity index 100%
rename from synapse/storage/schema/delta/20/pushers.py
rename to synapse/storage/data_stores/main/schema/delta/20/pushers.py
diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql
similarity index 100%
rename from synapse/storage/schema/delta/21/end_to_end_keys.sql
rename to synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql
diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql
similarity index 100%
rename from synapse/storage/schema/delta/21/receipts.sql
rename to synapse/storage/data_stores/main/schema/delta/21/receipts.sql
diff --git a/synapse/storage/schema/delta/22/receipts_index.sql b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/22/receipts_index.sql
rename to synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql
diff --git a/synapse/storage/schema/delta/22/user_threepids_unique.sql b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql
similarity index 100%
rename from synapse/storage/schema/delta/22/user_threepids_unique.sql
rename to synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql
diff --git a/synapse/storage/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/23/drop_state_index.sql
rename to synapse/storage/data_stores/main/schema/delta/23/drop_state_index.sql
diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql
similarity index 100%
rename from synapse/storage/schema/delta/24/stats_reporting.sql
rename to synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/25/00background_updates.sql b/synapse/storage/data_stores/main/schema/delta/25/00background_updates.sql
new file mode 100644
index 0000000000..2ad9e8fa56
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/25/00background_updates.sql
@@ -0,0 +1,21 @@
+/* Copyright 2015, 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE IF NOT EXISTS background_updates(
+    update_name TEXT NOT NULL, -- The name of the background update.
+    progress_json TEXT NOT NULL, -- The current progress of the update as JSON.
+    CONSTRAINT background_updates_uniqueness UNIQUE (update_name)
+);
diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/data_stores/main/schema/delta/25/fts.py
similarity index 100%
rename from synapse/storage/schema/delta/25/fts.py
rename to synapse/storage/data_stores/main/schema/delta/25/fts.py
diff --git a/synapse/storage/schema/delta/25/guest_access.sql b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql
similarity index 100%
rename from synapse/storage/schema/delta/25/guest_access.sql
rename to synapse/storage/data_stores/main/schema/delta/25/guest_access.sql
diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql
similarity index 100%
rename from synapse/storage/schema/delta/25/history_visibility.sql
rename to synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql
diff --git a/synapse/storage/schema/delta/25/tags.sql b/synapse/storage/data_stores/main/schema/delta/25/tags.sql
similarity index 100%
rename from synapse/storage/schema/delta/25/tags.sql
rename to synapse/storage/data_stores/main/schema/delta/25/tags.sql
diff --git a/synapse/storage/schema/delta/26/account_data.sql b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql
similarity index 100%
rename from synapse/storage/schema/delta/26/account_data.sql
rename to synapse/storage/data_stores/main/schema/delta/26/account_data.sql
diff --git a/synapse/storage/schema/delta/27/account_data.sql b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql
similarity index 100%
rename from synapse/storage/schema/delta/27/account_data.sql
rename to synapse/storage/data_stores/main/schema/delta/27/account_data.sql
diff --git a/synapse/storage/schema/delta/27/forgotten_memberships.sql b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql
similarity index 100%
rename from synapse/storage/schema/delta/27/forgotten_memberships.sql
rename to synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/data_stores/main/schema/delta/27/ts.py
similarity index 100%
rename from synapse/storage/schema/delta/27/ts.py
rename to synapse/storage/data_stores/main/schema/delta/27/ts.py
diff --git a/synapse/storage/schema/delta/28/event_push_actions.sql b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql
similarity index 100%
rename from synapse/storage/schema/delta/28/event_push_actions.sql
rename to synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql
diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql
similarity index 100%
rename from synapse/storage/schema/delta/28/events_room_stream.sql
rename to synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql
diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/28/public_roms_index.sql
rename to synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql
diff --git a/synapse/storage/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/28/receipts_user_id_index.sql
rename to synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql
diff --git a/synapse/storage/schema/delta/28/upgrade_times.sql b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql
similarity index 100%
rename from synapse/storage/schema/delta/28/upgrade_times.sql
rename to synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql
diff --git a/synapse/storage/schema/delta/28/users_is_guest.sql b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql
similarity index 100%
rename from synapse/storage/schema/delta/28/users_is_guest.sql
rename to synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql
diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql
similarity index 100%
rename from synapse/storage/schema/delta/29/push_actions.sql
rename to synapse/storage/data_stores/main/schema/delta/29/push_actions.sql
diff --git a/synapse/storage/schema/delta/30/alias_creator.sql b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql
similarity index 100%
rename from synapse/storage/schema/delta/30/alias_creator.sql
rename to synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/data_stores/main/schema/delta/30/as_users.py
similarity index 100%
rename from synapse/storage/schema/delta/30/as_users.py
rename to synapse/storage/data_stores/main/schema/delta/30/as_users.py
diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql
similarity index 100%
rename from synapse/storage/schema/delta/30/deleted_pushers.sql
rename to synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql
diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql
similarity index 100%
rename from synapse/storage/schema/delta/30/presence_stream.sql
rename to synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql
diff --git a/synapse/storage/schema/delta/30/public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql
similarity index 100%
rename from synapse/storage/schema/delta/30/public_rooms.sql
rename to synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql
diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql
similarity index 100%
rename from synapse/storage/schema/delta/30/push_rule_stream.sql
rename to synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql
diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/state_stream.sql
similarity index 100%
rename from synapse/storage/schema/delta/30/state_stream.sql
rename to synapse/storage/data_stores/main/schema/delta/30/state_stream.sql
diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql
similarity index 100%
rename from synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql
rename to synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql
diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/data_stores/main/schema/delta/31/invites.sql
similarity index 100%
rename from synapse/storage/schema/delta/31/invites.sql
rename to synapse/storage/data_stores/main/schema/delta/31/invites.sql
diff --git a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql
similarity index 100%
rename from synapse/storage/schema/delta/31/local_media_repository_url_cache.sql
rename to synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql
diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/data_stores/main/schema/delta/31/pushers.py
similarity index 100%
rename from synapse/storage/schema/delta/31/pushers.py
rename to synapse/storage/data_stores/main/schema/delta/31/pushers.py
diff --git a/synapse/storage/schema/delta/31/pushers_index.sql b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/31/pushers_index.sql
rename to synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/data_stores/main/schema/delta/31/search_update.py
similarity index 100%
rename from synapse/storage/schema/delta/31/search_update.py
rename to synapse/storage/data_stores/main/schema/delta/31/search_update.py
diff --git a/synapse/storage/schema/delta/32/events.sql b/synapse/storage/data_stores/main/schema/delta/32/events.sql
similarity index 100%
rename from synapse/storage/schema/delta/32/events.sql
rename to synapse/storage/data_stores/main/schema/delta/32/events.sql
diff --git a/synapse/storage/schema/delta/32/openid.sql b/synapse/storage/data_stores/main/schema/delta/32/openid.sql
similarity index 100%
rename from synapse/storage/schema/delta/32/openid.sql
rename to synapse/storage/data_stores/main/schema/delta/32/openid.sql
diff --git a/synapse/storage/schema/delta/32/pusher_throttle.sql b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql
similarity index 100%
rename from synapse/storage/schema/delta/32/pusher_throttle.sql
rename to synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql
diff --git a/synapse/storage/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
similarity index 100%
rename from synapse/storage/schema/delta/32/remove_indices.sql
rename to synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
diff --git a/synapse/storage/schema/delta/32/reports.sql b/synapse/storage/data_stores/main/schema/delta/32/reports.sql
similarity index 100%
rename from synapse/storage/schema/delta/32/reports.sql
rename to synapse/storage/data_stores/main/schema/delta/32/reports.sql
diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/33/access_tokens_device_index.sql
rename to synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql
diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/data_stores/main/schema/delta/33/devices.sql
similarity index 100%
rename from synapse/storage/schema/delta/33/devices.sql
rename to synapse/storage/data_stores/main/schema/delta/33/devices.sql
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql
similarity index 100%
rename from synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
rename to synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
similarity index 100%
rename from synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
rename to synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
similarity index 100%
rename from synapse/storage/schema/delta/33/event_fields.py
rename to synapse/storage/data_stores/main/schema/delta/33/event_fields.py
diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py
similarity index 100%
rename from synapse/storage/schema/delta/33/remote_media_ts.py
rename to synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py
diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/33/user_ips_index.sql
rename to synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql
diff --git a/synapse/storage/schema/delta/34/appservice_stream.sql b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql
similarity index 100%
rename from synapse/storage/schema/delta/34/appservice_stream.sql
rename to synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql
diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py
similarity index 100%
rename from synapse/storage/schema/delta/34/cache_stream.py
rename to synapse/storage/data_stores/main/schema/delta/34/cache_stream.py
diff --git a/synapse/storage/schema/delta/34/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql
similarity index 100%
rename from synapse/storage/schema/delta/34/device_inbox.sql
rename to synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql
diff --git a/synapse/storage/schema/delta/34/push_display_name_rename.sql b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql
similarity index 100%
rename from synapse/storage/schema/delta/34/push_display_name_rename.sql
rename to synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql
diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py
similarity index 100%
rename from synapse/storage/schema/delta/34/received_txn_purge.py
rename to synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py
diff --git a/synapse/storage/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql
similarity index 92%
rename from synapse/storage/schema/delta/35/add_state_index.sql
rename to synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql
index 0fce26345b..33980d02f0 100644
--- a/synapse/storage/schema/delta/35/add_state_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/add_state_index.sql
@@ -13,8 +13,5 @@
  * limitations under the License.
  */
 
-
-ALTER TABLE background_updates ADD COLUMN depends_on TEXT;
-
 INSERT into background_updates (update_name, progress_json, depends_on)
     VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication');
diff --git a/synapse/storage/schema/delta/35/contains_url.sql b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql
similarity index 100%
rename from synapse/storage/schema/delta/35/contains_url.sql
rename to synapse/storage/data_stores/main/schema/delta/35/contains_url.sql
diff --git a/synapse/storage/schema/delta/35/device_outbox.sql b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql
similarity index 100%
rename from synapse/storage/schema/delta/35/device_outbox.sql
rename to synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql
diff --git a/synapse/storage/schema/delta/35/device_stream_id.sql b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql
similarity index 100%
rename from synapse/storage/schema/delta/35/device_stream_id.sql
rename to synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql
diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/35/event_push_actions_index.sql
rename to synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql
diff --git a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql
similarity index 100%
rename from synapse/storage/schema/delta/35/public_room_list_change_stream.sql
rename to synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql
diff --git a/synapse/storage/schema/delta/35/state.sql b/synapse/storage/data_stores/main/schema/delta/35/state.sql
similarity index 100%
rename from synapse/storage/schema/delta/35/state.sql
rename to synapse/storage/data_stores/main/schema/delta/35/state.sql
diff --git a/synapse/storage/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql
similarity index 100%
rename from synapse/storage/schema/delta/35/state_dedupe.sql
rename to synapse/storage/data_stores/main/schema/delta/35/state_dedupe.sql
diff --git a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql
similarity index 100%
rename from synapse/storage/schema/delta/35/stream_order_to_extrem.sql
rename to synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql
diff --git a/synapse/storage/schema/delta/36/readd_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql
similarity index 100%
rename from synapse/storage/schema/delta/36/readd_public_rooms.sql
rename to synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql
diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py
similarity index 100%
rename from synapse/storage/schema/delta/37/remove_auth_idx.py
rename to synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py
diff --git a/synapse/storage/schema/delta/37/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql
similarity index 100%
rename from synapse/storage/schema/delta/37/user_threepids.sql
rename to synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql
diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql
similarity index 100%
rename from synapse/storage/schema/delta/38/postgres_fts_gist.sql
rename to synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql
diff --git a/synapse/storage/schema/delta/39/appservice_room_list.sql b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql
similarity index 100%
rename from synapse/storage/schema/delta/39/appservice_room_list.sql
rename to synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql
diff --git a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql
similarity index 100%
rename from synapse/storage/schema/delta/39/device_federation_stream_idx.sql
rename to synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql
diff --git a/synapse/storage/schema/delta/39/event_push_index.sql b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/39/event_push_index.sql
rename to synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql
diff --git a/synapse/storage/schema/delta/39/federation_out_position.sql b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql
similarity index 100%
rename from synapse/storage/schema/delta/39/federation_out_position.sql
rename to synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql
diff --git a/synapse/storage/schema/delta/39/membership_profile.sql b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql
similarity index 100%
rename from synapse/storage/schema/delta/39/membership_profile.sql
rename to synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql
diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql
similarity index 100%
rename from synapse/storage/schema/delta/40/current_state_idx.sql
rename to synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql
diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql
similarity index 100%
rename from synapse/storage/schema/delta/40/device_inbox.sql
rename to synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql
diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql
similarity index 100%
rename from synapse/storage/schema/delta/40/device_list_streams.sql
rename to synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql
diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql
similarity index 100%
rename from synapse/storage/schema/delta/40/event_push_summary.sql
rename to synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql
diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql
similarity index 100%
rename from synapse/storage/schema/delta/40/pushers.sql
rename to synapse/storage/data_stores/main/schema/delta/40/pushers.sql
diff --git a/synapse/storage/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql
similarity index 100%
rename from synapse/storage/schema/delta/41/device_list_stream_idx.sql
rename to synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql
diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/41/device_outbound_index.sql
rename to synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql
diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql
similarity index 100%
rename from synapse/storage/schema/delta/41/event_search_event_id_idx.sql
rename to synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql
diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql
similarity index 100%
rename from synapse/storage/schema/delta/41/ratelimit.sql
rename to synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql
diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql
similarity index 100%
rename from synapse/storage/schema/delta/42/current_state_delta.sql
rename to synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql
diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql
similarity index 100%
rename from synapse/storage/schema/delta/42/device_list_last_id.sql
rename to synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql
diff --git a/synapse/storage/schema/delta/42/event_auth_state_only.sql b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql
similarity index 100%
rename from synapse/storage/schema/delta/42/event_auth_state_only.sql
rename to synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql
diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py
similarity index 100%
rename from synapse/storage/schema/delta/42/user_dir.py
rename to synapse/storage/data_stores/main/schema/delta/42/user_dir.py
diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql
similarity index 100%
rename from synapse/storage/schema/delta/43/blocked_rooms.sql
rename to synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql
diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql
similarity index 100%
rename from synapse/storage/schema/delta/43/quarantine_media.sql
rename to synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql
diff --git a/synapse/storage/schema/delta/43/url_cache.sql b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql
similarity index 100%
rename from synapse/storage/schema/delta/43/url_cache.sql
rename to synapse/storage/data_stores/main/schema/delta/43/url_cache.sql
diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql
similarity index 100%
rename from synapse/storage/schema/delta/43/user_share.sql
rename to synapse/storage/data_stores/main/schema/delta/43/user_share.sql
diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql
similarity index 100%
rename from synapse/storage/schema/delta/44/expire_url_cache.sql
rename to synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql
diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql
similarity index 100%
rename from synapse/storage/schema/delta/45/group_server.sql
rename to synapse/storage/data_stores/main/schema/delta/45/group_server.sql
diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql
similarity index 100%
rename from synapse/storage/schema/delta/45/profile_cache.sql
rename to synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql
diff --git a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql
similarity index 100%
rename from synapse/storage/schema/delta/46/drop_refresh_tokens.sql
rename to synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql
diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql
similarity index 100%
rename from synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql
rename to synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql
diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql
similarity index 100%
rename from synapse/storage/schema/delta/46/group_server.sql
rename to synapse/storage/data_stores/main/schema/delta/46/group_server.sql
diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql
similarity index 100%
rename from synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
rename to synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql
diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql
similarity index 100%
rename from synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
rename to synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql
diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql
similarity index 100%
rename from synapse/storage/schema/delta/46/user_dir_typos.sql
rename to synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql
diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql
similarity index 100%
rename from synapse/storage/schema/delta/47/last_access_media.sql
rename to synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql
diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql
similarity index 100%
rename from synapse/storage/schema/delta/47/postgres_fts_gin.sql
rename to synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql
diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql
similarity index 100%
rename from synapse/storage/schema/delta/47/push_actions_staging.sql
rename to synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py
similarity index 100%
rename from synapse/storage/schema/delta/47/state_group_seq.py
rename to synapse/storage/data_stores/main/schema/delta/47/state_group_seq.py
diff --git a/synapse/storage/schema/delta/48/add_user_consent.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql
similarity index 100%
rename from synapse/storage/schema/delta/48/add_user_consent.sql
rename to synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql
diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
rename to synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql
diff --git a/synapse/storage/schema/delta/48/deactivated_users.sql b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql
similarity index 100%
rename from synapse/storage/schema/delta/48/deactivated_users.sql
rename to synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql
diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py
similarity index 100%
rename from synapse/storage/schema/delta/48/group_unique_indexes.py
rename to synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py
diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql
similarity index 100%
rename from synapse/storage/schema/delta/48/groups_joinable.sql
rename to synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql
diff --git a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql
similarity index 100%
rename from synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql
rename to synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql
diff --git a/synapse/storage/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql
similarity index 100%
rename from synapse/storage/schema/delta/49/add_user_daily_visits.sql
rename to synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql
diff --git a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql
rename to synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
diff --git a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/50/add_creation_ts_users_index.sql
rename to synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql
diff --git a/synapse/storage/schema/delta/50/erasure_store.sql b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql
similarity index 100%
rename from synapse/storage/schema/delta/50/erasure_store.sql
rename to synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql
diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py
similarity index 100%
rename from synapse/storage/schema/delta/50/make_event_content_nullable.py
rename to synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py
diff --git a/synapse/storage/schema/delta/51/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql
similarity index 100%
rename from synapse/storage/schema/delta/51/e2e_room_keys.sql
rename to synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql
diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql
similarity index 100%
rename from synapse/storage/schema/delta/51/monthly_active_users.sql
rename to synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql
diff --git a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/52/add_event_to_state_group_index.sql
rename to synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql
diff --git a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql
similarity index 100%
rename from synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql
rename to synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql
diff --git a/synapse/storage/schema/delta/52/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql
similarity index 100%
rename from synapse/storage/schema/delta/52/e2e_room_keys.sql
rename to synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql
diff --git a/synapse/storage/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql
similarity index 100%
rename from synapse/storage/schema/delta/53/add_user_type_to_users.sql
rename to synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql
diff --git a/synapse/storage/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql
similarity index 100%
rename from synapse/storage/schema/delta/53/drop_sent_transactions.sql
rename to synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql
diff --git a/synapse/storage/schema/delta/53/event_format_version.sql b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql
similarity index 100%
rename from synapse/storage/schema/delta/53/event_format_version.sql
rename to synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql
diff --git a/synapse/storage/schema/delta/53/user_dir_populate.sql b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql
similarity index 100%
rename from synapse/storage/schema/delta/53/user_dir_populate.sql
rename to synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql
diff --git a/synapse/storage/schema/delta/53/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/53/user_ips_index.sql
rename to synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql
diff --git a/synapse/storage/schema/delta/53/user_share.sql b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql
similarity index 100%
rename from synapse/storage/schema/delta/53/user_share.sql
rename to synapse/storage/data_stores/main/schema/delta/53/user_share.sql
diff --git a/synapse/storage/schema/delta/53/user_threepid_id.sql b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql
similarity index 100%
rename from synapse/storage/schema/delta/53/user_threepid_id.sql
rename to synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql
diff --git a/synapse/storage/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql
similarity index 100%
rename from synapse/storage/schema/delta/53/users_in_public_rooms.sql
rename to synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql
diff --git a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql
similarity index 100%
rename from synapse/storage/schema/delta/54/account_validity_with_renewal.sql
rename to synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql
diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql
similarity index 100%
rename from synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
rename to synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql
diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql
similarity index 100%
rename from synapse/storage/schema/delta/54/delete_forward_extremities.sql
rename to synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql
diff --git a/synapse/storage/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql
similarity index 100%
rename from synapse/storage/schema/delta/54/drop_legacy_tables.sql
rename to synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql
diff --git a/synapse/storage/schema/delta/54/drop_presence_list.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql
similarity index 100%
rename from synapse/storage/schema/delta/54/drop_presence_list.sql
rename to synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql
diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/data_stores/main/schema/delta/54/relations.sql
similarity index 100%
rename from synapse/storage/schema/delta/54/relations.sql
rename to synapse/storage/data_stores/main/schema/delta/54/relations.sql
diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/data_stores/main/schema/delta/54/stats.sql
similarity index 100%
rename from synapse/storage/schema/delta/54/stats.sql
rename to synapse/storage/data_stores/main/schema/delta/54/stats.sql
diff --git a/synapse/storage/schema/delta/54/stats2.sql b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql
similarity index 100%
rename from synapse/storage/schema/delta/54/stats2.sql
rename to synapse/storage/data_stores/main/schema/delta/54/stats2.sql
diff --git a/synapse/storage/schema/delta/55/access_token_expiry.sql b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql
similarity index 100%
rename from synapse/storage/schema/delta/55/access_token_expiry.sql
rename to synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql
diff --git a/synapse/storage/schema/delta/55/track_threepid_validations.sql b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql
similarity index 100%
rename from synapse/storage/schema/delta/55/track_threepid_validations.sql
rename to synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql
diff --git a/synapse/storage/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql
similarity index 100%
rename from synapse/storage/schema/delta/55/users_alter_deactivated.sql
rename to synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql
diff --git a/synapse/storage/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/add_spans_to_device_lists.sql
rename to synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql
diff --git a/synapse/storage/schema/delta/56/current_state_events_membership.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/current_state_events_membership.sql
rename to synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql
diff --git a/synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/current_state_events_membership_mk2.sql
rename to synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql
diff --git a/synapse/storage/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/destinations_failure_ts.sql
rename to synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql
diff --git a/synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
similarity index 100%
rename from synapse/storage/schema/delta/56/destinations_retry_interval_type.sql.postgres
rename to synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
diff --git a/synapse/storage/schema/delta/56/devices_last_seen.sql b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/devices_last_seen.sql
rename to synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql
diff --git a/synapse/storage/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/drop_unused_event_tables.sql
rename to synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql
diff --git a/synapse/storage/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/fix_room_keys_index.sql
rename to synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql
diff --git a/synapse/storage/schema/delta/56/public_room_list_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/public_room_list_idx.sql
rename to synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql
diff --git a/synapse/storage/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/redaction_censor.sql
rename to synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
diff --git a/synapse/storage/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/redaction_censor2.sql
rename to synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
diff --git a/synapse/storage/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
similarity index 100%
rename from synapse/storage/schema/delta/56/redaction_censor3_fix_update.sql.postgres
rename to synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
diff --git a/synapse/storage/schema/delta/56/room_membership_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/room_membership_idx.sql
rename to synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql
diff --git a/synapse/storage/schema/delta/56/stats_separated.sql b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/stats_separated.sql
rename to synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
diff --git a/synapse/storage/schema/delta/56/unique_user_filter_index.py b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py
similarity index 100%
rename from synapse/storage/schema/delta/56/unique_user_filter_index.py
rename to synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py
diff --git a/synapse/storage/schema/delta/56/user_external_ids.sql b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/user_external_ids.sql
rename to synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql
diff --git a/synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql
similarity index 100%
rename from synapse/storage/schema/delta/56/users_in_public_rooms_idx.sql
rename to synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql
diff --git a/synapse/storage/schema/full_schemas/16/application_services.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/application_services.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql
diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/event_edges.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql
diff --git a/synapse/storage/schema/full_schemas/16/event_signatures.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/event_signatures.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql
diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/im.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/im.sql
diff --git a/synapse/storage/schema/full_schemas/16/keys.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/keys.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql
diff --git a/synapse/storage/schema/full_schemas/16/media_repository.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/media_repository.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql
diff --git a/synapse/storage/schema/full_schemas/16/presence.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/presence.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql
diff --git a/synapse/storage/schema/full_schemas/16/profiles.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/profiles.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql
diff --git a/synapse/storage/schema/full_schemas/16/push.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/push.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/push.sql
diff --git a/synapse/storage/schema/full_schemas/16/redactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/redactions.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql
diff --git a/synapse/storage/schema/full_schemas/16/room_aliases.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/room_aliases.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql
diff --git a/synapse/storage/schema/full_schemas/16/state.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/state.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/state.sql
diff --git a/synapse/storage/schema/full_schemas/16/transactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/transactions.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql
diff --git a/synapse/storage/schema/full_schemas/16/users.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/16/users.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/16/users.sql
diff --git a/synapse/storage/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
similarity index 99%
rename from synapse/storage/schema/full_schemas/54/full.sql.postgres
rename to synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
index 098434356f..4ad2929f32 100644
--- a/synapse/storage/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
@@ -70,15 +70,6 @@ CREATE TABLE appservice_stream_position (
 );
 
 
-
-CREATE TABLE background_updates (
-    update_name text NOT NULL,
-    progress_json text NOT NULL,
-    depends_on text
-);
-
-
-
 CREATE TABLE blocked_rooms (
     room_id text NOT NULL,
     user_id text NOT NULL
@@ -1202,11 +1193,6 @@ ALTER TABLE ONLY appservice_stream_position
 
 
 
-ALTER TABLE ONLY background_updates
-    ADD CONSTRAINT background_updates_uniqueness UNIQUE (update_name);
-
-
-
 ALTER TABLE ONLY current_state_events
     ADD CONSTRAINT current_state_events_event_id_key UNIQUE (event_id);
 
@@ -2047,6 +2033,3 @@ CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_room
 
 
 CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms USING btree (user_id, other_user_id, room_id);
-
-
-
diff --git a/synapse/storage/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
similarity index 99%
rename from synapse/storage/schema/full_schemas/54/full.sql.sqlite
rename to synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
index be9295e4c9..bad33291e7 100644
--- a/synapse/storage/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
@@ -67,7 +67,6 @@ CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id );
 CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id );
 CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, validated_at BIGINT NOT NULL, added_at BIGINT NOT NULL, CONSTRAINT medium_address UNIQUE (medium, address) );
 CREATE INDEX user_threepids_user_id ON user_threepids(user_id);
-CREATE TABLE background_updates( update_name TEXT NOT NULL, progress_json TEXT NOT NULL, depends_on TEXT, CONSTRAINT background_updates_uniqueness UNIQUE (update_name) );
 CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value )
 /* event_search(event_id,room_id,sender,"key",value) */;
 CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value');
diff --git a/synapse/storage/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
similarity index 100%
rename from synapse/storage/schema/full_schemas/54/stream_positions.sql
rename to synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
diff --git a/synapse/storage/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt
similarity index 100%
rename from synapse/storage/schema/full_schemas/README.txt
rename to synapse/storage/data_stores/main/schema/full_schemas/README.txt
diff --git a/synapse/storage/search.py b/synapse/storage/data_stores/main/search.py
similarity index 99%
rename from synapse/storage/search.py
rename to synapse/storage/data_stores/main/search.py
index 7695bf09fc..0e08497452 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -25,10 +25,9 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import make_in_list_sql_clause
+from synapse.storage.background_updates import BackgroundUpdateStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
-from .background_updates import BackgroundUpdateStore
-
 logger = logging.getLogger(__name__)
 
 SearchEntry = namedtuple(
diff --git a/synapse/storage/signatures.py b/synapse/storage/data_stores/main/signatures.py
similarity index 98%
rename from synapse/storage/signatures.py
rename to synapse/storage/data_stores/main/signatures.py
index fb83218f90..556191b76f 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/data_stores/main/signatures.py
@@ -20,10 +20,9 @@ from unpaddedbase64 import encode_base64
 from twisted.internet import defer
 
 from synapse.crypto.event_signing import compute_event_reference_hash
+from synapse.storage._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached, cachedList
 
-from ._base import SQLBaseStore
-
 # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
 # despite being deprecated and removed in favor of memoryview
 if six.PY2:
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
new file mode 100644
index 0000000000..d54442e5fa
--- /dev/null
+++ b/synapse/storage/data_stores/main/state.py
@@ -0,0 +1,1244 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from collections import namedtuple
+
+from six import iteritems, itervalues
+from six.moves import range
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import NotFoundError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.state import StateFilter
+from synapse.util.caches import get_cache_factor_for, intern_string
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.stringutils import to_ascii
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class _GetStateGroupDelta(
+    namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
+    """Return type of get_state_group_delta that implements __len__, which lets
+    us use the itrable flag when caching
+    """
+
+    __slots__ = []
+
+    def __len__(self):
+        return len(self.delta_ids) if self.delta_ids else 0
+
+
+class StateGroupBackgroundUpdateStore(SQLBaseStore):
+    """Defines functions related to state groups needed to run the state backgroud
+    updates.
+    """
+
+    def _count_state_group_hops_txn(self, txn, state_group):
+        """Given a state group, count how many hops there are in the tree.
+
+        This is used to ensure the delta chains don't get too long.
+        """
+        if isinstance(self.database_engine, PostgresEngine):
+            sql = """
+                WITH RECURSIVE state(state_group) AS (
+                    VALUES(?::bigint)
+                    UNION ALL
+                    SELECT prev_state_group FROM state_group_edges e, state s
+                    WHERE s.state_group = e.state_group
+                )
+                SELECT count(*) FROM state;
+            """
+
+            txn.execute(sql, (state_group,))
+            row = txn.fetchone()
+            if row and row[0]:
+                return row[0]
+            else:
+                return 0
+        else:
+            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+            next_group = state_group
+            count = 0
+
+            while next_group:
+                next_group = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="state_group_edges",
+                    keyvalues={"state_group": next_group},
+                    retcol="prev_state_group",
+                    allow_none=True,
+                )
+                if next_group:
+                    count += 1
+
+            return count
+
+    def _get_state_groups_from_groups_txn(
+        self, txn, groups, state_filter=StateFilter.all()
+    ):
+        results = {group: {} for group in groups}
+
+        where_clause, where_args = state_filter.make_sql_filter_clause()
+
+        # Unless the filter clause is empty, we're going to append it after an
+        # existing where clause
+        if where_clause:
+            where_clause = " AND (%s)" % (where_clause,)
+
+        if isinstance(self.database_engine, PostgresEngine):
+            # Temporarily disable sequential scans in this transaction. This is
+            # a temporary hack until we can add the right indices in
+            txn.execute("SET LOCAL enable_seqscan=off")
+
+            # The below query walks the state_group tree so that the "state"
+            # table includes all state_groups in the tree. It then joins
+            # against `state_groups_state` to fetch the latest state.
+            # It assumes that previous state groups are always numerically
+            # lesser.
+            # The PARTITION is used to get the event_id in the greatest state
+            # group for the given type, state_key.
+            # This may return multiple rows per (type, state_key), but last_value
+            # should be the same.
+            sql = """
+                WITH RECURSIVE state(state_group) AS (
+                    VALUES(?::bigint)
+                    UNION ALL
+                    SELECT prev_state_group FROM state_group_edges e, state s
+                    WHERE s.state_group = e.state_group
+                )
+                SELECT DISTINCT type, state_key, last_value(event_id) OVER (
+                    PARTITION BY type, state_key ORDER BY state_group ASC
+                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+                ) AS event_id FROM state_groups_state
+                WHERE state_group IN (
+                    SELECT state_group FROM state
+                )
+            """
+
+            for group in groups:
+                args = [group]
+                args.extend(where_args)
+
+                txn.execute(sql + where_clause, args)
+                for row in txn:
+                    typ, state_key, event_id = row
+                    key = (typ, state_key)
+                    results[group][key] = event_id
+        else:
+            max_entries_returned = state_filter.max_entries_returned()
+
+            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+            for group in groups:
+                next_group = group
+
+                while next_group:
+                    # We did this before by getting the list of group ids, and
+                    # then passing that list to sqlite to get latest event for
+                    # each (type, state_key). However, that was terribly slow
+                    # without the right indices (which we can't add until
+                    # after we finish deduping state, which requires this func)
+                    args = [next_group]
+                    args.extend(where_args)
+
+                    txn.execute(
+                        "SELECT type, state_key, event_id FROM state_groups_state"
+                        " WHERE state_group = ? " + where_clause,
+                        args,
+                    )
+                    results[group].update(
+                        ((typ, state_key), event_id)
+                        for typ, state_key, event_id in txn
+                        if (typ, state_key) not in results[group]
+                    )
+
+                    # If the number of entries in the (type,state_key)->event_id dict
+                    # matches the number of (type,state_keys) types we were searching
+                    # for, then we must have found them all, so no need to go walk
+                    # further down the tree... UNLESS our types filter contained
+                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
+                    # search
+                    if (
+                        max_entries_returned is not None
+                        and len(results[group]) == max_entries_returned
+                    ):
+                        break
+
+                    next_group = self._simple_select_one_onecol_txn(
+                        txn,
+                        table="state_group_edges",
+                        keyvalues={"state_group": next_group},
+                        retcol="prev_state_group",
+                        allow_none=True,
+                    )
+
+        return results
+
+
+# this inherits from EventsWorkerStore because it calls self.get_events
+class StateGroupWorkerStore(
+    EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore
+):
+    """The parts of StateGroupStore that can be called from workers.
+    """
+
+    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+
+    def __init__(self, db_conn, hs):
+        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
+
+        # Originally the state store used a single DictionaryCache to cache the
+        # event IDs for the state types in a given state group to avoid hammering
+        # on the state_group* tables.
+        #
+        # The point of using a DictionaryCache is that it can cache a subset
+        # of the state events for a given state group (i.e. a subset of the keys for a
+        # given dict which is an entry in the cache for a given state group ID).
+        #
+        # However, this poses problems when performing complicated queries
+        # on the store - for instance: "give me all the state for this group, but
+        # limit members to this subset of users", as DictionaryCache's API isn't
+        # rich enough to say "please cache any of these fields, apart from this subset".
+        # This is problematic when lazy loading members, which requires this behaviour,
+        # as without it the cache has no choice but to speculatively load all
+        # state events for the group, which negates the efficiency being sought.
+        #
+        # Rather than overcomplicating DictionaryCache's API, we instead split the
+        # state_group_cache into two halves - one for tracking non-member events,
+        # and the other for tracking member_events.  This means that lazy loading
+        # queries can be made in a cache-friendly manner by querying both caches
+        # separately and then merging the result.  So for the example above, you
+        # would query the members cache for a specific subset of state keys
+        # (which DictionaryCache will handle efficiently and fine) and the non-members
+        # cache for all state (which DictionaryCache will similarly handle fine)
+        # and then just merge the results together.
+        #
+        # We size the non-members cache to be smaller than the members cache as the
+        # vast majority of state in Matrix (today) is member events.
+
+        self._state_group_cache = DictionaryCache(
+            "*stateGroupCache*",
+            # TODO: this hasn't been tuned yet
+            50000 * get_cache_factor_for("stateGroupCache"),
+        )
+        self._state_group_members_cache = DictionaryCache(
+            "*stateGroupMembersCache*",
+            500000 * get_cache_factor_for("stateGroupMembersCache"),
+        )
+
+    @defer.inlineCallbacks
+    def get_room_version(self, room_id):
+        """Get the room_version of a given room
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[str]
+
+        Raises:
+            NotFoundError if the room is unknown
+        """
+        # for now we do this by looking at the create event. We may want to cache this
+        # more intelligently in future.
+
+        # Retrieve the room's create event
+        create_event = yield self.get_create_event_for_room(room_id)
+        return create_event.content.get("room_version", "1")
+
+    @defer.inlineCallbacks
+    def get_room_predecessor(self, room_id):
+        """Get the predecessor room of an upgraded room if one exists.
+        Otherwise return None.
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[unicode|None]: predecessor room id
+
+        Raises:
+            NotFoundError if the room is unknown
+        """
+        # Retrieve the room's create event
+        create_event = yield self.get_create_event_for_room(room_id)
+
+        # Return predecessor if present
+        return create_event.content.get("predecessor", None)
+
+    @defer.inlineCallbacks
+    def get_create_event_for_room(self, room_id):
+        """Get the create state event for a room.
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[EventBase]: The room creation event.
+
+        Raises:
+            NotFoundError if the room is unknown
+        """
+        state_ids = yield self.get_current_state_ids(room_id)
+        create_id = state_ids.get((EventTypes.Create, ""))
+
+        # If we can't find the create event, assume we've hit a dead end
+        if not create_id:
+            raise NotFoundError("Unknown room %s" % (room_id))
+
+        # Retrieve the room's create event and return
+        create_event = yield self.get_event(create_id)
+        return create_event
+
+    @cached(max_entries=100000, iterable=True)
+    def get_current_state_ids(self, room_id):
+        """Get the current state event ids for a room based on the
+        current_state_events table.
+
+        Args:
+            room_id (str)
+
+        Returns:
+            deferred: dict of (type, state_key) -> event_id
+        """
+
+        def _get_current_state_ids_txn(txn):
+            txn.execute(
+                """SELECT type, state_key, event_id FROM current_state_events
+                WHERE room_id = ?
+                """,
+                (room_id,),
+            )
+
+            return {
+                (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
+            }
+
+        return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
+
+    # FIXME: how should this be cached?
+    def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
+        """Get the current state event of a given type for a room based on the
+        current_state_events table.  This may not be as up-to-date as the result
+        of doing a fresh state resolution as per state_handler.get_current_state
+
+        Args:
+            room_id (str)
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns:
+            Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
+            event ID.
+        """
+
+        where_clause, where_args = state_filter.make_sql_filter_clause()
+
+        if not where_clause:
+            # We delegate to the cached version
+            return self.get_current_state_ids(room_id)
+
+        def _get_filtered_current_state_ids_txn(txn):
+            results = {}
+            sql = """
+                SELECT type, state_key, event_id FROM current_state_events
+                WHERE room_id = ?
+            """
+
+            if where_clause:
+                sql += " AND (%s)" % (where_clause,)
+
+            args = [room_id]
+            args.extend(where_args)
+            txn.execute(sql, args)
+            for row in txn:
+                typ, state_key, event_id = row
+                key = (intern_string(typ), intern_string(state_key))
+                results[key] = event_id
+
+            return results
+
+        return self.runInteraction(
+            "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
+        )
+
+    @defer.inlineCallbacks
+    def get_canonical_alias_for_room(self, room_id):
+        """Get canonical alias for room, if any
+
+        Args:
+            room_id (str)
+
+        Returns:
+            Deferred[str|None]: The canonical alias, if any
+        """
+
+        state = yield self.get_filtered_current_state_ids(
+            room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
+        )
+
+        event_id = state.get((EventTypes.CanonicalAlias, ""))
+        if not event_id:
+            return
+
+        event = yield self.get_event(event_id, allow_none=True)
+        if not event:
+            return
+
+        return event.content.get("canonical_alias")
+
+    @cached(max_entries=10000, iterable=True)
+    def get_state_group_delta(self, state_group):
+        """Given a state group try to return a previous group and a delta between
+        the old and the new.
+
+        Returns:
+            (prev_group, delta_ids), where both may be None.
+        """
+
+        def _get_state_group_delta_txn(txn):
+            prev_group = self._simple_select_one_onecol_txn(
+                txn,
+                table="state_group_edges",
+                keyvalues={"state_group": state_group},
+                retcol="prev_state_group",
+                allow_none=True,
+            )
+
+            if not prev_group:
+                return _GetStateGroupDelta(None, None)
+
+            delta_ids = self._simple_select_list_txn(
+                txn,
+                table="state_groups_state",
+                keyvalues={"state_group": state_group},
+                retcols=("type", "state_key", "event_id"),
+            )
+
+            return _GetStateGroupDelta(
+                prev_group,
+                {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+            )
+
+        return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
+
+    @defer.inlineCallbacks
+    def get_state_groups_ids(self, _room_id, event_ids):
+        """Get the event IDs of all the state for the state groups for the given events
+
+        Args:
+            _room_id (str): id of the room for these events
+            event_ids (iterable[str]): ids of the events
+
+        Returns:
+            Deferred[dict[int, dict[tuple[str, str], str]]]:
+                dict of state_group_id -> (dict of (type, state_key) -> event id)
+        """
+        if not event_ids:
+            return {}
+
+        event_to_groups = yield self._get_state_group_for_events(event_ids)
+
+        groups = set(itervalues(event_to_groups))
+        group_to_state = yield self._get_state_for_groups(groups)
+
+        return group_to_state
+
+    @defer.inlineCallbacks
+    def get_state_ids_for_group(self, state_group):
+        """Get the event IDs of all the state in the given state group
+
+        Args:
+            state_group (int)
+
+        Returns:
+            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+        """
+        group_to_state = yield self._get_state_for_groups((state_group,))
+
+        return group_to_state[state_group]
+
+    @defer.inlineCallbacks
+    def get_state_groups(self, room_id, event_ids):
+        """ Get the state groups for the given list of event_ids
+
+        Returns:
+            Deferred[dict[int, list[EventBase]]]:
+                dict of state_group_id -> list of state events.
+        """
+        if not event_ids:
+            return {}
+
+        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+
+        state_event_map = yield self.get_events(
+            [
+                ev_id
+                for group_ids in itervalues(group_to_ids)
+                for ev_id in itervalues(group_ids)
+            ],
+            get_prev_content=False,
+        )
+
+        return {
+            group: [
+                state_event_map[v]
+                for v in itervalues(event_id_map)
+                if v in state_event_map
+            ]
+            for group, event_id_map in iteritems(group_to_ids)
+        }
+
+    @defer.inlineCallbacks
+    def _get_state_groups_from_groups(self, groups, state_filter):
+        """Returns the state groups for a given set of groups, filtering on
+        types of state events.
+
+        Args:
+            groups(list[int]): list of state group IDs to query
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+        Returns:
+            Deferred[dict[int, dict[tuple[str, str], str]]]:
+                dict of state_group_id -> (dict of (type, state_key) -> event id)
+        """
+        results = {}
+
+        chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
+        for chunk in chunks:
+            res = yield self.runInteraction(
+                "_get_state_groups_from_groups",
+                self._get_state_groups_from_groups_txn,
+                chunk,
+                state_filter,
+            )
+            results.update(res)
+
+        return results
+
+    @defer.inlineCallbacks
+    def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
+        """Given a list of event_ids and type tuples, return a list of state
+        dicts for each event.
+
+        Args:
+            event_ids (list[string])
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns:
+            deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
+        """
+        event_to_groups = yield self._get_state_group_for_events(event_ids)
+
+        groups = set(itervalues(event_to_groups))
+        group_to_state = yield self._get_state_for_groups(groups, state_filter)
+
+        state_event_map = yield self.get_events(
+            [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
+            get_prev_content=False,
+        )
+
+        event_to_state = {
+            event_id: {
+                k: state_event_map[v]
+                for k, v in iteritems(group_to_state[group])
+                if v in state_event_map
+            }
+            for event_id, group in iteritems(event_to_groups)
+        }
+
+        return {event: event_to_state[event] for event in event_ids}
+
+    @defer.inlineCallbacks
+    def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
+        """
+        Get the state dicts corresponding to a list of events, containing the event_ids
+        of the state events (as opposed to the events themselves)
+
+        Args:
+            event_ids(list(str)): events whose state should be returned
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns:
+            A deferred dict from event_id -> (type, state_key) -> event_id
+        """
+        event_to_groups = yield self._get_state_group_for_events(event_ids)
+
+        groups = set(itervalues(event_to_groups))
+        group_to_state = yield self._get_state_for_groups(groups, state_filter)
+
+        event_to_state = {
+            event_id: group_to_state[group]
+            for event_id, group in iteritems(event_to_groups)
+        }
+
+        return {event: event_to_state[event] for event in event_ids}
+
+    @defer.inlineCallbacks
+    def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
+        """
+        Get the state dict corresponding to a particular event
+
+        Args:
+            event_id(str): event whose state should be returned
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns:
+            A deferred dict from (type, state_key) -> state_event
+        """
+        state_map = yield self.get_state_for_events([event_id], state_filter)
+        return state_map[event_id]
+
+    @defer.inlineCallbacks
+    def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
+        """
+        Get the state dict corresponding to a particular event
+
+        Args:
+            event_id(str): event whose state should be returned
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns:
+            A deferred dict from (type, state_key) -> state_event
+        """
+        state_map = yield self.get_state_ids_for_events([event_id], state_filter)
+        return state_map[event_id]
+
+    @cached(max_entries=50000)
+    def _get_state_group_for_event(self, event_id):
+        return self._simple_select_one_onecol(
+            table="event_to_state_groups",
+            keyvalues={"event_id": event_id},
+            retcol="state_group",
+            allow_none=True,
+            desc="_get_state_group_for_event",
+        )
+
+    @cachedList(
+        cached_method_name="_get_state_group_for_event",
+        list_name="event_ids",
+        num_args=1,
+        inlineCallbacks=True,
+    )
+    def _get_state_group_for_events(self, event_ids):
+        """Returns mapping event_id -> state_group
+        """
+        rows = yield self._simple_select_many_batch(
+            table="event_to_state_groups",
+            column="event_id",
+            iterable=event_ids,
+            keyvalues={},
+            retcols=("event_id", "state_group"),
+            desc="_get_state_group_for_events",
+        )
+
+        return {row["event_id"]: row["state_group"] for row in rows}
+
+    def _get_state_for_group_using_cache(self, cache, group, state_filter):
+        """Checks if group is in cache. See `_get_state_for_groups`
+
+        Args:
+            cache(DictionaryCache): the state group cache to use
+            group(int): The state group to lookup
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns 2-tuple (`state_dict`, `got_all`).
+        `got_all` is a bool indicating if we successfully retrieved all
+        requests state from the cache, if False we need to query the DB for the
+        missing state.
+        """
+        is_all, known_absent, state_dict_ids = cache.get(group)
+
+        if is_all or state_filter.is_full():
+            # Either we have everything or want everything, either way
+            # `is_all` tells us whether we've gotten everything.
+            return state_filter.filter_state(state_dict_ids), is_all
+
+        # tracks whether any of our requested types are missing from the cache
+        missing_types = False
+
+        if state_filter.has_wildcards():
+            # We don't know if we fetched all the state keys for the types in
+            # the filter that are wildcards, so we have to assume that we may
+            # have missed some.
+            missing_types = True
+        else:
+            # There aren't any wild cards, so `concrete_types()` returns the
+            # complete list of event types we're wanting.
+            for key in state_filter.concrete_types():
+                if key not in state_dict_ids and key not in known_absent:
+                    missing_types = True
+                    break
+
+        return state_filter.filter_state(state_dict_ids), not missing_types
+
+    @defer.inlineCallbacks
+    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key
+
+        Args:
+            groups (iterable[int]): list of state groups for which we want
+                to get the state.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+        Returns:
+            Deferred[dict[int, dict[tuple[str, str], str]]]:
+                dict of state_group_id -> (dict of (type, state_key) -> event id)
+        """
+
+        member_filter, non_member_filter = state_filter.get_member_split()
+
+        # Now we look them up in the member and non-member caches
+        non_member_state, incomplete_groups_nm, = (
+            yield self._get_state_for_groups_using_cache(
+                groups, self._state_group_cache, state_filter=non_member_filter
+            )
+        )
+
+        member_state, incomplete_groups_m, = (
+            yield self._get_state_for_groups_using_cache(
+                groups, self._state_group_members_cache, state_filter=member_filter
+            )
+        )
+
+        state = dict(non_member_state)
+        for group in groups:
+            state[group].update(member_state[group])
+
+        # Now fetch any missing groups from the database
+
+        incomplete_groups = incomplete_groups_m | incomplete_groups_nm
+
+        if not incomplete_groups:
+            return state
+
+        cache_sequence_nm = self._state_group_cache.sequence
+        cache_sequence_m = self._state_group_members_cache.sequence
+
+        # Help the cache hit ratio by expanding the filter a bit
+        db_state_filter = state_filter.return_expanded()
+
+        group_to_state_dict = yield self._get_state_groups_from_groups(
+            list(incomplete_groups), state_filter=db_state_filter
+        )
+
+        # Now lets update the caches
+        self._insert_into_cache(
+            group_to_state_dict,
+            db_state_filter,
+            cache_seq_num_members=cache_sequence_m,
+            cache_seq_num_non_members=cache_sequence_nm,
+        )
+
+        # And finally update the result dict, by filtering out any extra
+        # stuff we pulled out of the database.
+        for group, group_state_dict in iteritems(group_to_state_dict):
+            # We just replace any existing entries, as we will have loaded
+            # everything we need from the database anyway.
+            state[group] = state_filter.filter_state(group_state_dict)
+
+        return state
+
+    def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
+        """Gets the state at each of a list of state groups, optionally
+        filtering by type/state_key, querying from a specific cache.
+
+        Args:
+            groups (iterable[int]): list of state groups for which we want
+                to get the state.
+            cache (DictionaryCache): the cache of group ids to state dicts which
+                we will pass through - either the normal state cache or the specific
+                members state cache.
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+
+        Returns:
+            tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
+            dict of state_group_id -> (dict of (type, state_key) -> event id)
+            of entries in the cache, and the state group ids either missing
+            from the cache or incomplete.
+        """
+        results = {}
+        incomplete_groups = set()
+        for group in set(groups):
+            state_dict_ids, got_all = self._get_state_for_group_using_cache(
+                cache, group, state_filter
+            )
+            results[group] = state_dict_ids
+
+            if not got_all:
+                incomplete_groups.add(group)
+
+        return results, incomplete_groups
+
+    def _insert_into_cache(
+        self,
+        group_to_state_dict,
+        state_filter,
+        cache_seq_num_members,
+        cache_seq_num_non_members,
+    ):
+        """Inserts results from querying the database into the relevant cache.
+
+        Args:
+            group_to_state_dict (dict): The new entries pulled from database.
+                Map from state group to state dict
+            state_filter (StateFilter): The state filter used to fetch state
+                from the database.
+            cache_seq_num_members (int): Sequence number of member cache since
+                last lookup in cache
+            cache_seq_num_non_members (int): Sequence number of member cache since
+                last lookup in cache
+        """
+
+        # We need to work out which types we've fetched from the DB for the
+        # member vs non-member caches. This should be as accurate as possible,
+        # but can be an underestimate (e.g. when we have wild cards)
+
+        member_filter, non_member_filter = state_filter.get_member_split()
+        if member_filter.is_full():
+            # We fetched all member events
+            member_types = None
+        else:
+            # `concrete_types()` will only return a subset when there are wild
+            # cards in the filter, but that's fine.
+            member_types = member_filter.concrete_types()
+
+        if non_member_filter.is_full():
+            # We fetched all non member events
+            non_member_types = None
+        else:
+            non_member_types = non_member_filter.concrete_types()
+
+        for group, group_state_dict in iteritems(group_to_state_dict):
+            state_dict_members = {}
+            state_dict_non_members = {}
+
+            for k, v in iteritems(group_state_dict):
+                if k[0] == EventTypes.Member:
+                    state_dict_members[k] = v
+                else:
+                    state_dict_non_members[k] = v
+
+            self._state_group_members_cache.update(
+                cache_seq_num_members,
+                key=group,
+                value=state_dict_members,
+                fetched_keys=member_types,
+            )
+
+            self._state_group_cache.update(
+                cache_seq_num_non_members,
+                key=group,
+                value=state_dict_non_members,
+                fetched_keys=non_member_types,
+            )
+
+    def store_state_group(
+        self, event_id, room_id, prev_group, delta_ids, current_state_ids
+    ):
+        """Store a new set of state, returning a newly assigned state group.
+
+        Args:
+            event_id (str): The event ID for which the state was calculated
+            room_id (str)
+            prev_group (int|None): A previous state group for the room, optional.
+            delta_ids (dict|None): The delta between state at `prev_group` and
+                `current_state_ids`, if `prev_group` was given. Same format as
+                `current_state_ids`.
+            current_state_ids (dict): The state to store. Map of (type, state_key)
+                to event_id.
+
+        Returns:
+            Deferred[int]: The state group ID
+        """
+
+        def _store_state_group_txn(txn):
+            if current_state_ids is None:
+                # AFAIK, this can never happen
+                raise Exception("current_state_ids cannot be None")
+
+            state_group = self.database_engine.get_next_state_group_id(txn)
+
+            self._simple_insert_txn(
+                txn,
+                table="state_groups",
+                values={"id": state_group, "room_id": room_id, "event_id": event_id},
+            )
+
+            # We persist as a delta if we can, while also ensuring the chain
+            # of deltas isn't tooo long, as otherwise read performance degrades.
+            if prev_group:
+                is_in_db = self._simple_select_one_onecol_txn(
+                    txn,
+                    table="state_groups",
+                    keyvalues={"id": prev_group},
+                    retcol="id",
+                    allow_none=True,
+                )
+                if not is_in_db:
+                    raise Exception(
+                        "Trying to persist state with unpersisted prev_group: %r"
+                        % (prev_group,)
+                    )
+
+                potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+                self._simple_insert_txn(
+                    txn,
+                    table="state_group_edges",
+                    values={"state_group": state_group, "prev_state_group": prev_group},
+                )
+
+                self._simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": state_group,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in iteritems(delta_ids)
+                    ],
+                )
+            else:
+                self._simple_insert_many_txn(
+                    txn,
+                    table="state_groups_state",
+                    values=[
+                        {
+                            "state_group": state_group,
+                            "room_id": room_id,
+                            "type": key[0],
+                            "state_key": key[1],
+                            "event_id": state_id,
+                        }
+                        for key, state_id in iteritems(current_state_ids)
+                    ],
+                )
+
+            # Prefill the state group caches with this group.
+            # It's fine to use the sequence like this as the state group map
+            # is immutable. (If the map wasn't immutable then this prefill could
+            # race with another update)
+
+            current_member_state_ids = {
+                s: ev
+                for (s, ev) in iteritems(current_state_ids)
+                if s[0] == EventTypes.Member
+            }
+            txn.call_after(
+                self._state_group_members_cache.update,
+                self._state_group_members_cache.sequence,
+                key=state_group,
+                value=dict(current_member_state_ids),
+            )
+
+            current_non_member_state_ids = {
+                s: ev
+                for (s, ev) in iteritems(current_state_ids)
+                if s[0] != EventTypes.Member
+            }
+            txn.call_after(
+                self._state_group_cache.update,
+                self._state_group_cache.sequence,
+                key=state_group,
+                value=dict(current_non_member_state_ids),
+            )
+
+            return state_group
+
+        return self.runInteraction("store_state_group", _store_state_group_txn)
+
+
+class StateBackgroundUpdateStore(
+    StateGroupBackgroundUpdateStore, BackgroundUpdateStore
+):
+
+    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+    EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
+
+    def __init__(self, db_conn, hs):
+        super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
+        self.register_background_update_handler(
+            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+            self._background_deduplicate_state,
+        )
+        self.register_background_update_handler(
+            self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
+        )
+        self.register_background_index_update(
+            self.CURRENT_STATE_INDEX_UPDATE_NAME,
+            index_name="current_state_events_member_index",
+            table="current_state_events",
+            columns=["state_key"],
+            where_clause="type='m.room.member'",
+        )
+        self.register_background_index_update(
+            self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
+            index_name="event_to_state_groups_sg_index",
+            table="event_to_state_groups",
+            columns=["state_group"],
+        )
+
+    @defer.inlineCallbacks
+    def _background_deduplicate_state(self, progress, batch_size):
+        """This background update will slowly deduplicate state by reencoding
+        them as deltas.
+        """
+        last_state_group = progress.get("last_state_group", 0)
+        rows_inserted = progress.get("rows_inserted", 0)
+        max_group = progress.get("max_group", None)
+
+        BATCH_SIZE_SCALE_FACTOR = 100
+
+        batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
+
+        if max_group is None:
+            rows = yield self._execute(
+                "_background_deduplicate_state",
+                None,
+                "SELECT coalesce(max(id), 0) FROM state_groups",
+            )
+            max_group = rows[0][0]
+
+        def reindex_txn(txn):
+            new_last_state_group = last_state_group
+            for count in range(batch_size):
+                txn.execute(
+                    "SELECT id, room_id FROM state_groups"
+                    " WHERE ? < id AND id <= ?"
+                    " ORDER BY id ASC"
+                    " LIMIT 1",
+                    (new_last_state_group, max_group),
+                )
+                row = txn.fetchone()
+                if row:
+                    state_group, room_id = row
+
+                if not row or not state_group:
+                    return True, count
+
+                txn.execute(
+                    "SELECT state_group FROM state_group_edges"
+                    " WHERE state_group = ?",
+                    (state_group,),
+                )
+
+                # If we reach a point where we've already started inserting
+                # edges we should stop.
+                if txn.fetchall():
+                    return True, count
+
+                txn.execute(
+                    "SELECT coalesce(max(id), 0) FROM state_groups"
+                    " WHERE id < ? AND room_id = ?",
+                    (state_group, room_id),
+                )
+                prev_group, = txn.fetchone()
+                new_last_state_group = state_group
+
+                if prev_group:
+                    potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+                    if potential_hops >= MAX_STATE_DELTA_HOPS:
+                        # We want to ensure chains are at most this long,#
+                        # otherwise read performance degrades.
+                        continue
+
+                    prev_state = self._get_state_groups_from_groups_txn(
+                        txn, [prev_group]
+                    )
+                    prev_state = prev_state[prev_group]
+
+                    curr_state = self._get_state_groups_from_groups_txn(
+                        txn, [state_group]
+                    )
+                    curr_state = curr_state[state_group]
+
+                    if not set(prev_state.keys()) - set(curr_state.keys()):
+                        # We can only do a delta if the current has a strict super set
+                        # of keys
+
+                        delta_state = {
+                            key: value
+                            for key, value in iteritems(curr_state)
+                            if prev_state.get(key, None) != value
+                        }
+
+                        self._simple_delete_txn(
+                            txn,
+                            table="state_group_edges",
+                            keyvalues={"state_group": state_group},
+                        )
+
+                        self._simple_insert_txn(
+                            txn,
+                            table="state_group_edges",
+                            values={
+                                "state_group": state_group,
+                                "prev_state_group": prev_group,
+                            },
+                        )
+
+                        self._simple_delete_txn(
+                            txn,
+                            table="state_groups_state",
+                            keyvalues={"state_group": state_group},
+                        )
+
+                        self._simple_insert_many_txn(
+                            txn,
+                            table="state_groups_state",
+                            values=[
+                                {
+                                    "state_group": state_group,
+                                    "room_id": room_id,
+                                    "type": key[0],
+                                    "state_key": key[1],
+                                    "event_id": state_id,
+                                }
+                                for key, state_id in iteritems(delta_state)
+                            ],
+                        )
+
+            progress = {
+                "last_state_group": state_group,
+                "rows_inserted": rows_inserted + batch_size,
+                "max_group": max_group,
+            }
+
+            self._background_update_progress_txn(
+                txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
+            )
+
+            return False, batch_size
+
+        finished, result = yield self.runInteraction(
+            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
+        )
+
+        if finished:
+            yield self._end_background_update(
+                self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
+            )
+
+        return result * BATCH_SIZE_SCALE_FACTOR
+
+    @defer.inlineCallbacks
+    def _background_index_state(self, progress, batch_size):
+        def reindex_txn(conn):
+            conn.rollback()
+            if isinstance(self.database_engine, PostgresEngine):
+                # postgres insists on autocommit for the index
+                conn.set_session(autocommit=True)
+                try:
+                    txn = conn.cursor()
+                    txn.execute(
+                        "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
+                        " ON state_groups_state(state_group, type, state_key)"
+                    )
+                    txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+                finally:
+                    conn.set_session(autocommit=False)
+            else:
+                txn = conn.cursor()
+                txn.execute(
+                    "CREATE INDEX state_groups_state_type_idx"
+                    " ON state_groups_state(state_group, type, state_key)"
+                )
+                txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+
+        yield self.runWithConnection(reindex_txn)
+
+        yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+
+        return 1
+
+
+class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
+    """ Keeps track of the state at a given event.
+
+    This is done by the concept of `state groups`. Every event is a assigned
+    a state group (identified by an arbitrary string), which references a
+    collection of state events. The current state of an event is then the
+    collection of state events referenced by the event's state group.
+
+    Hence, every change in the current state causes a new state group to be
+    generated. However, if no change happens (e.g., if we get a message event
+    with only one parent it inherits the state group from its parent.)
+
+    There are three tables:
+      * `state_groups`: Stores group name, first event with in the group and
+        room id.
+      * `event_to_state_groups`: Maps events to state groups.
+      * `state_groups_state`: Maps state group to state events.
+    """
+
+    def __init__(self, db_conn, hs):
+        super(StateStore, self).__init__(db_conn, hs)
+
+    def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+        state_groups = {}
+        for event, context in events_and_contexts:
+            if event.internal_metadata.is_outlier():
+                continue
+
+            # if the event was rejected, just give it the same state as its
+            # predecessor.
+            if context.rejected:
+                state_groups[event.event_id] = context.prev_group
+                continue
+
+            state_groups[event.event_id] = context.state_group
+
+        self._simple_insert_many_txn(
+            txn,
+            table="event_to_state_groups",
+            values=[
+                {"state_group": state_group_id, "event_id": event_id}
+                for event_id, state_group_id in iteritems(state_groups)
+            ],
+        )
+
+        for event_id, state_group_id in iteritems(state_groups):
+            txn.call_after(
+                self._get_state_group_for_event.prefill, (event_id,), state_group_id
+            )
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py
similarity index 100%
rename from synapse/storage/state_deltas.py
rename to synapse/storage/data_stores/main/state_deltas.py
diff --git a/synapse/storage/stats.py b/synapse/storage/data_stores/main/stats.py
similarity index 99%
rename from synapse/storage/stats.py
rename to synapse/storage/data_stores/main/stats.py
index 7c224cd3d9..5ab639b2ad 100644
--- a/synapse/storage/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -21,8 +21,8 @@ from twisted.internet import defer
 from twisted.internet.defer import DeferredLock
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.storage import PostgresEngine
-from synapse.storage.state_deltas import StateDeltasStore
+from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.engines import PostgresEngine
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
diff --git a/synapse/storage/stream.py b/synapse/storage/data_stores/main/stream.py
similarity index 99%
rename from synapse/storage/stream.py
rename to synapse/storage/data_stores/main/stream.py
index 490454f19a..263999dfca 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -43,8 +43,8 @@ from twisted.internet import defer
 
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.events_worker import EventsWorkerStore
 from synapse.types import RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
diff --git a/synapse/storage/tags.py b/synapse/storage/data_stores/main/tags.py
similarity index 99%
rename from synapse/storage/tags.py
rename to synapse/storage/data_stores/main/tags.py
index 20dd6bd53d..10d1887f75 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -22,7 +22,7 @@ from canonicaljson import json
 
 from twisted.internet import defer
 
-from synapse.storage.account_data import AccountDataWorkerStore
+from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
 from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
diff --git a/synapse/storage/transactions.py b/synapse/storage/data_stores/main/transactions.py
similarity index 99%
rename from synapse/storage/transactions.py
rename to synapse/storage/data_stores/main/transactions.py
index 289c117396..01b1be5e14 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -23,10 +23,9 @@ from canonicaljson import encode_canonical_json
 from twisted.internet import defer
 
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.util.caches.expiringcache import ExpiringCache
 
-from ._base import SQLBaseStore, db_to_json
-
 # py2 sqlite has buffer hardcoded as only binary type, so we must use it,
 # despite being deprecated and removed in favor of memoryview
 if six.PY2:
diff --git a/synapse/storage/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
similarity index 99%
rename from synapse/storage/user_directory.py
rename to synapse/storage/data_stores/main/user_directory.py
index 1b1e4751b9..652abe0e6a 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -20,9 +20,9 @@ from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.data_stores.main.state import StateFilter
+from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.storage.state import StateFilter
-from synapse.storage.state_deltas import StateDeltasStore
 from synapse.types import get_domain_from_id, get_localpart_from_id
 from synapse.util.caches.descriptors import cached
 
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
similarity index 100%
rename from synapse/storage/user_erasure_store.py
rename to synapse/storage/data_stores/main/user_erasure_store.py
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index e72f89e446..4769b21529 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -14,208 +14,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import itertools
 import logging
 
-import six
-
 import attr
-from signedjson.key import decode_verify_key_bytes
-
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedList
-
-from ._base import SQLBaseStore
 
 logger = logging.getLogger(__name__)
 
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
-    db_binary_type = six.moves.builtins.buffer
-else:
-    db_binary_type = memoryview
-
 
 @attr.s(slots=True, frozen=True)
 class FetchKeyResult(object):
     verify_key = attr.ib()  # VerifyKey: the key itself
     valid_until_ts = attr.ib()  # int: how long we can use this key for
-
-
-class KeyStore(SQLBaseStore):
-    """Persistence for signature verification keys
-    """
-
-    @cached()
-    def _get_server_verify_key(self, server_name_and_key_id):
-        raise NotImplementedError()
-
-    @cachedList(
-        cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
-    )
-    def get_server_verify_keys(self, server_name_and_key_ids):
-        """
-        Args:
-            server_name_and_key_ids (iterable[Tuple[str, str]]):
-                iterable of (server_name, key-id) tuples to fetch keys for
-
-        Returns:
-            Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
-                map from (server_name, key_id) -> FetchKeyResult, or None if the key is
-                unknown
-        """
-        keys = {}
-
-        def _get_keys(txn, batch):
-            """Processes a batch of keys to fetch, and adds the result to `keys`."""
-
-            # batch_iter always returns tuples so it's safe to do len(batch)
-            sql = (
-                "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
-                "FROM server_signature_keys WHERE 1=0"
-            ) + " OR (server_name=? AND key_id=?)" * len(batch)
-
-            txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
-
-            for row in txn:
-                server_name, key_id, key_bytes, ts_valid_until_ms = row
-
-                if ts_valid_until_ms is None:
-                    # Old keys may be stored with a ts_valid_until_ms of null,
-                    # in which case we treat this as if it was set to `0`, i.e.
-                    # it won't match key requests that define a minimum
-                    # `ts_valid_until_ms`.
-                    ts_valid_until_ms = 0
-
-                res = FetchKeyResult(
-                    verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
-                    valid_until_ts=ts_valid_until_ms,
-                )
-                keys[(server_name, key_id)] = res
-
-        def _txn(txn):
-            for batch in batch_iter(server_name_and_key_ids, 50):
-                _get_keys(txn, batch)
-            return keys
-
-        return self.runInteraction("get_server_verify_keys", _txn)
-
-    def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
-        """Stores NACL verification keys for remote servers.
-        Args:
-            from_server (str): Where the verification keys were looked up
-            ts_added_ms (int): The time to record that the key was added
-            verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
-                keys to be stored. Each entry is a triplet of
-                (server_name, key_id, key).
-        """
-        key_values = []
-        value_values = []
-        invalidations = []
-        for server_name, key_id, fetch_result in verify_keys:
-            key_values.append((server_name, key_id))
-            value_values.append(
-                (
-                    from_server,
-                    ts_added_ms,
-                    fetch_result.valid_until_ts,
-                    db_binary_type(fetch_result.verify_key.encode()),
-                )
-            )
-            # invalidate takes a tuple corresponding to the params of
-            # _get_server_verify_key. _get_server_verify_key only takes one
-            # param, which is itself the 2-tuple (server_name, key_id).
-            invalidations.append((server_name, key_id))
-
-        def _invalidate(res):
-            f = self._get_server_verify_key.invalidate
-            for i in invalidations:
-                f((i,))
-            return res
-
-        return self.runInteraction(
-            "store_server_verify_keys",
-            self._simple_upsert_many_txn,
-            table="server_signature_keys",
-            key_names=("server_name", "key_id"),
-            key_values=key_values,
-            value_names=(
-                "from_server",
-                "ts_added_ms",
-                "ts_valid_until_ms",
-                "verify_key",
-            ),
-            value_values=value_values,
-        ).addCallback(_invalidate)
-
-    def store_server_keys_json(
-        self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
-    ):
-        """Stores the JSON bytes for a set of keys from a server
-        The JSON should be signed by the originating server, the intermediate
-        server, and by this server. Updates the value for the
-        (server_name, key_id, from_server) triplet if one already existed.
-        Args:
-            server_name (str): The name of the server.
-            key_id (str): The identifer of the key this JSON is for.
-            from_server (str): The server this JSON was fetched from.
-            ts_now_ms (int): The time now in milliseconds.
-            ts_valid_until_ms (int): The time when this json stops being valid.
-            key_json (bytes): The encoded JSON.
-        """
-        return self._simple_upsert(
-            table="server_keys_json",
-            keyvalues={
-                "server_name": server_name,
-                "key_id": key_id,
-                "from_server": from_server,
-            },
-            values={
-                "server_name": server_name,
-                "key_id": key_id,
-                "from_server": from_server,
-                "ts_added_ms": ts_now_ms,
-                "ts_valid_until_ms": ts_expires_ms,
-                "key_json": db_binary_type(key_json_bytes),
-            },
-            desc="store_server_keys_json",
-        )
-
-    def get_server_keys_json(self, server_keys):
-        """Retrive the key json for a list of server_keys and key ids.
-        If no keys are found for a given server, key_id and source then
-        that server, key_id, and source triplet entry will be an empty list.
-        The JSON is returned as a byte array so that it can be efficiently
-        used in an HTTP response.
-        Args:
-            server_keys (list): List of (server_name, key_id, source) triplets.
-        Returns:
-            Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
-                Dict mapping (server_name, key_id, source) triplets to lists of dicts
-        """
-
-        def _get_server_keys_json_txn(txn):
-            results = {}
-            for server_name, key_id, from_server in server_keys:
-                keyvalues = {"server_name": server_name}
-                if key_id is not None:
-                    keyvalues["key_id"] = key_id
-                if from_server is not None:
-                    keyvalues["from_server"] = from_server
-                rows = self._simple_select_list_txn(
-                    txn,
-                    "server_keys_json",
-                    keyvalues=keyvalues,
-                    retcols=(
-                        "key_id",
-                        "from_server",
-                        "ts_added_ms",
-                        "ts_valid_until_ms",
-                        "key_json",
-                    ),
-                )
-                results[(server_name, key_id, from_server)] = rows
-            return results
-
-        return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 3a641f538b..18a462f0ee 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -15,12 +15,7 @@
 
 from collections import namedtuple
 
-from twisted.internet import defer
-
 from synapse.api.constants import PresenceState
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedList
 
 
 class UserPresenceState(
@@ -72,132 +67,3 @@ class UserPresenceState(
             status_msg=None,
             currently_active=False,
         )
-
-
-class PresenceStore(SQLBaseStore):
-    @defer.inlineCallbacks
-    def update_presence(self, presence_states):
-        stream_ordering_manager = self._presence_id_gen.get_next_mult(
-            len(presence_states)
-        )
-
-        with stream_ordering_manager as stream_orderings:
-            yield self.runInteraction(
-                "update_presence",
-                self._update_presence_txn,
-                stream_orderings,
-                presence_states,
-            )
-
-        return stream_orderings[-1], self._presence_id_gen.get_current_token()
-
-    def _update_presence_txn(self, txn, stream_orderings, presence_states):
-        for stream_id, state in zip(stream_orderings, presence_states):
-            txn.call_after(
-                self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
-            )
-            txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
-
-        # Actually insert new rows
-        self._simple_insert_many_txn(
-            txn,
-            table="presence_stream",
-            values=[
-                {
-                    "stream_id": stream_id,
-                    "user_id": state.user_id,
-                    "state": state.state,
-                    "last_active_ts": state.last_active_ts,
-                    "last_federation_update_ts": state.last_federation_update_ts,
-                    "last_user_sync_ts": state.last_user_sync_ts,
-                    "status_msg": state.status_msg,
-                    "currently_active": state.currently_active,
-                }
-                for state in presence_states
-            ],
-        )
-
-        # Delete old rows to stop database from getting really big
-        sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
-
-        for states in batch_iter(presence_states, 50):
-            clause, args = make_in_list_sql_clause(
-                self.database_engine, "user_id", [s.user_id for s in states]
-            )
-            txn.execute(sql + clause, [stream_id] + list(args))
-
-    def get_all_presence_updates(self, last_id, current_id):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_presence_updates_txn(txn):
-            sql = (
-                "SELECT stream_id, user_id, state, last_active_ts,"
-                " last_federation_update_ts, last_user_sync_ts, status_msg,"
-                " currently_active"
-                " FROM presence_stream"
-                " WHERE ? < stream_id AND stream_id <= ?"
-            )
-            txn.execute(sql, (last_id, current_id))
-            return txn.fetchall()
-
-        return self.runInteraction(
-            "get_all_presence_updates", get_all_presence_updates_txn
-        )
-
-    @cached()
-    def _get_presence_for_user(self, user_id):
-        raise NotImplementedError()
-
-    @cachedList(
-        cached_method_name="_get_presence_for_user",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
-    )
-    def get_presence_for_users(self, user_ids):
-        rows = yield self._simple_select_many_batch(
-            table="presence_stream",
-            column="user_id",
-            iterable=user_ids,
-            keyvalues={},
-            retcols=(
-                "user_id",
-                "state",
-                "last_active_ts",
-                "last_federation_update_ts",
-                "last_user_sync_ts",
-                "status_msg",
-                "currently_active",
-            ),
-            desc="get_presence_for_users",
-        )
-
-        for row in rows:
-            row["currently_active"] = bool(row["currently_active"])
-
-        return {row["user_id"]: UserPresenceState(**row) for row in rows}
-
-    def get_current_presence_token(self):
-        return self._presence_id_gen.get_current_token()
-
-    def allow_presence_visible(self, observed_localpart, observer_userid):
-        return self._simple_insert(
-            table="presence_allow_inbound",
-            values={
-                "observed_user_id": observed_localpart,
-                "observer_user_id": observer_userid,
-            },
-            desc="allow_presence_visible",
-            or_ignore=True,
-        )
-
-    def disallow_presence_visible(self, observed_localpart, observer_userid):
-        return self._simple_delete_one(
-            table="presence_allow_inbound",
-            keyvalues={
-                "observed_user_id": observed_localpart,
-                "observer_user_id": observer_userid,
-            },
-            desc="disallow_presence_visible",
-        )
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index c4e24edff2..f47cec0d86 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -14,704 +14,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import abc
-import logging
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.push.baserules import list_with_base_rules
-from synapse.storage.appservice import ApplicationServiceWorkerStore
-from synapse.storage.pusher import PusherWorkerStore
-from synapse.storage.receipts import ReceiptsWorkerStore
-from synapse.storage.roommember import RoomMemberWorkerStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-from ._base import SQLBaseStore
-
-logger = logging.getLogger(__name__)
-
-
-def _load_rules(rawrules, enabled_map):
-    ruleslist = []
-    for rawrule in rawrules:
-        rule = dict(rawrule)
-        rule["conditions"] = json.loads(rawrule["conditions"])
-        rule["actions"] = json.loads(rawrule["actions"])
-        ruleslist.append(rule)
-
-    # We're going to be mutating this a lot, so do a deep copy
-    rules = list(list_with_base_rules(ruleslist))
-
-    for i, rule in enumerate(rules):
-        rule_id = rule["rule_id"]
-        if rule_id in enabled_map:
-            if rule.get("enabled", True) != bool(enabled_map[rule_id]):
-                # Rules are cached across users.
-                rule = dict(rule)
-                rule["enabled"] = bool(enabled_map[rule_id])
-                rules[i] = rule
-
-    return rules
-
-
-class PushRulesWorkerStore(
-    ApplicationServiceWorkerStore,
-    ReceiptsWorkerStore,
-    PusherWorkerStore,
-    RoomMemberWorkerStore,
-    SQLBaseStore,
-):
-    """This is an abstract base class where subclasses must implement
-    `get_max_push_rules_stream_id` which can be called in the initializer.
-    """
-
-    # This ABCMeta metaclass ensures that we cannot be instantiated without
-    # the abstract methods being implemented.
-    __metaclass__ = abc.ABCMeta
-
-    def __init__(self, db_conn, hs):
-        super(PushRulesWorkerStore, self).__init__(db_conn, hs)
-
-        push_rules_prefill, push_rules_id = self._get_cache_dict(
-            db_conn,
-            "push_rules_stream",
-            entity_column="user_id",
-            stream_column="stream_id",
-            max_value=self.get_max_push_rules_stream_id(),
-        )
-
-        self.push_rules_stream_cache = StreamChangeCache(
-            "PushRulesStreamChangeCache",
-            push_rules_id,
-            prefilled_cache=push_rules_prefill,
-        )
-
-    @abc.abstractmethod
-    def get_max_push_rules_stream_id(self):
-        """Get the position of the push rules stream.
-
-        Returns:
-            int
-        """
-        raise NotImplementedError()
-
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_push_rules_for_user(self, user_id):
-        rows = yield self._simple_select_list(
-            table="push_rules",
-            keyvalues={"user_name": user_id},
-            retcols=(
-                "user_name",
-                "rule_id",
-                "priority_class",
-                "priority",
-                "conditions",
-                "actions",
-            ),
-            desc="get_push_rules_enabled_for_user",
-        )
-
-        rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
-
-        enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
-
-        rules = _load_rules(rows, enabled_map)
-
-        return rules
-
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_push_rules_enabled_for_user(self, user_id):
-        results = yield self._simple_select_list(
-            table="push_rules_enable",
-            keyvalues={"user_name": user_id},
-            retcols=("user_name", "rule_id", "enabled"),
-            desc="get_push_rules_enabled_for_user",
-        )
-        return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
-
-    def have_push_rules_changed_for_user(self, user_id, last_id):
-        if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
-            return defer.succeed(False)
-        else:
-
-            def have_push_rules_changed_txn(txn):
-                sql = (
-                    "SELECT COUNT(stream_id) FROM push_rules_stream"
-                    " WHERE user_id = ? AND ? < stream_id"
-                )
-                txn.execute(sql, (user_id, last_id))
-                count, = txn.fetchone()
-                return bool(count)
-
-            return self.runInteraction(
-                "have_push_rules_changed", have_push_rules_changed_txn
-            )
-
-    @cachedList(
-        cached_method_name="get_push_rules_for_user",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
-    )
-    def bulk_get_push_rules(self, user_ids):
-        if not user_ids:
-            return {}
-
-        results = {user_id: [] for user_id in user_ids}
-
-        rows = yield self._simple_select_many_batch(
-            table="push_rules",
-            column="user_name",
-            iterable=user_ids,
-            retcols=("*",),
-            desc="bulk_get_push_rules",
-        )
-
-        rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
-
-        for row in rows:
-            results.setdefault(row["user_name"], []).append(row)
-
-        enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
-
-        for user_id, rules in results.items():
-            results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
-
-        return results
-
-    @defer.inlineCallbacks
-    def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
-        """Copy a single push rule from one room to another for a specific user.
-
-        Args:
-            new_room_id (str): ID of the new room.
-            user_id (str): ID of user the push rule belongs to.
-            rule (Dict): A push rule.
-        """
-        # Create new rule id
-        rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
-        new_rule_id = rule_id_scope + "/" + new_room_id
-
-        # Change room id in each condition
-        for condition in rule.get("conditions", []):
-            if condition.get("key") == "room_id":
-                condition["pattern"] = new_room_id
-
-        # Add the rule for the new room
-        yield self.add_push_rule(
-            user_id=user_id,
-            rule_id=new_rule_id,
-            priority_class=rule["priority_class"],
-            conditions=rule["conditions"],
-            actions=rule["actions"],
-        )
-
-    @defer.inlineCallbacks
-    def copy_push_rules_from_room_to_room_for_user(
-        self, old_room_id, new_room_id, user_id
-    ):
-        """Copy all of the push rules from one room to another for a specific
-        user.
-
-        Args:
-            old_room_id (str): ID of the old room.
-            new_room_id (str): ID of the new room.
-            user_id (str): ID of user to copy push rules for.
-        """
-        # Retrieve push rules for this user
-        user_push_rules = yield self.get_push_rules_for_user(user_id)
-
-        # Get rules relating to the old room and copy them to the new room
-        for rule in user_push_rules:
-            conditions = rule.get("conditions", [])
-            if any(
-                (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
-                for c in conditions
-            ):
-                yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
-
-    @defer.inlineCallbacks
-    def bulk_get_push_rules_for_room(self, event, context):
-        state_group = context.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        current_state_ids = yield context.get_current_state_ids(self)
-        result = yield self._bulk_get_push_rules_for_room(
-            event.room_id, state_group, current_state_ids, event=event
-        )
-        return result
-
-    @cachedInlineCallbacks(num_args=2, cache_context=True)
-    def _bulk_get_push_rules_for_room(
-        self, room_id, state_group, current_state_ids, cache_context, event=None
-    ):
-        # We don't use `state_group`, its there so that we can cache based
-        # on it. However, its important that its never None, since two current_state's
-        # with a state_group of None are likely to be different.
-        # See bulk_get_push_rules_for_room for how we work around this.
-        assert state_group is not None
-
-        # We also will want to generate notifs for other people in the room so
-        # their unread countss are correct in the event stream, but to avoid
-        # generating them for bot / AS users etc, we only do so for people who've
-        # sent a read receipt into the room.
-
-        users_in_room = yield self._get_joined_users_from_context(
-            room_id,
-            state_group,
-            current_state_ids,
-            on_invalidate=cache_context.invalidate,
-            event=event,
-        )
-
-        # We ignore app service users for now. This is so that we don't fill
-        # up the `get_if_users_have_pushers` cache with AS entries that we
-        # know don't have pushers, nor even read receipts.
-        local_users_in_room = set(
-            u
-            for u in users_in_room
-            if self.hs.is_mine_id(u)
-            and not self.get_if_app_services_interested_in_user(u)
-        )
-
-        # users in the room who have pushers need to get push rules run because
-        # that's how their pushers work
-        if_users_with_pushers = yield self.get_if_users_have_pushers(
-            local_users_in_room, on_invalidate=cache_context.invalidate
-        )
-        user_ids = set(
-            uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
-        )
-
-        users_with_receipts = yield self.get_users_with_read_receipts_in_room(
-            room_id, on_invalidate=cache_context.invalidate
-        )
-
-        # any users with pushers must be ours: they have pushers
-        for uid in users_with_receipts:
-            if uid in local_users_in_room:
-                user_ids.add(uid)
-
-        rules_by_user = yield self.bulk_get_push_rules(
-            user_ids, on_invalidate=cache_context.invalidate
-        )
-
-        rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
-
-        return rules_by_user
-
-    @cachedList(
-        cached_method_name="get_push_rules_enabled_for_user",
-        list_name="user_ids",
-        num_args=1,
-        inlineCallbacks=True,
-    )
-    def bulk_get_push_rules_enabled(self, user_ids):
-        if not user_ids:
-            return {}
-
-        results = {user_id: {} for user_id in user_ids}
-
-        rows = yield self._simple_select_many_batch(
-            table="push_rules_enable",
-            column="user_name",
-            iterable=user_ids,
-            retcols=("user_name", "rule_id", "enabled"),
-            desc="bulk_get_push_rules_enabled",
-        )
-        for row in rows:
-            enabled = bool(row["enabled"])
-            results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
-        return results
-
-
-class PushRuleStore(PushRulesWorkerStore):
-    @defer.inlineCallbacks
-    def add_push_rule(
-        self,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions,
-        actions,
-        before=None,
-        after=None,
-    ):
-        conditions_json = json.dumps(conditions)
-        actions_json = json.dumps(actions)
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            if before or after:
-                yield self.runInteraction(
-                    "_add_push_rule_relative_txn",
-                    self._add_push_rule_relative_txn,
-                    stream_id,
-                    event_stream_ordering,
-                    user_id,
-                    rule_id,
-                    priority_class,
-                    conditions_json,
-                    actions_json,
-                    before,
-                    after,
-                )
-            else:
-                yield self.runInteraction(
-                    "_add_push_rule_highest_priority_txn",
-                    self._add_push_rule_highest_priority_txn,
-                    stream_id,
-                    event_stream_ordering,
-                    user_id,
-                    rule_id,
-                    priority_class,
-                    conditions_json,
-                    actions_json,
-                )
-
-    def _add_push_rule_relative_txn(
-        self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions_json,
-        actions_json,
-        before,
-        after,
-    ):
-        # Lock the table since otherwise we'll have annoying races between the
-        # SELECT here and the UPSERT below.
-        self.database_engine.lock_table(txn, "push_rules")
-
-        relative_to_rule = before or after
-
-        res = self._simple_select_one_txn(
-            txn,
-            table="push_rules",
-            keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
-            retcols=["priority_class", "priority"],
-            allow_none=True,
-        )
-
-        if not res:
-            raise RuleNotFoundException(
-                "before/after rule not found: %s" % (relative_to_rule,)
-            )
-
-        base_priority_class = res["priority_class"]
-        base_rule_priority = res["priority"]
-
-        if base_priority_class != priority_class:
-            raise InconsistentRuleException(
-                "Given priority class does not match class of relative rule"
-            )
-
-        if before:
-            # Higher priority rules are executed first, So adding a rule before
-            # a rule means giving it a higher priority than that rule.
-            new_rule_priority = base_rule_priority + 1
-        else:
-            # We increment the priority of the existing rules to make space for
-            # the new rule. Therefore if we want this rule to appear after
-            # an existing rule we give it the priority of the existing rule,
-            # and then increment the priority of the existing rule.
-            new_rule_priority = base_rule_priority
-
-        sql = (
-            "UPDATE push_rules SET priority = priority + 1"
-            " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
-        )
-
-        txn.execute(sql, (user_id, priority_class, new_rule_priority))
-
-        self._upsert_push_rule_txn(
-            txn,
-            stream_id,
-            event_stream_ordering,
-            user_id,
-            rule_id,
-            priority_class,
-            new_rule_priority,
-            conditions_json,
-            actions_json,
-        )
-
-    def _add_push_rule_highest_priority_txn(
-        self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        conditions_json,
-        actions_json,
-    ):
-        # Lock the table since otherwise we'll have annoying races between the
-        # SELECT here and the UPSERT below.
-        self.database_engine.lock_table(txn, "push_rules")
-
-        # find the highest priority rule in that class
-        sql = (
-            "SELECT COUNT(*), MAX(priority) FROM push_rules"
-            " WHERE user_name = ? and priority_class = ?"
-        )
-        txn.execute(sql, (user_id, priority_class))
-        res = txn.fetchall()
-        (how_many, highest_prio) = res[0]
-
-        new_prio = 0
-        if how_many > 0:
-            new_prio = highest_prio + 1
-
-        self._upsert_push_rule_txn(
-            txn,
-            stream_id,
-            event_stream_ordering,
-            user_id,
-            rule_id,
-            priority_class,
-            new_prio,
-            conditions_json,
-            actions_json,
-        )
-
-    def _upsert_push_rule_txn(
-        self,
-        txn,
-        stream_id,
-        event_stream_ordering,
-        user_id,
-        rule_id,
-        priority_class,
-        priority,
-        conditions_json,
-        actions_json,
-        update_stream=True,
-    ):
-        """Specialised version of _simple_upsert_txn that picks a push_rule_id
-        using the _push_rule_id_gen if it needs to insert the rule. It assumes
-        that the "push_rules" table is locked"""
-
-        sql = (
-            "UPDATE push_rules"
-            " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
-            " WHERE user_name = ? AND rule_id = ?"
-        )
-
-        txn.execute(
-            sql,
-            (priority_class, priority, conditions_json, actions_json, user_id, rule_id),
-        )
-
-        if txn.rowcount == 0:
-            # We didn't update a row with the given rule_id so insert one
-            push_rule_id = self._push_rule_id_gen.get_next()
-
-            self._simple_insert_txn(
-                txn,
-                table="push_rules",
-                values={
-                    "id": push_rule_id,
-                    "user_name": user_id,
-                    "rule_id": rule_id,
-                    "priority_class": priority_class,
-                    "priority": priority,
-                    "conditions": conditions_json,
-                    "actions": actions_json,
-                },
-            )
-
-        if update_stream:
-            self._insert_push_rules_update_txn(
-                txn,
-                stream_id,
-                event_stream_ordering,
-                user_id,
-                rule_id,
-                op="ADD",
-                data={
-                    "priority_class": priority_class,
-                    "priority": priority,
-                    "conditions": conditions_json,
-                    "actions": actions_json,
-                },
-            )
-
-    @defer.inlineCallbacks
-    def delete_push_rule(self, user_id, rule_id):
-        """
-        Delete a push rule. Args specify the row to be deleted and can be
-        any of the columns in the push_rule table, but below are the
-        standard ones
-
-        Args:
-            user_id (str): The matrix ID of the push rule owner
-            rule_id (str): The rule_id of the rule to be deleted
-        """
-
-        def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
-            self._simple_delete_one_txn(
-                txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
-            )
-
-            self._insert_push_rules_update_txn(
-                txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
-            )
-
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            yield self.runInteraction(
-                "delete_push_rule",
-                delete_push_rule_txn,
-                stream_id,
-                event_stream_ordering,
-            )
-
-    @defer.inlineCallbacks
-    def set_push_rule_enabled(self, user_id, rule_id, enabled):
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            yield self.runInteraction(
-                "_set_push_rule_enabled_txn",
-                self._set_push_rule_enabled_txn,
-                stream_id,
-                event_stream_ordering,
-                user_id,
-                rule_id,
-                enabled,
-            )
-
-    def _set_push_rule_enabled_txn(
-        self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
-    ):
-        new_id = self._push_rules_enable_id_gen.get_next()
-        self._simple_upsert_txn(
-            txn,
-            "push_rules_enable",
-            {"user_name": user_id, "rule_id": rule_id},
-            {"enabled": 1 if enabled else 0},
-            {"id": new_id},
-        )
-
-        self._insert_push_rules_update_txn(
-            txn,
-            stream_id,
-            event_stream_ordering,
-            user_id,
-            rule_id,
-            op="ENABLE" if enabled else "DISABLE",
-        )
-
-    @defer.inlineCallbacks
-    def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
-        actions_json = json.dumps(actions)
-
-        def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
-            if is_default_rule:
-                # Add a dummy rule to the rules table with the user specified
-                # actions.
-                priority_class = -1
-                priority = 1
-                self._upsert_push_rule_txn(
-                    txn,
-                    stream_id,
-                    event_stream_ordering,
-                    user_id,
-                    rule_id,
-                    priority_class,
-                    priority,
-                    "[]",
-                    actions_json,
-                    update_stream=False,
-                )
-            else:
-                self._simple_update_one_txn(
-                    txn,
-                    "push_rules",
-                    {"user_name": user_id, "rule_id": rule_id},
-                    {"actions": actions_json},
-                )
-
-            self._insert_push_rules_update_txn(
-                txn,
-                stream_id,
-                event_stream_ordering,
-                user_id,
-                rule_id,
-                op="ACTIONS",
-                data={"actions": actions_json},
-            )
-
-        with self._push_rules_stream_id_gen.get_next() as ids:
-            stream_id, event_stream_ordering = ids
-            yield self.runInteraction(
-                "set_push_rule_actions",
-                set_push_rule_actions_txn,
-                stream_id,
-                event_stream_ordering,
-            )
-
-    def _insert_push_rules_update_txn(
-        self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
-    ):
-        values = {
-            "stream_id": stream_id,
-            "event_stream_ordering": event_stream_ordering,
-            "user_id": user_id,
-            "rule_id": rule_id,
-            "op": op,
-        }
-        if data is not None:
-            values.update(data)
-
-        self._simple_insert_txn(txn, "push_rules_stream", values=values)
-
-        txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
-        txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
-        txn.call_after(
-            self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
-        )
-
-    def get_all_push_rule_updates(self, last_id, current_id, limit):
-        """Get all the push rules changes that have happend on the server"""
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_push_rule_updates_txn(txn):
-            sql = (
-                "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
-                " op, priority_class, priority, conditions, actions"
-                " FROM push_rules_stream"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            return txn.fetchall()
-
-        return self.runInteraction(
-            "get_all_push_rule_updates", get_all_push_rule_updates_txn
-        )
-
-    def get_push_rules_stream_token(self):
-        """Get the position of the push rules stream.
-        Returns a pair of a stream id for the push_rules stream and the
-        room stream ordering it corresponds to."""
-        return self._push_rules_stream_id_gen.get_current_token()
-
-    def get_max_push_rules_stream_id(self):
-        return self.get_push_rules_stream_token()[0]
-
 
 class RuleNotFoundException(Exception):
     pass
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index fcb5f2f23a..d471ec9860 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -17,11 +17,7 @@ import logging
 
 import attr
 
-from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.stream import generate_pagination_where_clause
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 
 logger = logging.getLogger(__name__)
 
@@ -113,358 +109,3 @@ class AggregationPaginationToken(object):
 
     def as_tuple(self):
         return attr.astuple(self)
-
-
-class RelationsWorkerStore(SQLBaseStore):
-    @cached(tree=True)
-    def get_relations_for_event(
-        self,
-        event_id,
-        relation_type=None,
-        event_type=None,
-        aggregation_key=None,
-        limit=5,
-        direction="b",
-        from_token=None,
-        to_token=None,
-    ):
-        """Get a list of relations for an event, ordered by topological ordering.
-
-        Args:
-            event_id (str): Fetch events that relate to this event ID.
-            relation_type (str|None): Only fetch events with this relation
-                type, if given.
-            event_type (str|None): Only fetch events with this event type, if
-                given.
-            aggregation_key (str|None): Only fetch events with this aggregation
-                key, if given.
-            limit (int): Only fetch the most recent `limit` events.
-            direction (str): Whether to fetch the most recent first (`"b"`) or
-                the oldest first (`"f"`).
-            from_token (RelationPaginationToken|None): Fetch rows from the given
-                token, or from the start if None.
-            to_token (RelationPaginationToken|None): Fetch rows up to the given
-                token, or up to the end if None.
-
-        Returns:
-            Deferred[PaginationChunk]: List of event IDs that match relations
-            requested. The rows are of the form `{"event_id": "..."}`.
-        """
-
-        where_clause = ["relates_to_id = ?"]
-        where_args = [event_id]
-
-        if relation_type is not None:
-            where_clause.append("relation_type = ?")
-            where_args.append(relation_type)
-
-        if event_type is not None:
-            where_clause.append("type = ?")
-            where_args.append(event_type)
-
-        if aggregation_key:
-            where_clause.append("aggregation_key = ?")
-            where_args.append(aggregation_key)
-
-        pagination_clause = generate_pagination_where_clause(
-            direction=direction,
-            column_names=("topological_ordering", "stream_ordering"),
-            from_token=attr.astuple(from_token) if from_token else None,
-            to_token=attr.astuple(to_token) if to_token else None,
-            engine=self.database_engine,
-        )
-
-        if pagination_clause:
-            where_clause.append(pagination_clause)
-
-        if direction == "b":
-            order = "DESC"
-        else:
-            order = "ASC"
-
-        sql = """
-            SELECT event_id, topological_ordering, stream_ordering
-            FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE %s
-            ORDER BY topological_ordering %s, stream_ordering %s
-            LIMIT ?
-        """ % (
-            " AND ".join(where_clause),
-            order,
-            order,
-        )
-
-        def _get_recent_references_for_event_txn(txn):
-            txn.execute(sql, where_args + [limit + 1])
-
-            last_topo_id = None
-            last_stream_id = None
-            events = []
-            for row in txn:
-                events.append({"event_id": row[0]})
-                last_topo_id = row[1]
-                last_stream_id = row[2]
-
-            next_batch = None
-            if len(events) > limit and last_topo_id and last_stream_id:
-                next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
-
-            return PaginationChunk(
-                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
-            )
-
-        return self.runInteraction(
-            "get_recent_references_for_event", _get_recent_references_for_event_txn
-        )
-
-    @cached(tree=True)
-    def get_aggregation_groups_for_event(
-        self,
-        event_id,
-        event_type=None,
-        limit=5,
-        direction="b",
-        from_token=None,
-        to_token=None,
-    ):
-        """Get a list of annotations on the event, grouped by event type and
-        aggregation key, sorted by count.
-
-        This is used e.g. to get the what and how many reactions have happend
-        on an event.
-
-        Args:
-            event_id (str): Fetch events that relate to this event ID.
-            event_type (str|None): Only fetch events with this event type, if
-                given.
-            limit (int): Only fetch the `limit` groups.
-            direction (str): Whether to fetch the highest count first (`"b"`) or
-                the lowest count first (`"f"`).
-            from_token (AggregationPaginationToken|None): Fetch rows from the
-                given token, or from the start if None.
-            to_token (AggregationPaginationToken|None): Fetch rows up to the
-                given token, or up to the end if None.
-
-
-        Returns:
-            Deferred[PaginationChunk]: List of groups of annotations that
-            match. Each row is a dict with `type`, `key` and `count` fields.
-        """
-
-        where_clause = ["relates_to_id = ?", "relation_type = ?"]
-        where_args = [event_id, RelationTypes.ANNOTATION]
-
-        if event_type:
-            where_clause.append("type = ?")
-            where_args.append(event_type)
-
-        having_clause = generate_pagination_where_clause(
-            direction=direction,
-            column_names=("COUNT(*)", "MAX(stream_ordering)"),
-            from_token=attr.astuple(from_token) if from_token else None,
-            to_token=attr.astuple(to_token) if to_token else None,
-            engine=self.database_engine,
-        )
-
-        if direction == "b":
-            order = "DESC"
-        else:
-            order = "ASC"
-
-        if having_clause:
-            having_clause = "HAVING " + having_clause
-        else:
-            having_clause = ""
-
-        sql = """
-            SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
-            FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE {where_clause}
-            GROUP BY relation_type, type, aggregation_key
-            {having_clause}
-            ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
-            LIMIT ?
-        """.format(
-            where_clause=" AND ".join(where_clause),
-            order=order,
-            having_clause=having_clause,
-        )
-
-        def _get_aggregation_groups_for_event_txn(txn):
-            txn.execute(sql, where_args + [limit + 1])
-
-            next_batch = None
-            events = []
-            for row in txn:
-                events.append({"type": row[0], "key": row[1], "count": row[2]})
-                next_batch = AggregationPaginationToken(row[2], row[3])
-
-            if len(events) <= limit:
-                next_batch = None
-
-            return PaginationChunk(
-                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
-            )
-
-        return self.runInteraction(
-            "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
-        )
-
-    @cachedInlineCallbacks()
-    def get_applicable_edit(self, event_id):
-        """Get the most recent edit (if any) that has happened for the given
-        event.
-
-        Correctly handles checking whether edits were allowed to happen.
-
-        Args:
-            event_id (str): The original event ID
-
-        Returns:
-            Deferred[EventBase|None]: Returns the most recent edit, if any.
-        """
-
-        # We only allow edits for `m.room.message` events that have the same sender
-        # and event type. We can't assert these things during regular event auth so
-        # we have to do the checks post hoc.
-
-        # Fetches latest edit that has the same type and sender as the
-        # original, and is an `m.room.message`.
-        sql = """
-            SELECT edit.event_id FROM events AS edit
-            INNER JOIN event_relations USING (event_id)
-            INNER JOIN events AS original ON
-                original.event_id = relates_to_id
-                AND edit.type = original.type
-                AND edit.sender = original.sender
-            WHERE
-                relates_to_id = ?
-                AND relation_type = ?
-                AND edit.type = 'm.room.message'
-            ORDER by edit.origin_server_ts DESC, edit.event_id DESC
-            LIMIT 1
-        """
-
-        def _get_applicable_edit_txn(txn):
-            txn.execute(sql, (event_id, RelationTypes.REPLACE))
-            row = txn.fetchone()
-            if row:
-                return row[0]
-
-        edit_id = yield self.runInteraction(
-            "get_applicable_edit", _get_applicable_edit_txn
-        )
-
-        if not edit_id:
-            return
-
-        edit_event = yield self.get_event(edit_id, allow_none=True)
-        return edit_event
-
-    def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
-        """Check if a user has already annotated an event with the same key
-        (e.g. already liked an event).
-
-        Args:
-            parent_id (str): The event being annotated
-            event_type (str): The event type of the annotation
-            aggregation_key (str): The aggregation key of the annotation
-            sender (str): The sender of the annotation
-
-        Returns:
-            Deferred[bool]
-        """
-
-        sql = """
-            SELECT 1 FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE
-                relates_to_id = ?
-                AND relation_type = ?
-                AND type = ?
-                AND sender = ?
-                AND aggregation_key = ?
-            LIMIT 1;
-        """
-
-        def _get_if_user_has_annotated_event(txn):
-            txn.execute(
-                sql,
-                (
-                    parent_id,
-                    RelationTypes.ANNOTATION,
-                    event_type,
-                    sender,
-                    aggregation_key,
-                ),
-            )
-
-            return bool(txn.fetchone())
-
-        return self.runInteraction(
-            "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
-        )
-
-
-class RelationsStore(RelationsWorkerStore):
-    def _handle_event_relations(self, txn, event):
-        """Handles inserting relation data during peristence of events
-
-        Args:
-            txn
-            event (EventBase)
-        """
-        relation = event.content.get("m.relates_to")
-        if not relation:
-            # No relations
-            return
-
-        rel_type = relation.get("rel_type")
-        if rel_type not in (
-            RelationTypes.ANNOTATION,
-            RelationTypes.REFERENCE,
-            RelationTypes.REPLACE,
-        ):
-            # Unknown relation type
-            return
-
-        parent_id = relation.get("event_id")
-        if not parent_id:
-            # Invalid relation
-            return
-
-        aggregation_key = relation.get("key")
-
-        self._simple_insert_txn(
-            txn,
-            table="event_relations",
-            values={
-                "event_id": event.event_id,
-                "relates_to_id": parent_id,
-                "relation_type": rel_type,
-                "aggregation_key": aggregation_key,
-            },
-        )
-
-        txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
-        txn.call_after(
-            self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
-        )
-
-        if rel_type == RelationTypes.REPLACE:
-            txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
-
-    def _handle_redaction(self, txn, redacted_event_id):
-        """Handles receiving a redaction and checking whether we need to remove
-        any redacted relations from the database.
-
-        Args:
-            txn
-            redacted_event_id (str): The event that was redacted.
-        """
-
-        self._simple_delete_txn(
-            txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
-        )
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index ff63487823..8c4a83a840 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,26 +17,6 @@
 import logging
 from collections import namedtuple
 
-from six import iteritems, itervalues
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes, Membership
-from synapse.metrics import LaterGauge
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause
-from synapse.storage.background_updates import BackgroundUpdateStore
-from synapse.storage.engines import Sqlite3Engine
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.types import get_domain_from_id
-from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import intern_string
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
-from synapse.util.metrics import Measure
-from synapse.util.stringutils import to_ascii
-
 logger = logging.getLogger(__name__)
 
 
@@ -57,1102 +37,3 @@ ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
 # a given membership type, suitable for use in calculating heroes for a room.
 # "count" points to the total numberr of users of a given membership type.
 MemberSummary = namedtuple("MemberSummary", ("members", "count"))
-
-_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
-_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
-
-
-class RoomMemberWorkerStore(EventsWorkerStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberWorkerStore, self).__init__(db_conn, hs)
-
-        # Is the current_state_events.membership up to date? Or is the
-        # background update still running?
-        self._current_state_events_membership_up_to_date = False
-
-        txn = LoggingTransaction(
-            db_conn.cursor(),
-            name="_check_safe_current_state_events_membership_updated",
-            database_engine=self.database_engine,
-        )
-        self._check_safe_current_state_events_membership_updated_txn(txn)
-        txn.close()
-
-        if self.hs.config.metrics_flags.known_servers:
-            self._known_servers_count = 1
-            self.hs.get_clock().looping_call(
-                run_as_background_process,
-                60 * 1000,
-                "_count_known_servers",
-                self._count_known_servers,
-            )
-            self.hs.get_clock().call_later(
-                1000,
-                run_as_background_process,
-                "_count_known_servers",
-                self._count_known_servers,
-            )
-            LaterGauge(
-                "synapse_federation_known_servers",
-                "",
-                [],
-                lambda: self._known_servers_count,
-            )
-
-    @defer.inlineCallbacks
-    def _count_known_servers(self):
-        """
-        Count the servers that this server knows about.
-
-        The statistic is stored on the class for the
-        `synapse_federation_known_servers` LaterGauge to collect.
-        """
-
-        def _transact(txn):
-            if isinstance(self.database_engine, Sqlite3Engine):
-                query = """
-                    SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
-                    FROM (
-                        SELECT rm.user_id as user_id, instr(rm.user_id, ':')
-                            AS pos FROM room_memberships as rm
-                        INNER JOIN current_state_events as c ON rm.event_id = c.event_id
-                        WHERE c.type = 'm.room.member'
-                    ) as out
-                """
-            else:
-                query = """
-                    SELECT COUNT(DISTINCT split_part(state_key, ':', 2))
-                    FROM current_state_events
-                    WHERE type = 'm.room.member' AND membership = 'join';
-                """
-            txn.execute(query)
-            return list(txn)[0][0]
-
-        count = yield self.runInteraction("get_known_servers", _transact)
-
-        # We always know about ourselves, even if we have nothing in
-        # room_memberships (for example, the server is new).
-        self._known_servers_count = max([count, 1])
-        return self._known_servers_count
-
-    def _check_safe_current_state_events_membership_updated_txn(self, txn):
-        """Checks if it is safe to assume the new current_state_events
-        membership column is up to date
-        """
-
-        pending_update = self._simple_select_one_txn(
-            txn,
-            table="background_updates",
-            keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
-            retcols=["update_name"],
-            allow_none=True,
-        )
-
-        self._current_state_events_membership_up_to_date = not pending_update
-
-        # If the update is still running, reschedule to run.
-        if pending_update:
-            self._clock.call_later(
-                15.0,
-                run_as_background_process,
-                "_check_safe_current_state_events_membership_updated",
-                self.runInteraction,
-                "_check_safe_current_state_events_membership_updated",
-                self._check_safe_current_state_events_membership_updated_txn,
-            )
-
-    @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
-    def get_hosts_in_room(self, room_id, cache_context):
-        """Returns the set of all hosts currently in the room
-        """
-        user_ids = yield self.get_users_in_room(
-            room_id, on_invalidate=cache_context.invalidate
-        )
-        hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
-        return hosts
-
-    @cached(max_entries=100000, iterable=True)
-    def get_users_in_room(self, room_id):
-        return self.runInteraction(
-            "get_users_in_room", self.get_users_in_room_txn, room_id
-        )
-
-    def get_users_in_room_txn(self, txn, room_id):
-        # If we can assume current_state_events.membership is up to date
-        # then we can avoid a join, which is a Very Good Thing given how
-        # frequently this function gets called.
-        if self._current_state_events_membership_up_to_date:
-            sql = """
-                SELECT state_key FROM current_state_events
-                WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
-            """
-        else:
-            sql = """
-                SELECT state_key FROM room_memberships as m
-                INNER JOIN current_state_events as c
-                ON m.event_id = c.event_id
-                AND m.room_id = c.room_id
-                AND m.user_id = c.state_key
-                WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
-            """
-
-        txn.execute(sql, (room_id, Membership.JOIN))
-        return [to_ascii(r[0]) for r in txn]
-
-    @cached(max_entries=100000)
-    def get_room_summary(self, room_id):
-        """ Get the details of a room roughly suitable for use by the room
-        summary extension to /sync. Useful when lazy loading room members.
-        Args:
-            room_id (str): The room ID to query
-        Returns:
-            Deferred[dict[str, MemberSummary]:
-                dict of membership states, pointing to a MemberSummary named tuple.
-        """
-
-        def _get_room_summary_txn(txn):
-            # first get counts.
-            # We do this all in one transaction to keep the cache small.
-            # FIXME: get rid of this when we have room_stats
-
-            # If we can assume current_state_events.membership is up to date
-            # then we can avoid a join, which is a Very Good Thing given how
-            # frequently this function gets called.
-            if self._current_state_events_membership_up_to_date:
-                # Note, rejected events will have a null membership field, so
-                # we we manually filter them out.
-                sql = """
-                    SELECT count(*), membership FROM current_state_events
-                    WHERE type = 'm.room.member' AND room_id = ?
-                        AND membership IS NOT NULL
-                    GROUP BY membership
-                """
-            else:
-                sql = """
-                    SELECT count(*), m.membership FROM room_memberships as m
-                    INNER JOIN current_state_events as c
-                    ON m.event_id = c.event_id
-                    AND m.room_id = c.room_id
-                    AND m.user_id = c.state_key
-                    WHERE c.type = 'm.room.member' AND c.room_id = ?
-                    GROUP BY m.membership
-                """
-
-            txn.execute(sql, (room_id,))
-            res = {}
-            for count, membership in txn:
-                summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
-
-            # we order by membership and then fairly arbitrarily by event_id so
-            # heroes are consistent
-            if self._current_state_events_membership_up_to_date:
-                # Note, rejected events will have a null membership field, so
-                # we we manually filter them out.
-                sql = """
-                    SELECT state_key, membership, event_id
-                    FROM current_state_events
-                    WHERE type = 'm.room.member' AND room_id = ?
-                        AND membership IS NOT NULL
-                    ORDER BY
-                        CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
-                        event_id ASC
-                    LIMIT ?
-                """
-            else:
-                sql = """
-                    SELECT c.state_key, m.membership, c.event_id
-                    FROM room_memberships as m
-                    INNER JOIN current_state_events as c USING (room_id, event_id)
-                    WHERE c.type = 'm.room.member' AND c.room_id = ?
-                    ORDER BY
-                        CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
-                        c.event_id ASC
-                    LIMIT ?
-                """
-
-            # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
-            txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
-            for user_id, membership, event_id in txn:
-                summary = res[to_ascii(membership)]
-                # we will always have a summary for this membership type at this
-                # point given the summary currently contains the counts.
-                members = summary.members
-                members.append((to_ascii(user_id), to_ascii(event_id)))
-
-            return res
-
-        return self.runInteraction("get_room_summary", _get_room_summary_txn)
-
-    def _get_user_counts_in_room_txn(self, txn, room_id):
-        """
-        Get the user count in a room by membership.
-
-        Args:
-            room_id (str)
-            membership (Membership)
-
-        Returns:
-            Deferred[int]
-        """
-        sql = """
-        SELECT m.membership, count(*) FROM room_memberships as m
-            INNER JOIN current_state_events as c USING(event_id)
-            WHERE c.type = 'm.room.member' AND c.room_id = ?
-            GROUP BY m.membership
-        """
-
-        txn.execute(sql, (room_id,))
-        return {row[0]: row[1] for row in txn}
-
-    @cached()
-    def get_invited_rooms_for_user(self, user_id):
-        """ Get all the rooms the user is invited to
-        Args:
-            user_id (str): The user ID.
-        Returns:
-            A deferred list of RoomsForUser.
-        """
-
-        return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
-
-    @defer.inlineCallbacks
-    def get_invite_for_user_in_room(self, user_id, room_id):
-        """Gets the invite for the given user and room
-
-        Args:
-            user_id (str)
-            room_id (str)
-
-        Returns:
-            Deferred: Resolves to either a RoomsForUser or None if no invite was
-                found.
-        """
-        invites = yield self.get_invited_rooms_for_user(user_id)
-        for invite in invites:
-            if invite.room_id == room_id:
-                return invite
-        return None
-
-    @defer.inlineCallbacks
-    def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
-        """ Get all the rooms for this user where the membership for this user
-        matches one in the membership list.
-
-        Filters out forgotten rooms.
-
-        Args:
-            user_id (str): The user ID.
-            membership_list (list): A list of synapse.api.constants.Membership
-            values which the user must be in.
-
-        Returns:
-            Deferred[list[RoomsForUser]]
-        """
-        if not membership_list:
-            return defer.succeed(None)
-
-        rooms = yield self.runInteraction(
-            "get_rooms_for_user_where_membership_is",
-            self._get_rooms_for_user_where_membership_is_txn,
-            user_id,
-            membership_list,
-        )
-
-        # Now we filter out forgotten rooms
-        forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
-        return [room for room in rooms if room.room_id not in forgotten_rooms]
-
-    def _get_rooms_for_user_where_membership_is_txn(
-        self, txn, user_id, membership_list
-    ):
-
-        do_invite = Membership.INVITE in membership_list
-        membership_list = [m for m in membership_list if m != Membership.INVITE]
-
-        results = []
-        if membership_list:
-            if self._current_state_events_membership_up_to_date:
-                clause, args = make_in_list_sql_clause(
-                    self.database_engine, "c.membership", membership_list
-                )
-                sql = """
-                    SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
-                    FROM current_state_events AS c
-                    INNER JOIN events AS e USING (room_id, event_id)
-                    WHERE
-                        c.type = 'm.room.member'
-                        AND state_key = ?
-                        AND %s
-                """ % (
-                    clause,
-                )
-            else:
-                clause, args = make_in_list_sql_clause(
-                    self.database_engine, "m.membership", membership_list
-                )
-                sql = """
-                    SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
-                    FROM current_state_events AS c
-                    INNER JOIN room_memberships AS m USING (room_id, event_id)
-                    INNER JOIN events AS e USING (room_id, event_id)
-                    WHERE
-                        c.type = 'm.room.member'
-                        AND state_key = ?
-                        AND %s
-                """ % (
-                    clause,
-                )
-
-            txn.execute(sql, (user_id, *args))
-            results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
-
-        if do_invite:
-            sql = (
-                "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
-                " FROM local_invites as i"
-                " INNER JOIN events as e USING (event_id)"
-                " WHERE invitee = ? AND locally_rejected is NULL"
-                " AND replaced_by is NULL"
-            )
-
-            txn.execute(sql, (user_id,))
-            results.extend(
-                RoomsForUser(
-                    room_id=r["room_id"],
-                    sender=r["inviter"],
-                    event_id=r["event_id"],
-                    stream_ordering=r["stream_ordering"],
-                    membership=Membership.INVITE,
-                )
-                for r in self.cursor_to_dict(txn)
-            )
-
-        return results
-
-    @cachedInlineCallbacks(max_entries=500000, iterable=True)
-    def get_rooms_for_user_with_stream_ordering(self, user_id):
-        """Returns a set of room_ids the user is currently joined to
-
-        Args:
-            user_id (str)
-
-        Returns:
-            Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
-            the rooms the user is in currently, along with the stream ordering
-            of the most recent join for that user and room.
-        """
-        rooms = yield self.get_rooms_for_user_where_membership_is(
-            user_id, membership_list=[Membership.JOIN]
-        )
-        return frozenset(
-            GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
-            for r in rooms
-        )
-
-    @defer.inlineCallbacks
-    def get_rooms_for_user(self, user_id, on_invalidate=None):
-        """Returns a set of room_ids the user is currently joined to
-        """
-        rooms = yield self.get_rooms_for_user_with_stream_ordering(
-            user_id, on_invalidate=on_invalidate
-        )
-        return frozenset(r.room_id for r in rooms)
-
-    @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
-    def get_users_who_share_room_with_user(self, user_id, cache_context):
-        """Returns the set of users who share a room with `user_id`
-        """
-        room_ids = yield self.get_rooms_for_user(
-            user_id, on_invalidate=cache_context.invalidate
-        )
-
-        user_who_share_room = set()
-        for room_id in room_ids:
-            user_ids = yield self.get_users_in_room(
-                room_id, on_invalidate=cache_context.invalidate
-            )
-            user_who_share_room.update(user_ids)
-
-        return user_who_share_room
-
-    @defer.inlineCallbacks
-    def get_joined_users_from_context(self, event, context):
-        state_group = context.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        current_state_ids = yield context.get_current_state_ids(self)
-        result = yield self._get_joined_users_from_context(
-            event.room_id, state_group, current_state_ids, event=event, context=context
-        )
-        return result
-
-    @defer.inlineCallbacks
-    def get_joined_users_from_state(self, room_id, state_entry):
-        state_group = state_entry.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        with Measure(self._clock, "get_joined_users_from_state"):
-            return (
-                yield self._get_joined_users_from_context(
-                    room_id, state_group, state_entry.state, context=state_entry
-                )
-            )
-
-    @cachedInlineCallbacks(
-        num_args=2, cache_context=True, iterable=True, max_entries=100000
-    )
-    def _get_joined_users_from_context(
-        self,
-        room_id,
-        state_group,
-        current_state_ids,
-        cache_context,
-        event=None,
-        context=None,
-    ):
-        # We don't use `state_group`, it's there so that we can cache based
-        # on it. However, it's important that it's never None, since two current_states
-        # with a state_group of None are likely to be different.
-        # See bulk_get_push_rules_for_room for how we work around this.
-        assert state_group is not None
-
-        users_in_room = {}
-        member_event_ids = [
-            e_id
-            for key, e_id in iteritems(current_state_ids)
-            if key[0] == EventTypes.Member
-        ]
-
-        if context is not None:
-            # If we have a context with a delta from a previous state group,
-            # check if we also have the result from the previous group in cache.
-            # If we do then we can reuse that result and simply update it with
-            # any membership changes in `delta_ids`
-            if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_users_from_context.cache.get(
-                    (room_id, context.prev_group), None
-                )
-                if prev_res and isinstance(prev_res, dict):
-                    users_in_room = dict(prev_res)
-                    member_event_ids = [
-                        e_id
-                        for key, e_id in iteritems(context.delta_ids)
-                        if key[0] == EventTypes.Member
-                    ]
-                    for etype, state_key in context.delta_ids:
-                        users_in_room.pop(state_key, None)
-
-        # We check if we have any of the member event ids in the event cache
-        # before we ask the DB
-
-        # We don't update the event cache hit ratio as it completely throws off
-        # the hit ratio counts. After all, we don't populate the cache if we
-        # miss it here
-        event_map = self._get_events_from_cache(
-            member_event_ids, allow_rejected=False, update_metrics=False
-        )
-
-        missing_member_event_ids = []
-        for event_id in member_event_ids:
-            ev_entry = event_map.get(event_id)
-            if ev_entry:
-                if ev_entry.event.membership == Membership.JOIN:
-                    users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
-                        display_name=to_ascii(
-                            ev_entry.event.content.get("displayname", None)
-                        ),
-                        avatar_url=to_ascii(
-                            ev_entry.event.content.get("avatar_url", None)
-                        ),
-                    )
-            else:
-                missing_member_event_ids.append(event_id)
-
-        if missing_member_event_ids:
-            event_to_memberships = yield self._get_joined_profiles_from_event_ids(
-                missing_member_event_ids
-            )
-            users_in_room.update((row for row in event_to_memberships.values() if row))
-
-        if event is not None and event.type == EventTypes.Member:
-            if event.membership == Membership.JOIN:
-                if event.event_id in member_event_ids:
-                    users_in_room[to_ascii(event.state_key)] = ProfileInfo(
-                        display_name=to_ascii(event.content.get("displayname", None)),
-                        avatar_url=to_ascii(event.content.get("avatar_url", None)),
-                    )
-
-        return users_in_room
-
-    @cached(max_entries=10000)
-    def _get_joined_profile_from_event_id(self, event_id):
-        raise NotImplementedError()
-
-    @cachedList(
-        cached_method_name="_get_joined_profile_from_event_id",
-        list_name="event_ids",
-        inlineCallbacks=True,
-    )
-    def _get_joined_profiles_from_event_ids(self, event_ids):
-        """For given set of member event_ids check if they point to a join
-        event and if so return the associated user and profile info.
-
-        Args:
-            event_ids (Iterable[str]): The member event IDs to lookup
-
-        Returns:
-            Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
-            to `user_id` and ProfileInfo (or None if not join event).
-        """
-
-        rows = yield self._simple_select_many_batch(
-            table="room_memberships",
-            column="event_id",
-            iterable=event_ids,
-            retcols=("user_id", "display_name", "avatar_url", "event_id"),
-            keyvalues={"membership": Membership.JOIN},
-            batch_size=500,
-            desc="_get_membership_from_event_ids",
-        )
-
-        return {
-            row["event_id"]: (
-                row["user_id"],
-                ProfileInfo(
-                    avatar_url=row["avatar_url"], display_name=row["display_name"]
-                ),
-            )
-            for row in rows
-        }
-
-    @cachedInlineCallbacks(max_entries=10000)
-    def is_host_joined(self, room_id, host):
-        if "%" in host or "_" in host:
-            raise Exception("Invalid host name")
-
-        sql = """
-            SELECT state_key FROM current_state_events AS c
-            INNER JOIN room_memberships AS m USING (event_id)
-            WHERE m.membership = 'join'
-                AND type = 'm.room.member'
-                AND c.room_id = ?
-                AND state_key LIKE ?
-            LIMIT 1
-        """
-
-        # We do need to be careful to ensure that host doesn't have any wild cards
-        # in it, but we checked above for known ones and we'll check below that
-        # the returned user actually has the correct domain.
-        like_clause = "%:" + host
-
-        rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
-
-        if not rows:
-            return False
-
-        user_id = rows[0][0]
-        if get_domain_from_id(user_id) != host:
-            # This can only happen if the host name has something funky in it
-            raise Exception("Invalid host name")
-
-        return True
-
-    @cachedInlineCallbacks()
-    def was_host_joined(self, room_id, host):
-        """Check whether the server is or ever was in the room.
-
-        Args:
-            room_id (str)
-            host (str)
-
-        Returns:
-            Deferred: Resolves to True if the host is/was in the room, otherwise
-            False.
-        """
-        if "%" in host or "_" in host:
-            raise Exception("Invalid host name")
-
-        sql = """
-            SELECT user_id FROM room_memberships
-            WHERE room_id = ?
-                AND user_id LIKE ?
-                AND membership = 'join'
-            LIMIT 1
-        """
-
-        # We do need to be careful to ensure that host doesn't have any wild cards
-        # in it, but we checked above for known ones and we'll check below that
-        # the returned user actually has the correct domain.
-        like_clause = "%:" + host
-
-        rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
-
-        if not rows:
-            return False
-
-        user_id = rows[0][0]
-        if get_domain_from_id(user_id) != host:
-            # This can only happen if the host name has something funky in it
-            raise Exception("Invalid host name")
-
-        return True
-
-    @defer.inlineCallbacks
-    def get_joined_hosts(self, room_id, state_entry):
-        state_group = state_entry.state_group
-        if not state_group:
-            # If state_group is None it means it has yet to be assigned a
-            # state group, i.e. we need to make sure that calls with a state_group
-            # of None don't hit previous cached calls with a None state_group.
-            # To do this we set the state_group to a new object as object() != object()
-            state_group = object()
-
-        with Measure(self._clock, "get_joined_hosts"):
-            return (
-                yield self._get_joined_hosts(
-                    room_id, state_group, state_entry.state, state_entry=state_entry
-                )
-            )
-
-    @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
-    # @defer.inlineCallbacks
-    def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
-        # We don't use `state_group`, its there so that we can cache based
-        # on it. However, its important that its never None, since two current_state's
-        # with a state_group of None are likely to be different.
-        # See bulk_get_push_rules_for_room for how we work around this.
-        assert state_group is not None
-
-        cache = self._get_joined_hosts_cache(room_id)
-        joined_hosts = yield cache.get_destinations(state_entry)
-
-        return joined_hosts
-
-    @cached(max_entries=10000)
-    def _get_joined_hosts_cache(self, room_id):
-        return _JoinedHostsCache(self, room_id)
-
-    @cachedInlineCallbacks(num_args=2)
-    def did_forget(self, user_id, room_id):
-        """Returns whether user_id has elected to discard history for room_id.
-
-        Returns False if they have since re-joined."""
-
-        def f(txn):
-            sql = (
-                "SELECT"
-                "  COUNT(*)"
-                " FROM"
-                "  room_memberships"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-                " AND"
-                "  forgotten = 0"
-            )
-            txn.execute(sql, (user_id, room_id))
-            rows = txn.fetchall()
-            return rows[0][0]
-
-        count = yield self.runInteraction("did_forget_membership", f)
-        return count == 0
-
-    @cached()
-    def get_forgotten_rooms_for_user(self, user_id):
-        """Gets all rooms the user has forgotten.
-
-        Args:
-            user_id (str)
-
-        Returns:
-            Deferred[set[str]]
-        """
-
-        def _get_forgotten_rooms_for_user_txn(txn):
-            # This is a slightly convoluted query that first looks up all rooms
-            # that the user has forgotten in the past, then rechecks that list
-            # to see if any have subsequently been updated. This is done so that
-            # we can use a partial index on `forgotten = 1` on the assumption
-            # that few users will actually forget many rooms.
-            #
-            # Note that a room is considered "forgotten" if *all* membership
-            # events for that user and room have the forgotten field set (as
-            # when a user forgets a room we update all rows for that user and
-            # room, not just the current one).
-            sql = """
-                SELECT room_id, (
-                    SELECT count(*) FROM room_memberships
-                    WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0
-                ) AS count
-                FROM room_memberships AS m
-                WHERE user_id = ? AND forgotten = 1
-                GROUP BY room_id, user_id;
-            """
-            txn.execute(sql, (user_id,))
-            return set(row[0] for row in txn if row[1] == 0)
-
-        return self.runInteraction(
-            "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
-        )
-
-    @defer.inlineCallbacks
-    def get_rooms_user_has_been_in(self, user_id):
-        """Get all rooms that the user has ever been in.
-
-        Args:
-            user_id (str)
-
-        Returns:
-            Deferred[set[str]]: Set of room IDs.
-        """
-
-        room_ids = yield self._simple_select_onecol(
-            table="room_memberships",
-            keyvalues={"membership": Membership.JOIN, "user_id": user_id},
-            retcol="room_id",
-            desc="get_rooms_user_has_been_in",
-        )
-
-        return set(room_ids)
-
-
-class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
-            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
-        )
-        self.register_background_update_handler(
-            _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
-            self._background_current_state_membership,
-        )
-        self.register_background_index_update(
-            "room_membership_forgotten_idx",
-            index_name="room_memberships_user_room_forgotten",
-            table="room_memberships",
-            columns=["user_id", "room_id"],
-            where_clause="forgotten = 1",
-        )
-
-    @defer.inlineCallbacks
-    def _background_add_membership_profile(self, progress, batch_size):
-        target_min_stream_id = progress.get(
-            "target_min_stream_id_inclusive", self._min_stream_order_on_start
-        )
-        max_stream_id = progress.get(
-            "max_stream_id_exclusive", self._stream_order_on_start + 1
-        )
-
-        INSERT_CLUMP_SIZE = 1000
-
-        def add_membership_profile_txn(txn):
-            sql = """
-                SELECT stream_ordering, event_id, events.room_id, event_json.json
-                FROM events
-                INNER JOIN event_json USING (event_id)
-                INNER JOIN room_memberships USING (event_id)
-                WHERE ? <= stream_ordering AND stream_ordering < ?
-                AND type = 'm.room.member'
-                ORDER BY stream_ordering DESC
-                LIMIT ?
-            """
-
-            txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
-
-            rows = self.cursor_to_dict(txn)
-            if not rows:
-                return 0
-
-            min_stream_id = rows[-1]["stream_ordering"]
-
-            to_update = []
-            for row in rows:
-                event_id = row["event_id"]
-                room_id = row["room_id"]
-                try:
-                    event_json = json.loads(row["json"])
-                    content = event_json["content"]
-                except Exception:
-                    continue
-
-                display_name = content.get("displayname", None)
-                avatar_url = content.get("avatar_url", None)
-
-                if display_name or avatar_url:
-                    to_update.append((display_name, avatar_url, event_id, room_id))
-
-            to_update_sql = """
-                UPDATE room_memberships SET display_name = ?, avatar_url = ?
-                WHERE event_id = ? AND room_id = ?
-            """
-            for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
-                clump = to_update[index : index + INSERT_CLUMP_SIZE]
-                txn.executemany(to_update_sql, clump)
-
-            progress = {
-                "target_min_stream_id_inclusive": target_min_stream_id,
-                "max_stream_id_exclusive": min_stream_id,
-            }
-
-            self._background_update_progress_txn(
-                txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
-            )
-
-            return len(rows)
-
-        result = yield self.runInteraction(
-            _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
-        )
-
-        if not result:
-            yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
-
-        return result
-
-    @defer.inlineCallbacks
-    def _background_current_state_membership(self, progress, batch_size):
-        """Update the new membership column on current_state_events.
-
-        This works by iterating over all rooms in alphebetical order.
-        """
-
-        def _background_current_state_membership_txn(txn, last_processed_room):
-            processed = 0
-            while processed < batch_size:
-                txn.execute(
-                    """
-                        SELECT MIN(room_id) FROM current_state_events WHERE room_id > ?
-                    """,
-                    (last_processed_room,),
-                )
-                row = txn.fetchone()
-                if not row or not row[0]:
-                    return processed, True
-
-                next_room, = row
-
-                sql = """
-                    UPDATE current_state_events
-                    SET membership = (
-                        SELECT membership FROM room_memberships
-                        WHERE event_id = current_state_events.event_id
-                    )
-                    WHERE room_id = ?
-                """
-                txn.execute(sql, (next_room,))
-                processed += txn.rowcount
-
-                last_processed_room = next_room
-
-            self._background_update_progress_txn(
-                txn,
-                _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
-                {"last_processed_room": last_processed_room},
-            )
-
-            return processed, False
-
-        # If we haven't got a last processed room then just use the empty
-        # string, which will compare before all room IDs correctly.
-        last_processed_room = progress.get("last_processed_room", "")
-
-        row_count, finished = yield self.runInteraction(
-            "_background_current_state_membership_update",
-            _background_current_state_membership_txn,
-            last_processed_room,
-        )
-
-        if finished:
-            yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME)
-
-        return row_count
-
-
-class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberStore, self).__init__(db_conn, hs)
-
-    def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database.
-        """
-        self._simple_insert_many_txn(
-            txn,
-            table="room_memberships",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "user_id": event.state_key,
-                    "sender": event.user_id,
-                    "room_id": event.room_id,
-                    "membership": event.membership,
-                    "display_name": event.content.get("displayname", None),
-                    "avatar_url": event.content.get("avatar_url", None),
-                }
-                for event in events
-            ],
-        )
-
-        for event in events:
-            txn.call_after(
-                self._membership_stream_cache.entity_has_changed,
-                event.state_key,
-                event.internal_metadata.stream_ordering,
-            )
-            txn.call_after(
-                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
-            )
-
-            # We update the local_invites table only if the event is "current",
-            # i.e., its something that has just happened. If the event is an
-            # outlier it is only current if its an "out of band membership",
-            # like a remote invite or a rejection of a remote invite.
-            is_new_state = not backfilled and (
-                not event.internal_metadata.is_outlier()
-                or event.internal_metadata.is_out_of_band_membership()
-            )
-            is_mine = self.hs.is_mine_id(event.state_key)
-            if is_new_state and is_mine:
-                if event.membership == Membership.INVITE:
-                    self._simple_insert_txn(
-                        txn,
-                        table="local_invites",
-                        values={
-                            "event_id": event.event_id,
-                            "invitee": event.state_key,
-                            "inviter": event.sender,
-                            "room_id": event.room_id,
-                            "stream_id": event.internal_metadata.stream_ordering,
-                        },
-                    )
-                else:
-                    sql = (
-                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
-                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-                        " AND replaced_by is NULL"
-                    )
-
-                    txn.execute(
-                        sql,
-                        (
-                            event.internal_metadata.stream_ordering,
-                            event.event_id,
-                            event.room_id,
-                            event.state_key,
-                        ),
-                    )
-
-    @defer.inlineCallbacks
-    def locally_reject_invite(self, user_id, room_id):
-        sql = (
-            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
-            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-            " AND replaced_by is NULL"
-        )
-
-        def f(txn, stream_ordering):
-            txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
-        with self._stream_id_gen.get_next() as stream_ordering:
-            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
-    def forget(self, user_id, room_id):
-        """Indicate that user_id wishes to discard history for room_id."""
-
-        def f(txn):
-            sql = (
-                "UPDATE"
-                "  room_memberships"
-                " SET"
-                "  forgotten = 1"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-            )
-            txn.execute(sql, (user_id, room_id))
-
-            self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
-            self._invalidate_cache_and_stream(
-                txn, self.get_forgotten_rooms_for_user, (user_id,)
-            )
-
-        return self.runInteraction("forget_membership", f)
-
-
-class _JoinedHostsCache(object):
-    """Cache for joined hosts in a room that is optimised to handle updates
-    via state deltas.
-    """
-
-    def __init__(self, store, room_id):
-        self.store = store
-        self.room_id = room_id
-
-        self.hosts_to_joined_users = {}
-
-        self.state_group = object()
-
-        self.linearizer = Linearizer("_JoinedHostsCache")
-
-        self._len = 0
-
-    @defer.inlineCallbacks
-    def get_destinations(self, state_entry):
-        """Get set of destinations for a state entry
-
-        Args:
-            state_entry(synapse.state._StateCacheEntry)
-        """
-        if state_entry.state_group == self.state_group:
-            return frozenset(self.hosts_to_joined_users)
-
-        with (yield self.linearizer.queue(())):
-            if state_entry.state_group == self.state_group:
-                pass
-            elif state_entry.prev_group == self.state_group:
-                for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
-                    if typ != EventTypes.Member:
-                        continue
-
-                    host = intern_string(get_domain_from_id(state_key))
-                    user_id = state_key
-                    known_joins = self.hosts_to_joined_users.setdefault(host, set())
-
-                    event = yield self.store.get_event(event_id)
-                    if event.membership == Membership.JOIN:
-                        known_joins.add(user_id)
-                    else:
-                        known_joins.discard(user_id)
-
-                        if not known_joins:
-                            self.hosts_to_joined_users.pop(host, None)
-            else:
-                joined_users = yield self.store.get_joined_users_from_state(
-                    self.room_id, state_entry
-                )
-
-                self.hosts_to_joined_users = {}
-                for user_id in joined_users:
-                    host = intern_string(get_domain_from_id(user_id))
-                    self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
-
-            if state_entry.state_group:
-                self.state_group = state_entry.state_group
-            else:
-                self.state_group = object()
-            self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
-        return frozenset(self.hosts_to_joined_users)
-
-    def __len__(self):
-        return self._len
diff --git a/synapse/storage/schema/delta/35/00background_updates_add_col.sql b/synapse/storage/schema/delta/35/00background_updates_add_col.sql
new file mode 100644
index 0000000000..c2d2a4f836
--- /dev/null
+++ b/synapse/storage/schema/delta/35/00background_updates_add_col.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ALTER TABLE background_updates ADD COLUMN depends_on TEXT;
diff --git a/synapse/storage/schema/full_schemas/54/full.sql b/synapse/storage/schema/full_schemas/54/full.sql
new file mode 100644
index 0000000000..1005880466
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/54/full.sql
@@ -0,0 +1,8 @@
+
+
+CREATE TABLE background_updates (
+    update_name text NOT NULL,
+    progress_json text NOT NULL,
+    depends_on text,
+    CONSTRAINT background_updates_uniqueness UNIQUE (update_name)
+);
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index a941a5ae3f..a2df8fa827 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,45 +14,16 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
 
 from six import iteritems, itervalues
-from six.moves import range
 
 import attr
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes
-from synapse.api.errors import NotFoundError
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.background_updates import BackgroundUpdateStore
-from synapse.storage.engines import PostgresEngine
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.util.caches import get_cache_factor_for, intern_string
-from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.dictionary_cache import DictionaryCache
-from synapse.util.stringutils import to_ascii
 
 logger = logging.getLogger(__name__)
 
 
-MAX_STATE_DELTA_HOPS = 100
-
-
-class _GetStateGroupDelta(
-    namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
-    """Return type of get_state_group_delta that implements __len__, which lets
-    us use the itrable flag when caching
-    """
-
-    __slots__ = []
-
-    def __len__(self):
-        return len(self.delta_ids) if self.delta_ids else 0
-
-
 @attr.s(slots=True)
 class StateFilter(object):
     """A filter used when querying for state.
@@ -351,1195 +322,3 @@ class StateFilter(object):
         )
 
         return member_filter, non_member_filter
-
-
-class StateGroupBackgroundUpdateStore(SQLBaseStore):
-    """Defines functions related to state groups needed to run the state backgroud
-    updates.
-    """
-
-    def _count_state_group_hops_txn(self, txn, state_group):
-        """Given a state group, count how many hops there are in the tree.
-
-        This is used to ensure the delta chains don't get too long.
-        """
-        if isinstance(self.database_engine, PostgresEngine):
-            sql = """
-                WITH RECURSIVE state(state_group) AS (
-                    VALUES(?::bigint)
-                    UNION ALL
-                    SELECT prev_state_group FROM state_group_edges e, state s
-                    WHERE s.state_group = e.state_group
-                )
-                SELECT count(*) FROM state;
-            """
-
-            txn.execute(sql, (state_group,))
-            row = txn.fetchone()
-            if row and row[0]:
-                return row[0]
-            else:
-                return 0
-        else:
-            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
-            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
-            next_group = state_group
-            count = 0
-
-            while next_group:
-                next_group = self._simple_select_one_onecol_txn(
-                    txn,
-                    table="state_group_edges",
-                    keyvalues={"state_group": next_group},
-                    retcol="prev_state_group",
-                    allow_none=True,
-                )
-                if next_group:
-                    count += 1
-
-            return count
-
-    def _get_state_groups_from_groups_txn(
-        self, txn, groups, state_filter=StateFilter.all()
-    ):
-        results = {group: {} for group in groups}
-
-        where_clause, where_args = state_filter.make_sql_filter_clause()
-
-        # Unless the filter clause is empty, we're going to append it after an
-        # existing where clause
-        if where_clause:
-            where_clause = " AND (%s)" % (where_clause,)
-
-        if isinstance(self.database_engine, PostgresEngine):
-            # Temporarily disable sequential scans in this transaction. This is
-            # a temporary hack until we can add the right indices in
-            txn.execute("SET LOCAL enable_seqscan=off")
-
-            # The below query walks the state_group tree so that the "state"
-            # table includes all state_groups in the tree. It then joins
-            # against `state_groups_state` to fetch the latest state.
-            # It assumes that previous state groups are always numerically
-            # lesser.
-            # The PARTITION is used to get the event_id in the greatest state
-            # group for the given type, state_key.
-            # This may return multiple rows per (type, state_key), but last_value
-            # should be the same.
-            sql = """
-                WITH RECURSIVE state(state_group) AS (
-                    VALUES(?::bigint)
-                    UNION ALL
-                    SELECT prev_state_group FROM state_group_edges e, state s
-                    WHERE s.state_group = e.state_group
-                )
-                SELECT DISTINCT type, state_key, last_value(event_id) OVER (
-                    PARTITION BY type, state_key ORDER BY state_group ASC
-                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
-                ) AS event_id FROM state_groups_state
-                WHERE state_group IN (
-                    SELECT state_group FROM state
-                )
-            """
-
-            for group in groups:
-                args = [group]
-                args.extend(where_args)
-
-                txn.execute(sql + where_clause, args)
-                for row in txn:
-                    typ, state_key, event_id = row
-                    key = (typ, state_key)
-                    results[group][key] = event_id
-        else:
-            max_entries_returned = state_filter.max_entries_returned()
-
-            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
-            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
-            for group in groups:
-                next_group = group
-
-                while next_group:
-                    # We did this before by getting the list of group ids, and
-                    # then passing that list to sqlite to get latest event for
-                    # each (type, state_key). However, that was terribly slow
-                    # without the right indices (which we can't add until
-                    # after we finish deduping state, which requires this func)
-                    args = [next_group]
-                    args.extend(where_args)
-
-                    txn.execute(
-                        "SELECT type, state_key, event_id FROM state_groups_state"
-                        " WHERE state_group = ? " + where_clause,
-                        args,
-                    )
-                    results[group].update(
-                        ((typ, state_key), event_id)
-                        for typ, state_key, event_id in txn
-                        if (typ, state_key) not in results[group]
-                    )
-
-                    # If the number of entries in the (type,state_key)->event_id dict
-                    # matches the number of (type,state_keys) types we were searching
-                    # for, then we must have found them all, so no need to go walk
-                    # further down the tree... UNLESS our types filter contained
-                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
-                    # search
-                    if (
-                        max_entries_returned is not None
-                        and len(results[group]) == max_entries_returned
-                    ):
-                        break
-
-                    next_group = self._simple_select_one_onecol_txn(
-                        txn,
-                        table="state_group_edges",
-                        keyvalues={"state_group": next_group},
-                        retcol="prev_state_group",
-                        allow_none=True,
-                    )
-
-        return results
-
-
-# this inherits from EventsWorkerStore because it calls self.get_events
-class StateGroupWorkerStore(
-    EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore
-):
-    """The parts of StateGroupStore that can be called from workers.
-    """
-
-    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
-    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
-    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-
-    def __init__(self, db_conn, hs):
-        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
-
-        # Originally the state store used a single DictionaryCache to cache the
-        # event IDs for the state types in a given state group to avoid hammering
-        # on the state_group* tables.
-        #
-        # The point of using a DictionaryCache is that it can cache a subset
-        # of the state events for a given state group (i.e. a subset of the keys for a
-        # given dict which is an entry in the cache for a given state group ID).
-        #
-        # However, this poses problems when performing complicated queries
-        # on the store - for instance: "give me all the state for this group, but
-        # limit members to this subset of users", as DictionaryCache's API isn't
-        # rich enough to say "please cache any of these fields, apart from this subset".
-        # This is problematic when lazy loading members, which requires this behaviour,
-        # as without it the cache has no choice but to speculatively load all
-        # state events for the group, which negates the efficiency being sought.
-        #
-        # Rather than overcomplicating DictionaryCache's API, we instead split the
-        # state_group_cache into two halves - one for tracking non-member events,
-        # and the other for tracking member_events.  This means that lazy loading
-        # queries can be made in a cache-friendly manner by querying both caches
-        # separately and then merging the result.  So for the example above, you
-        # would query the members cache for a specific subset of state keys
-        # (which DictionaryCache will handle efficiently and fine) and the non-members
-        # cache for all state (which DictionaryCache will similarly handle fine)
-        # and then just merge the results together.
-        #
-        # We size the non-members cache to be smaller than the members cache as the
-        # vast majority of state in Matrix (today) is member events.
-
-        self._state_group_cache = DictionaryCache(
-            "*stateGroupCache*",
-            # TODO: this hasn't been tuned yet
-            50000 * get_cache_factor_for("stateGroupCache"),
-        )
-        self._state_group_members_cache = DictionaryCache(
-            "*stateGroupMembersCache*",
-            500000 * get_cache_factor_for("stateGroupMembersCache"),
-        )
-
-    @defer.inlineCallbacks
-    def get_room_version(self, room_id):
-        """Get the room_version of a given room
-
-        Args:
-            room_id (str)
-
-        Returns:
-            Deferred[str]
-
-        Raises:
-            NotFoundError if the room is unknown
-        """
-        # for now we do this by looking at the create event. We may want to cache this
-        # more intelligently in future.
-
-        # Retrieve the room's create event
-        create_event = yield self.get_create_event_for_room(room_id)
-        return create_event.content.get("room_version", "1")
-
-    @defer.inlineCallbacks
-    def get_room_predecessor(self, room_id):
-        """Get the predecessor room of an upgraded room if one exists.
-        Otherwise return None.
-
-        Args:
-            room_id (str)
-
-        Returns:
-            Deferred[unicode|None]: predecessor room id
-
-        Raises:
-            NotFoundError if the room is unknown
-        """
-        # Retrieve the room's create event
-        create_event = yield self.get_create_event_for_room(room_id)
-
-        # Return predecessor if present
-        return create_event.content.get("predecessor", None)
-
-    @defer.inlineCallbacks
-    def get_create_event_for_room(self, room_id):
-        """Get the create state event for a room.
-
-        Args:
-            room_id (str)
-
-        Returns:
-            Deferred[EventBase]: The room creation event.
-
-        Raises:
-            NotFoundError if the room is unknown
-        """
-        state_ids = yield self.get_current_state_ids(room_id)
-        create_id = state_ids.get((EventTypes.Create, ""))
-
-        # If we can't find the create event, assume we've hit a dead end
-        if not create_id:
-            raise NotFoundError("Unknown room %s" % (room_id))
-
-        # Retrieve the room's create event and return
-        create_event = yield self.get_event(create_id)
-        return create_event
-
-    @cached(max_entries=100000, iterable=True)
-    def get_current_state_ids(self, room_id):
-        """Get the current state event ids for a room based on the
-        current_state_events table.
-
-        Args:
-            room_id (str)
-
-        Returns:
-            deferred: dict of (type, state_key) -> event_id
-        """
-
-        def _get_current_state_ids_txn(txn):
-            txn.execute(
-                """SELECT type, state_key, event_id FROM current_state_events
-                WHERE room_id = ?
-                """,
-                (room_id,),
-            )
-
-            return {
-                (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
-            }
-
-        return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
-
-    # FIXME: how should this be cached?
-    def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
-        """Get the current state event of a given type for a room based on the
-        current_state_events table.  This may not be as up-to-date as the result
-        of doing a fresh state resolution as per state_handler.get_current_state
-
-        Args:
-            room_id (str)
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
-            event ID.
-        """
-
-        where_clause, where_args = state_filter.make_sql_filter_clause()
-
-        if not where_clause:
-            # We delegate to the cached version
-            return self.get_current_state_ids(room_id)
-
-        def _get_filtered_current_state_ids_txn(txn):
-            results = {}
-            sql = """
-                SELECT type, state_key, event_id FROM current_state_events
-                WHERE room_id = ?
-            """
-
-            if where_clause:
-                sql += " AND (%s)" % (where_clause,)
-
-            args = [room_id]
-            args.extend(where_args)
-            txn.execute(sql, args)
-            for row in txn:
-                typ, state_key, event_id = row
-                key = (intern_string(typ), intern_string(state_key))
-                results[key] = event_id
-
-            return results
-
-        return self.runInteraction(
-            "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
-        )
-
-    @defer.inlineCallbacks
-    def get_canonical_alias_for_room(self, room_id):
-        """Get canonical alias for room, if any
-
-        Args:
-            room_id (str)
-
-        Returns:
-            Deferred[str|None]: The canonical alias, if any
-        """
-
-        state = yield self.get_filtered_current_state_ids(
-            room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
-        )
-
-        event_id = state.get((EventTypes.CanonicalAlias, ""))
-        if not event_id:
-            return
-
-        event = yield self.get_event(event_id, allow_none=True)
-        if not event:
-            return
-
-        return event.content.get("canonical_alias")
-
-    @cached(max_entries=10000, iterable=True)
-    def get_state_group_delta(self, state_group):
-        """Given a state group try to return a previous group and a delta between
-        the old and the new.
-
-        Returns:
-            (prev_group, delta_ids), where both may be None.
-        """
-
-        def _get_state_group_delta_txn(txn):
-            prev_group = self._simple_select_one_onecol_txn(
-                txn,
-                table="state_group_edges",
-                keyvalues={"state_group": state_group},
-                retcol="prev_state_group",
-                allow_none=True,
-            )
-
-            if not prev_group:
-                return _GetStateGroupDelta(None, None)
-
-            delta_ids = self._simple_select_list_txn(
-                txn,
-                table="state_groups_state",
-                keyvalues={"state_group": state_group},
-                retcols=("type", "state_key", "event_id"),
-            )
-
-            return _GetStateGroupDelta(
-                prev_group,
-                {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
-            )
-
-        return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
-
-    @defer.inlineCallbacks
-    def get_state_groups_ids(self, _room_id, event_ids):
-        """Get the event IDs of all the state for the state groups for the given events
-
-        Args:
-            _room_id (str): id of the room for these events
-            event_ids (iterable[str]): ids of the events
-
-        Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
-        """
-        if not event_ids:
-            return {}
-
-        event_to_groups = yield self._get_state_group_for_events(event_ids)
-
-        groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups)
-
-        return group_to_state
-
-    @defer.inlineCallbacks
-    def get_state_ids_for_group(self, state_group):
-        """Get the event IDs of all the state in the given state group
-
-        Args:
-            state_group (int)
-
-        Returns:
-            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
-        """
-        group_to_state = yield self._get_state_for_groups((state_group,))
-
-        return group_to_state[state_group]
-
-    @defer.inlineCallbacks
-    def get_state_groups(self, room_id, event_ids):
-        """ Get the state groups for the given list of event_ids
-
-        Returns:
-            Deferred[dict[int, list[EventBase]]]:
-                dict of state_group_id -> list of state events.
-        """
-        if not event_ids:
-            return {}
-
-        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
-
-        state_event_map = yield self.get_events(
-            [
-                ev_id
-                for group_ids in itervalues(group_to_ids)
-                for ev_id in itervalues(group_ids)
-            ],
-            get_prev_content=False,
-        )
-
-        return {
-            group: [
-                state_event_map[v]
-                for v in itervalues(event_id_map)
-                if v in state_event_map
-            ]
-            for group, event_id_map in iteritems(group_to_ids)
-        }
-
-    @defer.inlineCallbacks
-    def _get_state_groups_from_groups(self, groups, state_filter):
-        """Returns the state groups for a given set of groups, filtering on
-        types of state events.
-
-        Args:
-            groups(list[int]): list of state group IDs to query
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-        Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
-        """
-        results = {}
-
-        chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
-        for chunk in chunks:
-            res = yield self.runInteraction(
-                "_get_state_groups_from_groups",
-                self._get_state_groups_from_groups_txn,
-                chunk,
-                state_filter,
-            )
-            results.update(res)
-
-        return results
-
-    @defer.inlineCallbacks
-    def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
-        """Given a list of event_ids and type tuples, return a list of state
-        dicts for each event.
-
-        Args:
-            event_ids (list[string])
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
-        """
-        event_to_groups = yield self._get_state_group_for_events(event_ids)
-
-        groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, state_filter)
-
-        state_event_map = yield self.get_events(
-            [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
-            get_prev_content=False,
-        )
-
-        event_to_state = {
-            event_id: {
-                k: state_event_map[v]
-                for k, v in iteritems(group_to_state[group])
-                if v in state_event_map
-            }
-            for event_id, group in iteritems(event_to_groups)
-        }
-
-        return {event: event_to_state[event] for event in event_ids}
-
-    @defer.inlineCallbacks
-    def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
-        """
-        Get the state dicts corresponding to a list of events, containing the event_ids
-        of the state events (as opposed to the events themselves)
-
-        Args:
-            event_ids(list(str)): events whose state should be returned
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            A deferred dict from event_id -> (type, state_key) -> event_id
-        """
-        event_to_groups = yield self._get_state_group_for_events(event_ids)
-
-        groups = set(itervalues(event_to_groups))
-        group_to_state = yield self._get_state_for_groups(groups, state_filter)
-
-        event_to_state = {
-            event_id: group_to_state[group]
-            for event_id, group in iteritems(event_to_groups)
-        }
-
-        return {event: event_to_state[event] for event in event_ids}
-
-    @defer.inlineCallbacks
-    def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
-        """
-        Get the state dict corresponding to a particular event
-
-        Args:
-            event_id(str): event whose state should be returned
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            A deferred dict from (type, state_key) -> state_event
-        """
-        state_map = yield self.get_state_for_events([event_id], state_filter)
-        return state_map[event_id]
-
-    @defer.inlineCallbacks
-    def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
-        """
-        Get the state dict corresponding to a particular event
-
-        Args:
-            event_id(str): event whose state should be returned
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            A deferred dict from (type, state_key) -> state_event
-        """
-        state_map = yield self.get_state_ids_for_events([event_id], state_filter)
-        return state_map[event_id]
-
-    @cached(max_entries=50000)
-    def _get_state_group_for_event(self, event_id):
-        return self._simple_select_one_onecol(
-            table="event_to_state_groups",
-            keyvalues={"event_id": event_id},
-            retcol="state_group",
-            allow_none=True,
-            desc="_get_state_group_for_event",
-        )
-
-    @cachedList(
-        cached_method_name="_get_state_group_for_event",
-        list_name="event_ids",
-        num_args=1,
-        inlineCallbacks=True,
-    )
-    def _get_state_group_for_events(self, event_ids):
-        """Returns mapping event_id -> state_group
-        """
-        rows = yield self._simple_select_many_batch(
-            table="event_to_state_groups",
-            column="event_id",
-            iterable=event_ids,
-            keyvalues={},
-            retcols=("event_id", "state_group"),
-            desc="_get_state_group_for_events",
-        )
-
-        return {row["event_id"]: row["state_group"] for row in rows}
-
-    def _get_state_for_group_using_cache(self, cache, group, state_filter):
-        """Checks if group is in cache. See `_get_state_for_groups`
-
-        Args:
-            cache(DictionaryCache): the state group cache to use
-            group(int): The state group to lookup
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns 2-tuple (`state_dict`, `got_all`).
-        `got_all` is a bool indicating if we successfully retrieved all
-        requests state from the cache, if False we need to query the DB for the
-        missing state.
-        """
-        is_all, known_absent, state_dict_ids = cache.get(group)
-
-        if is_all or state_filter.is_full():
-            # Either we have everything or want everything, either way
-            # `is_all` tells us whether we've gotten everything.
-            return state_filter.filter_state(state_dict_ids), is_all
-
-        # tracks whether any of our requested types are missing from the cache
-        missing_types = False
-
-        if state_filter.has_wildcards():
-            # We don't know if we fetched all the state keys for the types in
-            # the filter that are wildcards, so we have to assume that we may
-            # have missed some.
-            missing_types = True
-        else:
-            # There aren't any wild cards, so `concrete_types()` returns the
-            # complete list of event types we're wanting.
-            for key in state_filter.concrete_types():
-                if key not in state_dict_ids and key not in known_absent:
-                    missing_types = True
-                    break
-
-        return state_filter.filter_state(state_dict_ids), not missing_types
-
-    @defer.inlineCallbacks
-    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
-        """Gets the state at each of a list of state groups, optionally
-        filtering by type/state_key
-
-        Args:
-            groups (iterable[int]): list of state groups for which we want
-                to get the state.
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-        Returns:
-            Deferred[dict[int, dict[tuple[str, str], str]]]:
-                dict of state_group_id -> (dict of (type, state_key) -> event id)
-        """
-
-        member_filter, non_member_filter = state_filter.get_member_split()
-
-        # Now we look them up in the member and non-member caches
-        non_member_state, incomplete_groups_nm, = (
-            yield self._get_state_for_groups_using_cache(
-                groups, self._state_group_cache, state_filter=non_member_filter
-            )
-        )
-
-        member_state, incomplete_groups_m, = (
-            yield self._get_state_for_groups_using_cache(
-                groups, self._state_group_members_cache, state_filter=member_filter
-            )
-        )
-
-        state = dict(non_member_state)
-        for group in groups:
-            state[group].update(member_state[group])
-
-        # Now fetch any missing groups from the database
-
-        incomplete_groups = incomplete_groups_m | incomplete_groups_nm
-
-        if not incomplete_groups:
-            return state
-
-        cache_sequence_nm = self._state_group_cache.sequence
-        cache_sequence_m = self._state_group_members_cache.sequence
-
-        # Help the cache hit ratio by expanding the filter a bit
-        db_state_filter = state_filter.return_expanded()
-
-        group_to_state_dict = yield self._get_state_groups_from_groups(
-            list(incomplete_groups), state_filter=db_state_filter
-        )
-
-        # Now lets update the caches
-        self._insert_into_cache(
-            group_to_state_dict,
-            db_state_filter,
-            cache_seq_num_members=cache_sequence_m,
-            cache_seq_num_non_members=cache_sequence_nm,
-        )
-
-        # And finally update the result dict, by filtering out any extra
-        # stuff we pulled out of the database.
-        for group, group_state_dict in iteritems(group_to_state_dict):
-            # We just replace any existing entries, as we will have loaded
-            # everything we need from the database anyway.
-            state[group] = state_filter.filter_state(group_state_dict)
-
-        return state
-
-    def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
-        """Gets the state at each of a list of state groups, optionally
-        filtering by type/state_key, querying from a specific cache.
-
-        Args:
-            groups (iterable[int]): list of state groups for which we want
-                to get the state.
-            cache (DictionaryCache): the cache of group ids to state dicts which
-                we will pass through - either the normal state cache or the specific
-                members state cache.
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-
-        Returns:
-            tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
-            dict of state_group_id -> (dict of (type, state_key) -> event id)
-            of entries in the cache, and the state group ids either missing
-            from the cache or incomplete.
-        """
-        results = {}
-        incomplete_groups = set()
-        for group in set(groups):
-            state_dict_ids, got_all = self._get_state_for_group_using_cache(
-                cache, group, state_filter
-            )
-            results[group] = state_dict_ids
-
-            if not got_all:
-                incomplete_groups.add(group)
-
-        return results, incomplete_groups
-
-    def _insert_into_cache(
-        self,
-        group_to_state_dict,
-        state_filter,
-        cache_seq_num_members,
-        cache_seq_num_non_members,
-    ):
-        """Inserts results from querying the database into the relevant cache.
-
-        Args:
-            group_to_state_dict (dict): The new entries pulled from database.
-                Map from state group to state dict
-            state_filter (StateFilter): The state filter used to fetch state
-                from the database.
-            cache_seq_num_members (int): Sequence number of member cache since
-                last lookup in cache
-            cache_seq_num_non_members (int): Sequence number of member cache since
-                last lookup in cache
-        """
-
-        # We need to work out which types we've fetched from the DB for the
-        # member vs non-member caches. This should be as accurate as possible,
-        # but can be an underestimate (e.g. when we have wild cards)
-
-        member_filter, non_member_filter = state_filter.get_member_split()
-        if member_filter.is_full():
-            # We fetched all member events
-            member_types = None
-        else:
-            # `concrete_types()` will only return a subset when there are wild
-            # cards in the filter, but that's fine.
-            member_types = member_filter.concrete_types()
-
-        if non_member_filter.is_full():
-            # We fetched all non member events
-            non_member_types = None
-        else:
-            non_member_types = non_member_filter.concrete_types()
-
-        for group, group_state_dict in iteritems(group_to_state_dict):
-            state_dict_members = {}
-            state_dict_non_members = {}
-
-            for k, v in iteritems(group_state_dict):
-                if k[0] == EventTypes.Member:
-                    state_dict_members[k] = v
-                else:
-                    state_dict_non_members[k] = v
-
-            self._state_group_members_cache.update(
-                cache_seq_num_members,
-                key=group,
-                value=state_dict_members,
-                fetched_keys=member_types,
-            )
-
-            self._state_group_cache.update(
-                cache_seq_num_non_members,
-                key=group,
-                value=state_dict_non_members,
-                fetched_keys=non_member_types,
-            )
-
-    def store_state_group(
-        self, event_id, room_id, prev_group, delta_ids, current_state_ids
-    ):
-        """Store a new set of state, returning a newly assigned state group.
-
-        Args:
-            event_id (str): The event ID for which the state was calculated
-            room_id (str)
-            prev_group (int|None): A previous state group for the room, optional.
-            delta_ids (dict|None): The delta between state at `prev_group` and
-                `current_state_ids`, if `prev_group` was given. Same format as
-                `current_state_ids`.
-            current_state_ids (dict): The state to store. Map of (type, state_key)
-                to event_id.
-
-        Returns:
-            Deferred[int]: The state group ID
-        """
-
-        def _store_state_group_txn(txn):
-            if current_state_ids is None:
-                # AFAIK, this can never happen
-                raise Exception("current_state_ids cannot be None")
-
-            state_group = self.database_engine.get_next_state_group_id(txn)
-
-            self._simple_insert_txn(
-                txn,
-                table="state_groups",
-                values={"id": state_group, "room_id": room_id, "event_id": event_id},
-            )
-
-            # We persist as a delta if we can, while also ensuring the chain
-            # of deltas isn't tooo long, as otherwise read performance degrades.
-            if prev_group:
-                is_in_db = self._simple_select_one_onecol_txn(
-                    txn,
-                    table="state_groups",
-                    keyvalues={"id": prev_group},
-                    retcol="id",
-                    allow_none=True,
-                )
-                if not is_in_db:
-                    raise Exception(
-                        "Trying to persist state with unpersisted prev_group: %r"
-                        % (prev_group,)
-                    )
-
-                potential_hops = self._count_state_group_hops_txn(txn, prev_group)
-            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
-                self._simple_insert_txn(
-                    txn,
-                    table="state_group_edges",
-                    values={"state_group": state_group, "prev_state_group": prev_group},
-                )
-
-                self._simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    values=[
-                        {
-                            "state_group": state_group,
-                            "room_id": room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
-                        for key, state_id in iteritems(delta_ids)
-                    ],
-                )
-            else:
-                self._simple_insert_many_txn(
-                    txn,
-                    table="state_groups_state",
-                    values=[
-                        {
-                            "state_group": state_group,
-                            "room_id": room_id,
-                            "type": key[0],
-                            "state_key": key[1],
-                            "event_id": state_id,
-                        }
-                        for key, state_id in iteritems(current_state_ids)
-                    ],
-                )
-
-            # Prefill the state group caches with this group.
-            # It's fine to use the sequence like this as the state group map
-            # is immutable. (If the map wasn't immutable then this prefill could
-            # race with another update)
-
-            current_member_state_ids = {
-                s: ev
-                for (s, ev) in iteritems(current_state_ids)
-                if s[0] == EventTypes.Member
-            }
-            txn.call_after(
-                self._state_group_members_cache.update,
-                self._state_group_members_cache.sequence,
-                key=state_group,
-                value=dict(current_member_state_ids),
-            )
-
-            current_non_member_state_ids = {
-                s: ev
-                for (s, ev) in iteritems(current_state_ids)
-                if s[0] != EventTypes.Member
-            }
-            txn.call_after(
-                self._state_group_cache.update,
-                self._state_group_cache.sequence,
-                key=state_group,
-                value=dict(current_non_member_state_ids),
-            )
-
-            return state_group
-
-        return self.runInteraction("store_state_group", _store_state_group_txn)
-
-
-class StateBackgroundUpdateStore(
-    StateGroupBackgroundUpdateStore, BackgroundUpdateStore
-):
-
-    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
-    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
-    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-    EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
-
-    def __init__(self, db_conn, hs):
-        super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
-            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
-            self._background_deduplicate_state,
-        )
-        self.register_background_update_handler(
-            self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
-        )
-        self.register_background_index_update(
-            self.CURRENT_STATE_INDEX_UPDATE_NAME,
-            index_name="current_state_events_member_index",
-            table="current_state_events",
-            columns=["state_key"],
-            where_clause="type='m.room.member'",
-        )
-        self.register_background_index_update(
-            self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
-            index_name="event_to_state_groups_sg_index",
-            table="event_to_state_groups",
-            columns=["state_group"],
-        )
-
-    @defer.inlineCallbacks
-    def _background_deduplicate_state(self, progress, batch_size):
-        """This background update will slowly deduplicate state by reencoding
-        them as deltas.
-        """
-        last_state_group = progress.get("last_state_group", 0)
-        rows_inserted = progress.get("rows_inserted", 0)
-        max_group = progress.get("max_group", None)
-
-        BATCH_SIZE_SCALE_FACTOR = 100
-
-        batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
-
-        if max_group is None:
-            rows = yield self._execute(
-                "_background_deduplicate_state",
-                None,
-                "SELECT coalesce(max(id), 0) FROM state_groups",
-            )
-            max_group = rows[0][0]
-
-        def reindex_txn(txn):
-            new_last_state_group = last_state_group
-            for count in range(batch_size):
-                txn.execute(
-                    "SELECT id, room_id FROM state_groups"
-                    " WHERE ? < id AND id <= ?"
-                    " ORDER BY id ASC"
-                    " LIMIT 1",
-                    (new_last_state_group, max_group),
-                )
-                row = txn.fetchone()
-                if row:
-                    state_group, room_id = row
-
-                if not row or not state_group:
-                    return True, count
-
-                txn.execute(
-                    "SELECT state_group FROM state_group_edges"
-                    " WHERE state_group = ?",
-                    (state_group,),
-                )
-
-                # If we reach a point where we've already started inserting
-                # edges we should stop.
-                if txn.fetchall():
-                    return True, count
-
-                txn.execute(
-                    "SELECT coalesce(max(id), 0) FROM state_groups"
-                    " WHERE id < ? AND room_id = ?",
-                    (state_group, room_id),
-                )
-                prev_group, = txn.fetchone()
-                new_last_state_group = state_group
-
-                if prev_group:
-                    potential_hops = self._count_state_group_hops_txn(txn, prev_group)
-                    if potential_hops >= MAX_STATE_DELTA_HOPS:
-                        # We want to ensure chains are at most this long,#
-                        # otherwise read performance degrades.
-                        continue
-
-                    prev_state = self._get_state_groups_from_groups_txn(
-                        txn, [prev_group]
-                    )
-                    prev_state = prev_state[prev_group]
-
-                    curr_state = self._get_state_groups_from_groups_txn(
-                        txn, [state_group]
-                    )
-                    curr_state = curr_state[state_group]
-
-                    if not set(prev_state.keys()) - set(curr_state.keys()):
-                        # We can only do a delta if the current has a strict super set
-                        # of keys
-
-                        delta_state = {
-                            key: value
-                            for key, value in iteritems(curr_state)
-                            if prev_state.get(key, None) != value
-                        }
-
-                        self._simple_delete_txn(
-                            txn,
-                            table="state_group_edges",
-                            keyvalues={"state_group": state_group},
-                        )
-
-                        self._simple_insert_txn(
-                            txn,
-                            table="state_group_edges",
-                            values={
-                                "state_group": state_group,
-                                "prev_state_group": prev_group,
-                            },
-                        )
-
-                        self._simple_delete_txn(
-                            txn,
-                            table="state_groups_state",
-                            keyvalues={"state_group": state_group},
-                        )
-
-                        self._simple_insert_many_txn(
-                            txn,
-                            table="state_groups_state",
-                            values=[
-                                {
-                                    "state_group": state_group,
-                                    "room_id": room_id,
-                                    "type": key[0],
-                                    "state_key": key[1],
-                                    "event_id": state_id,
-                                }
-                                for key, state_id in iteritems(delta_state)
-                            ],
-                        )
-
-            progress = {
-                "last_state_group": state_group,
-                "rows_inserted": rows_inserted + batch_size,
-                "max_group": max_group,
-            }
-
-            self._background_update_progress_txn(
-                txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
-            )
-
-            return False, batch_size
-
-        finished, result = yield self.runInteraction(
-            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
-        )
-
-        if finished:
-            yield self._end_background_update(
-                self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
-            )
-
-        return result * BATCH_SIZE_SCALE_FACTOR
-
-    @defer.inlineCallbacks
-    def _background_index_state(self, progress, batch_size):
-        def reindex_txn(conn):
-            conn.rollback()
-            if isinstance(self.database_engine, PostgresEngine):
-                # postgres insists on autocommit for the index
-                conn.set_session(autocommit=True)
-                try:
-                    txn = conn.cursor()
-                    txn.execute(
-                        "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
-                        " ON state_groups_state(state_group, type, state_key)"
-                    )
-                    txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
-                finally:
-                    conn.set_session(autocommit=False)
-            else:
-                txn = conn.cursor()
-                txn.execute(
-                    "CREATE INDEX state_groups_state_type_idx"
-                    " ON state_groups_state(state_group, type, state_key)"
-                )
-                txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
-
-        yield self.runWithConnection(reindex_txn)
-
-        yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
-
-        return 1
-
-
-class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
-    """ Keeps track of the state at a given event.
-
-    This is done by the concept of `state groups`. Every event is a assigned
-    a state group (identified by an arbitrary string), which references a
-    collection of state events. The current state of an event is then the
-    collection of state events referenced by the event's state group.
-
-    Hence, every change in the current state causes a new state group to be
-    generated. However, if no change happens (e.g., if we get a message event
-    with only one parent it inherits the state group from its parent.)
-
-    There are three tables:
-      * `state_groups`: Stores group name, first event with in the group and
-        room id.
-      * `event_to_state_groups`: Maps events to state groups.
-      * `state_groups_state`: Maps state group to state events.
-    """
-
-    def __init__(self, db_conn, hs):
-        super(StateStore, self).__init__(db_conn, hs)
-
-    def _store_event_state_mappings_txn(self, txn, events_and_contexts):
-        state_groups = {}
-        for event, context in events_and_contexts:
-            if event.internal_metadata.is_outlier():
-                continue
-
-            # if the event was rejected, just give it the same state as its
-            # predecessor.
-            if context.rejected:
-                state_groups[event.event_id] = context.prev_group
-                continue
-
-            state_groups[event.event_id] = context.state_group
-
-        self._simple_insert_many_txn(
-            txn,
-            table="event_to_state_groups",
-            values=[
-                {"state_group": state_group_id, "event_id": event_id}
-                for event_id, state_group_id in iteritems(state_groups)
-            ],
-        )
-
-        for event_id, state_group_id in iteritems(state_groups):
-            txn.call_after(
-                self._get_state_group_for_event.prefill, (event_id,), state_group_id
-            )
diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py
index 7569b6fab5..d5c8bd7612 100644
--- a/tests/handlers/test_stats.py
+++ b/tests/handlers/test_stats.py
@@ -13,9 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse import storage
 from synapse.rest import admin
 from synapse.rest.client.v1 import login, room
+from synapse.storage.data_stores.main import stats
 
 from tests import unittest
 
@@ -87,10 +87,10 @@ class StatsRoomTests(unittest.HomeserverTestCase):
         )
 
     def _get_current_stats(self, stats_type, stat_id):
-        table, id_col = storage.stats.TYPE_TO_TABLE[stats_type]
+        table, id_col = stats.TYPE_TO_TABLE[stats_type]
 
-        cols = list(storage.stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list(
-            storage.stats.PER_SLICE_FIELDS[stats_type]
+        cols = list(stats.ABSOLUTE_STATS_FIELDS[stats_type]) + list(
+            stats.PER_SLICE_FIELDS[stats_type]
         )
 
         end_ts = self.store.quantise_stats_time(self.reactor.seconds() * 1000)
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 622b16a071..dfeea24599 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -24,7 +24,7 @@ from twisted.internet import defer
 
 from synapse.appservice import ApplicationService, ApplicationServiceState
 from synapse.config._base import ConfigError
-from synapse.storage.appservice import (
+from synapse.storage.data_stores.main.appservice import (
     ApplicationServiceStore,
     ApplicationServiceTransactionStore,
 )
diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py
index 34f9c72709..69dcaa63d5 100644
--- a/tests/storage/test_cleanup_extrems.py
+++ b/tests/storage/test_cleanup_extrems.py
@@ -50,6 +50,8 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
 
         schema_path = os.path.join(
             prepare_database.dir_path,
+            "data_stores",
+            "main",
             "schema",
             "delta",
             "54",
diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py
index 45824bd3b2..24c7fe16c3 100644
--- a/tests/storage/test_profile.py
+++ b/tests/storage/test_profile.py
@@ -16,7 +16,7 @@
 
 from twisted.internet import defer
 
-from synapse.storage.profile import ProfileStore
+from synapse.storage.data_stores.main.profile import ProfileStore
 from synapse.types import UserID
 
 from tests import unittest
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index d7d244ce97..7eea57c0e2 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-from synapse.storage import UserDirectoryStore
+from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
 
 from tests import unittest
 from tests.utils import setup_test_homeserver

From ffd24545bbde327179a76bb1d6ed02b70bd2b93b Mon Sep 17 00:00:00 2001
From: Erik Johnston 
Date: Mon, 21 Oct 2019 16:08:40 +0100
Subject: [PATCH 44/55] Fix schema management to work with multiple data
 stores.

---
 synapse/storage/prepare_database.py | 159 ++++++++++++++++++++--------
 1 file changed, 113 insertions(+), 46 deletions(-)

diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e96eed8a6d..0bb970a296 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -14,12 +14,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import fnmatch
 import imp
 import logging
 import os
 import re
 
+import attr
+
 from synapse.storage.engines.postgres import PostgresEngine
 
 logger = logging.getLogger(__name__)
@@ -54,6 +55,10 @@ def prepare_database(db_conn, database_engine, config):
             application config, or None if we are connecting to an existing
             database which we expect to be configured already
     """
+
+    # For now we only have the one datastore.
+    data_stores = ["main"]
+
     try:
         cur = db_conn.cursor()
         version_info = _get_or_create_schema_state(cur, database_engine)
@@ -68,10 +73,16 @@ def prepare_database(db_conn, database_engine, config):
                     raise UpgradeDatabaseException("Database needs to be upgraded")
             else:
                 _upgrade_existing_database(
-                    cur, user_version, delta_files, upgraded, database_engine, config
+                    cur,
+                    user_version,
+                    delta_files,
+                    upgraded,
+                    database_engine,
+                    config,
+                    data_stores=data_stores,
                 )
         else:
-            _setup_new_database(cur, database_engine)
+            _setup_new_database(cur, database_engine, data_stores=data_stores)
 
         # check if any of our configured dynamic modules want a database
         if config is not None:
@@ -84,7 +95,7 @@ def prepare_database(db_conn, database_engine, config):
         raise
 
 
-def _setup_new_database(cur, database_engine):
+def _setup_new_database(cur, database_engine, data_stores):
     """Sets up the database by finding a base set of "full schemas" and then
     applying any necessary deltas.
 
@@ -115,48 +126,65 @@ def _setup_new_database(cur, database_engine):
     current_dir = os.path.join(dir_path, "schema", "full_schemas")
     directory_entries = os.listdir(current_dir)
 
-    valid_dirs = []
-    pattern = re.compile(r"^\d+(\.sql)?$")
+    # First we find the highest full schema version we have
+    valid_versions = []
+
+    for filename in directory_entries:
+        try:
+            ver = int(filename)
+        except ValueError:
+            continue
+
+        if ver <= SCHEMA_VERSION:
+            valid_versions.append(ver)
+
+    if not valid_versions:
+        raise PrepareDatabaseException(
+            "Could not find a suitable base set of full schemas"
+        )
+
+    max_current_ver = max(valid_versions)
+
+    logger.debug("Initialising schema v%d", max_current_ver)
+
+    # Now lets find all the full schema files, both in the global schema and
+    # in data store schemas.
+    directories = [os.path.join(current_dir, str(max_current_ver))]
+    directories.extend(
+        os.path.join(
+            dir_path,
+            "data_stores",
+            data_store,
+            "schema",
+            "full_schemas",
+            str(max_current_ver),
+        )
+        for data_store in data_stores
+    )
+
+    directory_entries = []
+    for directory in directories:
+        directory_entries.extend(
+            _DirectoryListing(file_name, os.path.join(directory, file_name))
+            for file_name in os.listdir(directory)
+        )
 
     if isinstance(database_engine, PostgresEngine):
         specific = "postgres"
     else:
         specific = "sqlite"
 
-    specific_pattern = re.compile(r"^\d+(\.sql." + specific + r")?$")
-
-    for filename in directory_entries:
-        match = pattern.match(filename) or specific_pattern.match(filename)
-        abs_path = os.path.join(current_dir, filename)
-        if match and os.path.isdir(abs_path):
-            ver = int(match.group(0))
-            if ver <= SCHEMA_VERSION:
-                valid_dirs.append((ver, abs_path))
-        else:
-            logger.debug("Ignoring entry '%s' in 'full_schemas'", filename)
-
-    if not valid_dirs:
-        raise PrepareDatabaseException(
-            "Could not find a suitable base set of full schemas"
-        )
-
-    max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0])
-
-    logger.debug("Initialising schema v%d", max_current_ver)
-
-    directory_entries = os.listdir(sql_dir)
-
-    for filename in sorted(
-        fnmatch.filter(directory_entries, "*.sql")
-        + fnmatch.filter(directory_entries, "*.sql." + specific)
-    ):
-        sql_loc = os.path.join(sql_dir, filename)
-        logger.debug("Applying schema %s", sql_loc)
-        executescript(cur, sql_loc)
+    directory_entries.sort()
+    for entry in directory_entries:
+        if entry.file_name.endswith(".sql") or entry.file_name.endswith(
+            ".sql." + specific
+        ):
+            logger.debug("Applying schema %s", entry.absolute_path)
+            executescript(cur, entry.absolute_path)
 
     cur.execute(
         database_engine.convert_param_style(
-            "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
+            "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
         ),
         (max_current_ver, False),
     )
@@ -168,6 +196,7 @@ def _setup_new_database(cur, database_engine):
         upgraded=False,
         database_engine=database_engine,
         config=None,
+        data_stores=data_stores,
         is_empty=True,
     )
 
@@ -179,6 +208,7 @@ def _upgrade_existing_database(
     upgraded,
     database_engine,
     config,
+    data_stores,
     is_empty=False,
 ):
     """Upgrades an existing database.
@@ -248,24 +278,51 @@ def _upgrade_existing_database(
     for v in range(start_ver, SCHEMA_VERSION + 1):
         logger.info("Upgrading schema to v%d", v)
 
-        delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
+        # We need to search both the global and per data store schema
+        # directories for schema updates.
 
-        try:
-            directory_entries = os.listdir(delta_dir)
-        except OSError:
-            logger.exception("Could not open delta dir for version %d", v)
-            raise UpgradeDatabaseException(
-                "Could not open delta dir for version %d" % (v,)
+        # First we find the directories to search in
+        delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
+        directories = [delta_dir]
+        for data_store in data_stores:
+            directories.append(
+                os.path.join(
+                    dir_path, "data_stores", data_store, "schema", "delta", str(v)
+                )
             )
 
+        # Now find which directories have anything of interest.
+        directory_entries = []
+        for directory in directories:
+            logger.debug("Looking for schema deltas in %s", directory)
+            try:
+                file_names = os.listdir(directory)
+                directory_entries.extend(
+                    _DirectoryListing(file_name, os.path.join(directory, file_name))
+                    for file_name in file_names
+                )
+            except FileNotFoundError:
+                # Data stores can have empty entries for a given version delta.
+                pass
+            except OSError:
+                logger.exception("Could not open delta dir for version %d", v)
+                raise UpgradeDatabaseException(
+                    "Could not open delta dir for version %d" % (v,)
+                )
+
+        if not directory_entries:
+            continue
+
         directory_entries.sort()
-        for file_name in directory_entries:
+        for entry in directory_entries:
+            file_name = entry.file_name
             relative_path = os.path.join(str(v), file_name)
+            absolute_path = entry.absolute_path
+
             logger.debug("Found file: %s", relative_path)
             if relative_path in applied_delta_files:
                 continue
 
-            absolute_path = os.path.join(dir_path, "schema", "delta", relative_path)
             root_name, ext = os.path.splitext(file_name)
             if ext == ".py":
                 # This is a python upgrade module. We need to import into some
@@ -448,3 +505,13 @@ def _get_or_create_schema_state(txn, database_engine):
         return current_version, applied_deltas, upgraded
 
     return None
+
+
+@attr.s()
+class _DirectoryListing(object):
+    """Helper class to store schema file name and the
+    absolute path to it.
+    """
+
+    file_name = attr.ib()
+    absolute_path = attr.ib()

From 3c304aaaebfd95f6cc17b1f0677183df9fe6b735 Mon Sep 17 00:00:00 2001
From: Erik Johnston 
Date: Mon, 21 Oct 2019 16:10:37 +0100
Subject: [PATCH 45/55] Newsfile

---
 changelog.d/6231.misc | 1 +
 1 file changed, 1 insertion(+)
 create mode 100644 changelog.d/6231.misc

diff --git a/changelog.d/6231.misc b/changelog.d/6231.misc
new file mode 100644
index 0000000000..89b8297794
--- /dev/null
+++ b/changelog.d/6231.misc
@@ -0,0 +1 @@
+Refactor storage layer in preparation to support having multiple databases.

From 4b5163d521eb36f5e18607e9427a7805c34cf89c Mon Sep 17 00:00:00 2001
From: Erik Johnston 
Date: Mon, 21 Oct 2019 16:13:16 +0100
Subject: [PATCH 46/55] Fix packaging

---
 MANIFEST.in | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/MANIFEST.in b/MANIFEST.in
index b22be58f3d..2b8244f9c5 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -8,11 +8,11 @@ include demo/demo.tls.dh
 include demo/*.py
 include demo/*.sh
 
-recursive-include synapse/storage/schema *.sql
-recursive-include synapse/storage/schema *.sql.postgres
-recursive-include synapse/storage/schema *.sql.sqlite
-recursive-include synapse/storage/schema *.py
-recursive-include synapse/storage/schema *.txt
+recursive-include synapse/storage *.sql
+recursive-include synapse/storage *.sql.postgres
+recursive-include synapse/storage *.sql.sqlite
+recursive-include synapse/storage *.py
+recursive-include synapse/storage *.txt
 
 recursive-include docs *
 recursive-include scripts *

From 336eeea3ffd14dbd879459cde50f1b7f32e9a325 Mon Sep 17 00:00:00 2001
From: Erik Johnston 
Date: Tue, 22 Oct 2019 11:02:01 +0100
Subject: [PATCH 47/55] Fix postgres unit tests to use prepare_database

---
 tests/utils.py | 12 ++----------
 1 file changed, 2 insertions(+), 10 deletions(-)

diff --git a/tests/utils.py b/tests/utils.py
index 46ef2959f2..0a64f75d04 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -38,11 +38,7 @@ from synapse.logging.context import LoggingContext
 from synapse.server import HomeServer
 from synapse.storage import DataStore
 from synapse.storage.engines import PostgresEngine, create_engine
-from synapse.storage.prepare_database import (
-    _get_or_create_schema_state,
-    _setup_new_database,
-    prepare_database,
-)
+from synapse.storage.prepare_database import prepare_database
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 # set this to True to run the tests against postgres instead of sqlite.
@@ -88,11 +84,7 @@ def setupdb():
             host=POSTGRES_HOST,
             password=POSTGRES_PASSWORD,
         )
-        cur = db_conn.cursor()
-        _get_or_create_schema_state(cur, db_engine)
-        _setup_new_database(cur, db_engine)
-        db_conn.commit()
-        cur.close()
+        prepare_database(db_conn, db_engine, None)
         db_conn.close()
 
         def _cleanup():

From acf47c76989de39baaa55acc6cc86871b5048601 Mon Sep 17 00:00:00 2001
From: Erik Johnston 
Date: Tue, 22 Oct 2019 11:49:00 +0100
Subject: [PATCH 48/55] Add a basic README to synapse.storage

---
 MANIFEST.in               |  1 +
 synapse/storage/README.md | 13 +++++++++++++
 2 files changed, 14 insertions(+)
 create mode 100644 synapse/storage/README.md

diff --git a/MANIFEST.in b/MANIFEST.in
index 2b8244f9c5..156d6f04f7 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -13,6 +13,7 @@ recursive-include synapse/storage *.sql.postgres
 recursive-include synapse/storage *.sql.sqlite
 recursive-include synapse/storage *.py
 recursive-include synapse/storage *.txt
+recursive-include synapse/storage *.md
 
 recursive-include docs *
 recursive-include scripts *
diff --git a/synapse/storage/README.md b/synapse/storage/README.md
new file mode 100644
index 0000000000..567ae785a7
--- /dev/null
+++ b/synapse/storage/README.md
@@ -0,0 +1,13 @@
+Storage Layer
+=============
+
+The storage layer is split up into multiple parts to allow Synapse to run
+against different configurations of databases (e.g. single or multiple
+databases). The `data_stores` are classes that talk directly to a single
+database and have associated schemas, background updates, etc. On top of those
+there are (or will be) classes that provide high level interfaces that combine
+calls to multiple `data_stores`.
+
+There are also schemas that get applied to every database, regardless of the
+data stores associated with them (e.g. the schema version tables), which are
+stored in `synapse.storage.schema`.

From 0327a00a3716d8f96ab1353613eca3a0eb813c65 Mon Sep 17 00:00:00 2001
From: Adrien Luxey 
Date: Tue, 22 Oct 2019 13:48:02 +0200
Subject: [PATCH 49/55] Update postgres.md (#6234)

Added database owner authentication with `sudo` when `su` does not work
---
 docs/postgres.md | 10 +++++++---
 1 file changed, 7 insertions(+), 3 deletions(-)

diff --git a/docs/postgres.md b/docs/postgres.md
index 29cf762858..7cb1ad18d4 100644
--- a/docs/postgres.md
+++ b/docs/postgres.md
@@ -27,17 +27,21 @@ connect to a postgres database.
 
 ## Set up database
 
-Assuming your PostgreSQL database user is called `postgres`, create a
-user `synapse_user` with:
+Assuming your PostgreSQL database user is called `postgres`, first authenticate as the database user with:
 
     su - postgres
+    # Or, if your system uses sudo to get administrative rights
+    sudo -u postgres bash
+  
+Then, create a user ``synapse_user`` with:
+
     createuser --pwprompt synapse_user
 
 Before you can authenticate with the `synapse_user`, you must create a
 database that it can access. To create a database, first connect to the
 database with your database user:
 
-    su - postgres
+    su - postgres # Or: sudo -u postgres bash
     psql
 
 and then run:

From b2945d26727520378f1f80bf96eed24638c360cf Mon Sep 17 00:00:00 2001
From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
Date: Tue, 22 Oct 2019 13:52:25 +0100
Subject: [PATCH 50/55] Fix demo script on ipv6-supported boxes (#6229)

The synapse demo was a bit flakey in terms of supporting federation. It turns out that if your computer resolved `localhost` to `::1` instead of `127.0.0.1`, the built-in federation blacklist specified in `start.sh` would still block it, since it contained an entry for `::/127`. Removing this no longer prevents Synapse from contacting `::1`, federation works again on these boxes.
---
 changelog.d/6229.bugfix | 1 +
 demo/start.sh           | 3 +--
 2 files changed, 2 insertions(+), 2 deletions(-)
 create mode 100644 changelog.d/6229.bugfix

diff --git a/changelog.d/6229.bugfix b/changelog.d/6229.bugfix
new file mode 100644
index 0000000000..bced3304d0
--- /dev/null
+++ b/changelog.d/6229.bugfix
@@ -0,0 +1 @@
+Prevent the demo Synapse's from blacklisting `::1`.
\ No newline at end of file
diff --git a/demo/start.sh b/demo/start.sh
index eccaa2abeb..83396e5c33 100755
--- a/demo/start.sh
+++ b/demo/start.sh
@@ -77,14 +77,13 @@ for port in 8080 8081 8082; do
 
         # Reduce the blacklist
         blacklist=$(cat <<-BLACK
-		# Set the blacklist so that it doesn't include 127.0.0.1
+		# Set the blacklist so that it doesn't include 127.0.0.1, ::1
 		federation_ip_range_blacklist:
 		  - '10.0.0.0/8'
 		  - '172.16.0.0/12'
 		  - '192.168.0.0/16'
 		  - '100.64.0.0/10'
 		  - '169.254.0.0/16'
-		  - '::1/128'
 		  - 'fe80::/64'
 		  - 'fc00::/7'
 		BLACK

From 1bbc5444a88e7069817433a518df50ac1a4c1811 Mon Sep 17 00:00:00 2001
From: Erik Johnston 
Date: Tue, 22 Oct 2019 17:59:31 +0100
Subject: [PATCH 51/55] Move README into synapse/storage/__init__.py

---
 synapse/storage/README.md   | 13 -------------
 synapse/storage/__init__.py | 12 ++++++++++++
 2 files changed, 12 insertions(+), 13 deletions(-)
 delete mode 100644 synapse/storage/README.md

diff --git a/synapse/storage/README.md b/synapse/storage/README.md
deleted file mode 100644
index 567ae785a7..0000000000
--- a/synapse/storage/README.md
+++ /dev/null
@@ -1,13 +0,0 @@
-Storage Layer
-=============
-
-The storage layer is split up into multiple parts to allow Synapse to run
-against different configurations of databases (e.g. single or multiple
-databases). The `data_stores` are classes that talk directly to a single
-database and have associated schemas, background updates, etc. On top of those
-there are (or will be) classes that provide high level interfaces that combine
-calls to multiple `data_stores`.
-
-There are also schemas that get applied to every database, regardless of the
-data stores associated with them (e.g. the schema version tables), which are
-stored in `synapse.storage.schema`.
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 2f29c4a112..a249ecd219 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -14,6 +14,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+"""
+The storage layer is split up into multiple parts to allow Synapse to run
+against different configurations of databases (e.g. single or multiple
+databases). The `data_stores` are classes that talk directly to a single
+database and have associated schemas, background updates, etc. On top of those
+there are (or will be) classes that provide high level interfaces that combine
+calls to multiple `data_stores`.
+
+There are also schemas that get applied to every database, regardless of the
+data stores associated with them (e.g. the schema version tables), which are
+stored in `synapse.storage.schema`.
+"""
 
 from synapse.storage.data_stores.main import DataStore  # noqa: F401
 

From 6cc497f99b804cfef91ba559247db476054c0d35 Mon Sep 17 00:00:00 2001
From: Erik Johnston 
Date: Tue, 22 Oct 2019 18:02:50 +0100
Subject: [PATCH 52/55] Delete background_update table creation in main
 data_store

---
 .../schema/delta/25/00background_updates.sql  | 21 -------------------
 1 file changed, 21 deletions(-)
 delete mode 100644 synapse/storage/data_stores/main/schema/delta/25/00background_updates.sql

diff --git a/synapse/storage/data_stores/main/schema/delta/25/00background_updates.sql b/synapse/storage/data_stores/main/schema/delta/25/00background_updates.sql
deleted file mode 100644
index 2ad9e8fa56..0000000000
--- a/synapse/storage/data_stores/main/schema/delta/25/00background_updates.sql
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2015, 2016 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-
-CREATE TABLE IF NOT EXISTS background_updates(
-    update_name TEXT NOT NULL, -- The name of the background update.
-    progress_json TEXT NOT NULL, -- The current progress of the update as JSON.
-    CONSTRAINT background_updates_uniqueness UNIQUE (update_name)
-);

From 23d62eded235d26dd582f6a6c5aa48aa16c6fba6 Mon Sep 17 00:00:00 2001
From: Erik Johnston 
Date: Tue, 22 Oct 2019 18:43:31 +0100
Subject: [PATCH 53/55] Clean up prepare_database.py a bit and add comments

---
 synapse/storage/prepare_database.py | 37 +++++++++++++++++++++++------
 1 file changed, 30 insertions(+), 7 deletions(-)

diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 0bb970a296..2e7753820e 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -97,7 +97,8 @@ def prepare_database(db_conn, database_engine, config):
 
 def _setup_new_database(cur, database_engine, data_stores):
     """Sets up the database by finding a base set of "full schemas" and then
-    applying any necessary deltas.
+    applying any necessary deltas, including schemas from the given data
+    stores.
 
     The "full_schemas" directory has subdirectories named after versions. This
     function searches for the highest version less than or equal to
@@ -122,6 +123,15 @@ def _setup_new_database(cur, database_engine, data_stores):
 
     In the example foo.sql and bar.sql would be run, and then any delta files
     for versions strictly greater than 11.
+
+    Note: we apply the full schemas and deltas from the top level `schema/`
+    folder as well those in the data stores specified.
+
+    Args:
+        cur (Cursor): a database cursor
+        database_engine (DatabaseEngine)
+        data_stores (list[str]): The names of the data stores to instantiate
+            on the given database.
     """
     current_dir = os.path.join(dir_path, "schema", "full_schemas")
     directory_entries = os.listdir(current_dir)
@@ -245,6 +255,10 @@ def _upgrade_existing_database(
     only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in
     some arbitrary order.
 
+    Note: we apply the delta files from the specified data stores as well as
+    those in the top-level schema. We apply all delta files across data stores
+    for a version before applying those in the next version.
+
     Args:
         cur (Cursor)
         current_version (int): The current version of the schema.
@@ -254,6 +268,14 @@ def _upgrade_existing_database(
             applied deltas or from full schema file. If `True` the function
             will never apply delta files for the given `current_version`, since
             the current_version wasn't generated by applying those delta files.
+        database_engine (DatabaseEngine)
+        config (synapse.config.homeserver.HomeServerConfig|None):
+            application config, or None if we are connecting to an existing
+            database which we expect to be configured already
+        data_stores (list[str]): The names of the data stores to instantiate
+            on the given database.
+        is_empty (bool): Is this a blank database? I.e. do we need to run the
+            upgrade portions of the delta scripts.
     """
 
     if current_version > SCHEMA_VERSION:
@@ -305,21 +327,19 @@ def _upgrade_existing_database(
                 # Data stores can have empty entries for a given version delta.
                 pass
             except OSError:
-                logger.exception("Could not open delta dir for version %d", v)
                 raise UpgradeDatabaseException(
-                    "Could not open delta dir for version %d" % (v,)
+                    "Could not open delta dir for version %d: %s" % (v, directory)
                 )
 
-        if not directory_entries:
-            continue
-
+        # We sort to ensure that we apply the delta files in a consistent
+        # order (to avoid bugs caused by inconsistent directory listing order)
         directory_entries.sort()
         for entry in directory_entries:
             file_name = entry.file_name
             relative_path = os.path.join(str(v), file_name)
             absolute_path = entry.absolute_path
 
-            logger.debug("Found file: %s", relative_path)
+            logger.debug("Found file: %s (%s)", relative_path, absolute_path)
             if relative_path in applied_delta_files:
                 continue
 
@@ -511,6 +531,9 @@ def _get_or_create_schema_state(txn, database_engine):
 class _DirectoryListing(object):
     """Helper class to store schema file name and the
     absolute path to it.
+
+    These entries get sorted, so for consistency we want to ensure that
+    `file_name` attr is kept first.
     """
 
     file_name = attr.ib()

From 409c62b27bca5df1c1f147e85ac1432376054d1c Mon Sep 17 00:00:00 2001
From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
Date: Wed, 23 Oct 2019 13:22:54 +0100
Subject: [PATCH 54/55] Add config linting script that checks for bool casing
 (#6203)

Add a linting script that enforces all boolean values in the default config be lowercase.

This has annoyed me for a while so I decided to fix it.
---
 changelog.d/6203.misc            |  1 +
 docs/sample_config.yaml          | 30 +++++++++++++++---------------
 scripts-dev/config-lint.sh       |  9 +++++++++
 scripts-dev/lint.sh              |  1 +
 synapse/config/appservice.py     |  2 +-
 synapse/config/consent_config.py |  4 ++--
 synapse/config/emailconfig.py    |  4 ++--
 synapse/config/metrics.py        |  2 +-
 synapse/config/registration.py   |  2 +-
 synapse/config/saml2_config.py   |  2 +-
 synapse/config/server.py         | 10 +++++-----
 synapse/config/tls.py            |  9 ++++++++-
 synapse/config/voip.py           |  2 +-
 tox.ini                          |  1 +
 14 files changed, 49 insertions(+), 30 deletions(-)
 create mode 100644 changelog.d/6203.misc
 create mode 100755 scripts-dev/config-lint.sh

diff --git a/changelog.d/6203.misc b/changelog.d/6203.misc
new file mode 100644
index 0000000000..c1d8276d45
--- /dev/null
+++ b/changelog.d/6203.misc
@@ -0,0 +1 @@
+Enforce that all boolean configuration values are lowercase in CI.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 8226978ba6..b4dd146f06 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -86,7 +86,7 @@ pid_file: DATADIR/homeserver.pid
 # Whether room invites to users on this server should be blocked
 # (except those sent by local server admins). The default is False.
 #
-#block_non_admin_invites: True
+#block_non_admin_invites: true
 
 # Room searching
 #
@@ -239,7 +239,7 @@ listeners:
 
 # Global blocking
 #
-#hs_disabled: False
+#hs_disabled: false
 #hs_disabled_message: 'Human readable reason for why the HS is blocked'
 #hs_disabled_limit_type: 'error code(str), to help clients decode reason'
 
@@ -261,7 +261,7 @@ listeners:
 # sign up in a short space of time never to return after their initial
 # session.
 #
-#limit_usage_by_mau: False
+#limit_usage_by_mau: false
 #max_mau_value: 50
 #mau_trial_days: 2
 
@@ -269,7 +269,7 @@ listeners:
 # be populated, however no one will be limited. If limit_usage_by_mau
 # is true, this is implied to be true.
 #
-#mau_stats_only: False
+#mau_stats_only: false
 
 # Sometimes the server admin will want to ensure certain accounts are
 # never blocked by mau checking. These accounts are specified here.
@@ -294,7 +294,7 @@ listeners:
 #
 # Uncomment the below lines to enable:
 #limit_remote_rooms:
-#  enabled: True
+#  enabled: true
 #  complexity: 1.0
 #  complexity_error: "This room is too complex."
 
@@ -411,7 +411,7 @@ acme:
     # ACME support is disabled by default. Set this to `true` and uncomment
     # tls_certificate_path and tls_private_key_path above to enable it.
     #
-    enabled: False
+    enabled: false
 
     # Endpoint to use to request certificates. If you only want to test,
     # use Let's Encrypt's staging url:
@@ -786,7 +786,7 @@ uploads_path: "DATADIR/uploads"
 # connect to arbitrary endpoints without having first signed up for a
 # valid account (e.g. by passing a CAPTCHA).
 #
-#turn_allow_guests: True
+#turn_allow_guests: true
 
 
 ## Registration ##
@@ -829,7 +829,7 @@ uploads_path: "DATADIR/uploads"
 # where d is equal to 10% of the validity period.
 #
 #account_validity:
-#  enabled: True
+#  enabled: true
 #  period: 6w
 #  renew_at: 1w
 #  renew_email_subject: "Renew your %(app)s account"
@@ -971,7 +971,7 @@ account_threepid_delegates:
 
 # Enable collection and rendering of performance metrics
 #
-#enable_metrics: False
+#enable_metrics: false
 
 # Enable sentry integration
 # NOTE: While attempts are made to ensure that the logs don't contain
@@ -1023,7 +1023,7 @@ metrics_flags:
 # Uncomment to enable tracking of application service IP addresses. Implicitly
 # enables MAU tracking for application service users.
 #
-#track_appservice_user_ips: True
+#track_appservice_user_ips: true
 
 
 # a secret which is used to sign access tokens. If none is specified,
@@ -1149,7 +1149,7 @@ saml2_config:
   #      - url: https://our_idp/metadata.xml
   #
   #    # By default, the user has to go to our login page first. If you'd like
-  #    # to allow IdP-initiated login, set 'allow_unsolicited: True' in a
+  #    # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
   #    # 'service.sp' section:
   #    #
   #    #service:
@@ -1263,13 +1263,13 @@ password_config:
 #   smtp_port: 25 # SSL: 465, STARTTLS: 587
 #   smtp_user: "exampleusername"
 #   smtp_pass: "examplepassword"
-#   require_transport_security: False
+#   require_transport_security: false
 #   notif_from: "Your Friendly %(app)s Home Server "
 #   app_name: Matrix
 #
 #   # Enable email notifications by default
 #   #
-#   notif_for_new_users: True
+#   notif_for_new_users: true
 #
 #   # Defining a custom URL for Riot is only needed if email notifications
 #   # should contain links to a self-hosted installation of Riot; when set
@@ -1447,11 +1447,11 @@ password_config:
 #    body: >-
 #      To continue using this homeserver you must review and agree to the
 #      terms and conditions at %(consent_uri)s
-#  send_server_notice_to_guests: True
+#  send_server_notice_to_guests: true
 #  block_events_error: >-
 #    To continue using this homeserver you must review and agree to the
 #    terms and conditions at %(consent_uri)s
-#  require_at_registration: False
+#  require_at_registration: false
 #  policy_name: Privacy Policy
 #
 
diff --git a/scripts-dev/config-lint.sh b/scripts-dev/config-lint.sh
new file mode 100755
index 0000000000..677a854c85
--- /dev/null
+++ b/scripts-dev/config-lint.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+# Find linting errors in Synapse's default config file.
+# Exits with 0 if there are no problems, or another code otherwise.
+
+# Fix non-lowercase true/false values
+sed -i -E "s/: +True/: true/g; s/: +False/: false/g;" docs/sample_config.yaml
+
+# Check if anything changed
+git diff --exit-code docs/sample_config.yaml
diff --git a/scripts-dev/lint.sh b/scripts-dev/lint.sh
index ebb4d69f86..02a2ca39e5 100755
--- a/scripts-dev/lint.sh
+++ b/scripts-dev/lint.sh
@@ -10,3 +10,4 @@ set -e
 isort -y -rc synapse tests scripts-dev scripts
 flake8 synapse tests
 python3 -m black synapse tests scripts-dev scripts
+./scripts-dev/config-lint.sh
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 9b4682222d..e77d3387ff 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -48,7 +48,7 @@ class AppServiceConfig(Config):
         # Uncomment to enable tracking of application service IP addresses. Implicitly
         # enables MAU tracking for application service users.
         #
-        #track_appservice_user_ips: True
+        #track_appservice_user_ips: true
         """
 
 
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index 62c4c44d60..aec9c4bbce 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -62,11 +62,11 @@ DEFAULT_CONFIG = """\
 #    body: >-
 #      To continue using this homeserver you must review and agree to the
 #      terms and conditions at %(consent_uri)s
-#  send_server_notice_to_guests: True
+#  send_server_notice_to_guests: true
 #  block_events_error: >-
 #    To continue using this homeserver you must review and agree to the
 #    terms and conditions at %(consent_uri)s
-#  require_at_registration: False
+#  require_at_registration: false
 #  policy_name: Privacy Policy
 #
 """
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 658897a77e..39e7a1dddb 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -304,13 +304,13 @@ class EmailConfig(Config):
         #   smtp_port: 25 # SSL: 465, STARTTLS: 587
         #   smtp_user: "exampleusername"
         #   smtp_pass: "examplepassword"
-        #   require_transport_security: False
+        #   require_transport_security: false
         #   notif_from: "Your Friendly %(app)s Home Server "
         #   app_name: Matrix
         #
         #   # Enable email notifications by default
         #   #
-        #   notif_for_new_users: True
+        #   notif_for_new_users: true
         #
         #   # Defining a custom URL for Riot is only needed if email notifications
         #   # should contain links to a self-hosted installation of Riot; when set
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 282a43bddb..22538153e1 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -70,7 +70,7 @@ class MetricsConfig(Config):
 
         # Enable collection and rendering of performance metrics
         #
-        #enable_metrics: False
+        #enable_metrics: false
 
         # Enable sentry integration
         # NOTE: While attempts are made to ensure that the logs don't contain
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index b3e3e6dda2..ab41623b2b 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -180,7 +180,7 @@ class RegistrationConfig(Config):
         # where d is equal to 10%% of the validity period.
         #
         #account_validity:
-        #  enabled: True
+        #  enabled: true
         #  period: 6w
         #  renew_at: 1w
         #  renew_email_subject: "Renew your %%(app)s account"
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c407e13680..c5ea2d43a1 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -176,7 +176,7 @@ class SAML2Config(Config):
           #      - url: https://our_idp/metadata.xml
           #
           #    # By default, the user has to go to our login page first. If you'd like
-          #    # to allow IdP-initiated login, set 'allow_unsolicited: True' in a
+          #    # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
           #    # 'service.sp' section:
           #    #
           #    #service:
diff --git a/synapse/config/server.py b/synapse/config/server.py
index afc4d6a4ab..c942841578 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -532,7 +532,7 @@ class ServerConfig(Config):
         # Whether room invites to users on this server should be blocked
         # (except those sent by local server admins). The default is False.
         #
-        #block_non_admin_invites: True
+        #block_non_admin_invites: true
 
         # Room searching
         #
@@ -673,7 +673,7 @@ class ServerConfig(Config):
 
         # Global blocking
         #
-        #hs_disabled: False
+        #hs_disabled: false
         #hs_disabled_message: 'Human readable reason for why the HS is blocked'
         #hs_disabled_limit_type: 'error code(str), to help clients decode reason'
 
@@ -695,7 +695,7 @@ class ServerConfig(Config):
         # sign up in a short space of time never to return after their initial
         # session.
         #
-        #limit_usage_by_mau: False
+        #limit_usage_by_mau: false
         #max_mau_value: 50
         #mau_trial_days: 2
 
@@ -703,7 +703,7 @@ class ServerConfig(Config):
         # be populated, however no one will be limited. If limit_usage_by_mau
         # is true, this is implied to be true.
         #
-        #mau_stats_only: False
+        #mau_stats_only: false
 
         # Sometimes the server admin will want to ensure certain accounts are
         # never blocked by mau checking. These accounts are specified here.
@@ -728,7 +728,7 @@ class ServerConfig(Config):
         #
         # Uncomment the below lines to enable:
         #limit_remote_rooms:
-        #  enabled: True
+        #  enabled: true
         #  complexity: 1.0
         #  complexity_error: "This room is too complex."
 
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index f06341eb67..2e9e478a2a 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -289,6 +289,9 @@ class TlsConfig(Config):
             "http://localhost:8009/.well-known/acme-challenge"
         )
 
+        # flake8 doesn't recognise that variables are used in the below string
+        _ = tls_enabled, proxypassline, acme_enabled, default_acme_account_file
+
         return (
             """\
         ## TLS ##
@@ -451,7 +454,11 @@ class TlsConfig(Config):
         #tls_fingerprints: [{"sha256": ""}]
 
         """
-            % locals()
+            # Lowercase the string representation of boolean values
+            % {
+                x[0]: str(x[1]).lower() if isinstance(x[1], bool) else x[1]
+                for x in locals().items()
+            }
         )
 
     def read_tls_certificate(self):
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index a68a3068aa..b313bff140 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -56,5 +56,5 @@ class VoipConfig(Config):
         # connect to arbitrary endpoints without having first signed up for a
         # valid account (e.g. by passing a CAPTCHA).
         #
-        #turn_allow_guests: True
+        #turn_allow_guests: true
         """
diff --git a/tox.ini b/tox.ini
index 7ba6f6339f..3cd2c5e633 100644
--- a/tox.ini
+++ b/tox.ini
@@ -118,6 +118,7 @@ deps =
 commands =
     python -m black --check --diff .
     /bin/sh -c "flake8 synapse tests scripts scripts-dev scripts/hash_password scripts/register_new_matrix_user scripts/synapse_port_db synctl {env:PEP8SUFFIX:}"
+    {toxinidir}/scripts-dev/config-lint.sh
 
 [testenv:check_isort]
 skip_install = True

From c97ed64db3d99680819ec4dcd88ea76f3d0c7537 Mon Sep 17 00:00:00 2001
From: Brendan Abolivier 
Date: Wed, 23 Oct 2019 15:31:59 +0100
Subject: [PATCH 55/55] Make synapse_port_db correctly create indexes (#6102)

Make `synapse_port_db` correctly create indexes in the PostgreSQL database, by having it run the background updates on the database before migrating the data.

To ensure we're migrating the right data, also block the port if the SQLite3 database still has pending or ongoing background updates.

Fixes #4877
---
 changelog.d/6102.bugfix |   1 +
 scripts/synapse_port_db | 186 ++++++++++++++++++++++++++++------------
 2 files changed, 133 insertions(+), 54 deletions(-)
 create mode 100644 changelog.d/6102.bugfix

diff --git a/changelog.d/6102.bugfix b/changelog.d/6102.bugfix
new file mode 100644
index 0000000000..cd288c2a44
--- /dev/null
+++ b/changelog.d/6102.bugfix
@@ -0,0 +1 @@
+Make the `synapse_port_db` script create the right indexes on a new PostgreSQL database.
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 3f942abdb6..5a34d6f2f5 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -2,6 +2,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
 # Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -29,9 +30,23 @@ import yaml
 from twisted.enterprise import adbapi
 from twisted.internet import defer, reactor
 
-from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.config.homeserver import HomeServerConfig
+from synapse.logging.context import PreserveLoggingContext
+from synapse.storage._base import LoggingTransaction
+from synapse.storage.client_ips import ClientIpBackgroundUpdateStore
+from synapse.storage.deviceinbox import DeviceInboxBackgroundUpdateStore
+from synapse.storage.devices import DeviceBackgroundUpdateStore
 from synapse.storage.engines import create_engine
+from synapse.storage.events_bg_updates import EventsBackgroundUpdatesStore
+from synapse.storage.media_repository import MediaRepositoryBackgroundUpdateStore
 from synapse.storage.prepare_database import prepare_database
+from synapse.storage.registration import RegistrationBackgroundUpdateStore
+from synapse.storage.roommember import RoomMemberBackgroundUpdateStore
+from synapse.storage.search import SearchBackgroundUpdateStore
+from synapse.storage.state import StateBackgroundUpdateStore
+from synapse.storage.stats import StatsStore
+from synapse.storage.user_directory import UserDirectoryBackgroundUpdateStore
+from synapse.util import Clock
 
 logger = logging.getLogger("synapse_port_db")
 
@@ -98,33 +113,24 @@ APPEND_ONLY_TABLES = [
 end_error_exec_info = None
 
 
-class Store(object):
-    """This object is used to pull out some of the convenience API from the
-    Storage layer.
-
-    *All* database interactions should go through this object.
-    """
-
-    def __init__(self, db_pool, engine):
-        self.db_pool = db_pool
-        self.database_engine = engine
-
-    _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
-    _simple_insert = SQLBaseStore.__dict__["_simple_insert"]
-
-    _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
-    _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
-    _simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
-    _simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
-    _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
-    _simple_select_one_onecol_txn = SQLBaseStore.__dict__[
-        "_simple_select_one_onecol_txn"
-    ]
-
-    _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
-    _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
-    _simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"]
+class Store(
+    ClientIpBackgroundUpdateStore,
+    DeviceInboxBackgroundUpdateStore,
+    DeviceBackgroundUpdateStore,
+    EventsBackgroundUpdatesStore,
+    MediaRepositoryBackgroundUpdateStore,
+    RegistrationBackgroundUpdateStore,
+    RoomMemberBackgroundUpdateStore,
+    SearchBackgroundUpdateStore,
+    StateBackgroundUpdateStore,
+    UserDirectoryBackgroundUpdateStore,
+    StatsStore,
+):
+    def __init__(self, db_conn, hs):
+        super().__init__(db_conn, hs)
+        self.db_pool = hs.get_db_pool()
 
+    @defer.inlineCallbacks
     def runInteraction(self, desc, func, *args, **kwargs):
         def r(conn):
             try:
@@ -150,7 +156,8 @@ class Store(object):
                 logger.debug("[TXN FAIL] {%s} %s", desc, e)
                 raise
 
-        return self.db_pool.runWithConnection(r)
+        with PreserveLoggingContext():
+            return (yield self.db_pool.runWithConnection(r))
 
     def execute(self, f, *args, **kwargs):
         return self.runInteraction(f.__name__, f, *args, **kwargs)
@@ -176,6 +183,25 @@ class Store(object):
             raise
 
 
+class MockHomeserver:
+    def __init__(self, config, database_engine, db_conn, db_pool):
+        self.database_engine = database_engine
+        self.db_conn = db_conn
+        self.db_pool = db_pool
+        self.clock = Clock(reactor)
+        self.config = config
+        self.hostname = config.server_name
+
+    def get_db_conn(self):
+        return self.db_conn
+
+    def get_db_pool(self):
+        return self.db_pool
+
+    def get_clock(self):
+        return self.clock
+
+
 class Porter(object):
     def __init__(self, **kwargs):
         self.__dict__.update(kwargs)
@@ -447,31 +473,75 @@ class Porter(object):
 
         db_conn.commit()
 
+        return db_conn
+
+    @defer.inlineCallbacks
+    def build_db_store(self, config):
+        """Builds and returns a database store using the provided configuration.
+
+        Args:
+            config: The database configuration, i.e. a dict following the structure of
+                the "database" section of Synapse's configuration file.
+
+        Returns:
+            The built Store object.
+        """
+        engine = create_engine(config)
+
+        self.progress.set_state("Preparing %s" % config["name"])
+        conn = self.setup_db(config, engine)
+
+        db_pool = adbapi.ConnectionPool(
+            config["name"], **config["args"]
+        )
+
+        hs = MockHomeserver(self.hs_config, engine, conn, db_pool)
+
+        store = Store(conn, hs)
+
+        yield store.runInteraction(
+            "%s_engine.check_database" % config["name"],
+            engine.check_database,
+        )
+
+        return store
+
+    @defer.inlineCallbacks
+    def run_background_updates_on_postgres(self):
+        # Manually apply all background updates on the PostgreSQL database.
+        postgres_ready = yield self.postgres_store.has_completed_background_updates()
+
+        if not postgres_ready:
+            # Only say that we're running background updates when there are background
+            # updates to run.
+            self.progress.set_state("Running background updates on PostgreSQL")
+
+        while not postgres_ready:
+            yield self.postgres_store.do_next_background_update(100)
+            postgres_ready = yield (
+                self.postgres_store.has_completed_background_updates()
+            )
+
     @defer.inlineCallbacks
     def run(self):
         try:
-            sqlite_db_pool = adbapi.ConnectionPool(
-                self.sqlite_config["name"], **self.sqlite_config["args"]
+            self.sqlite_store = yield self.build_db_store(self.sqlite_config)
+
+            # Check if all background updates are done, abort if not.
+            updates_complete = yield self.sqlite_store.has_completed_background_updates()
+            if not updates_complete:
+                sys.stderr.write(
+                    "Pending background updates exist in the SQLite3 database."
+                    " Please start Synapse again and wait until every update has finished"
+                    " before running this script.\n"
+                )
+                defer.returnValue(None)
+
+            self.postgres_store = yield self.build_db_store(
+                self.hs_config.database_config
             )
 
-            postgres_db_pool = adbapi.ConnectionPool(
-                self.postgres_config["name"], **self.postgres_config["args"]
-            )
-
-            sqlite_engine = create_engine(sqlite_config)
-            postgres_engine = create_engine(postgres_config)
-
-            self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
-            self.postgres_store = Store(postgres_db_pool, postgres_engine)
-
-            yield self.postgres_store.execute(postgres_engine.check_database)
-
-            # Step 1. Set up databases.
-            self.progress.set_state("Preparing SQLite3")
-            self.setup_db(sqlite_config, sqlite_engine)
-
-            self.progress.set_state("Preparing PostgreSQL")
-            self.setup_db(postgres_config, postgres_engine)
+            yield self.run_background_updates_on_postgres()
 
             self.progress.set_state("Creating port tables")
 
@@ -563,6 +633,8 @@ class Porter(object):
         def conv(j, col):
             if j in bool_cols:
                 return bool(col)
+            if isinstance(col, bytes):
+                return bytearray(col)
             elif isinstance(col, string_types) and "\0" in col:
                 logger.warn(
                     "DROPPING ROW: NUL value in table %s col %s: %r",
@@ -926,18 +998,24 @@ if __name__ == "__main__":
         },
     }
 
-    postgres_config = yaml.safe_load(args.postgres_config)
+    hs_config = yaml.safe_load(args.postgres_config)
 
-    if "database" in postgres_config:
-        postgres_config = postgres_config["database"]
+    if "database" not in hs_config:
+        sys.stderr.write("The configuration file must have a 'database' section.\n")
+        sys.exit(4)
+
+    postgres_config = hs_config["database"]
 
     if "name" not in postgres_config:
-        sys.stderr.write("Malformed database config: no 'name'")
+        sys.stderr.write("Malformed database config: no 'name'\n")
         sys.exit(2)
     if postgres_config["name"] != "psycopg2":
-        sys.stderr.write("Database must use 'psycopg2' connector.")
+        sys.stderr.write("Database must use the 'psycopg2' connector.\n")
         sys.exit(3)
 
+    config = HomeServerConfig()
+    config.parse_config_dict(hs_config, "", "")
+
     def start(stdscr=None):
         if stdscr:
             progress = CursesProgress(stdscr)
@@ -946,9 +1024,9 @@ if __name__ == "__main__":
 
         porter = Porter(
             sqlite_config=sqlite_config,
-            postgres_config=postgres_config,
             progress=progress,
             batch_size=args.batch_size,
+            hs_config=config,
         )
 
         reactor.callWhenRunning(porter.run)