1
0

Compare commits

..

5 Commits

Author SHA1 Message Date
Richard van der Hoff
d0dc0014c1 update sample config 2019-03-05 17:32:34 +00:00
Richard van der Hoff
054ed1ab5b tweak changelog 2019-03-05 17:28:52 +00:00
Matthew Hodgson
b556fcda72 changelog 2019-03-05 00:01:17 +00:00
Matthew Hodgson
11d64ec4a9 remove trailing space 2019-03-04 23:50:14 +00:00
Matthew Hodgson
d70c2748af reword the sample config header to be less scary 2019-03-04 23:49:54 +00:00
21 changed files with 858 additions and 909 deletions

View File

@@ -1 +0,0 @@
Fix attempting to paginate in rooms where server cannot see any events, to avoid unnecessarily pulling in lots of redacted events.

View File

@@ -1 +0,0 @@
Add support for /keys/query and /keys/changes REST endpoints to client_reader worker.

View File

@@ -1 +0,0 @@
Clean up some replication code.

1
changelog.d/4801.feature Normal file
View File

@@ -0,0 +1 @@
Include a default configuration file in the 'docs' directory.

View File

@@ -1,7 +1,12 @@
# This file is a reference to the configuration options which can be set in
# homeserver.yaml.
# The config is maintained as an up-to-date snapshot of the default
# homeserver.yaml configuration generated by Synapse.
#
# Note that it is not quite ready to be used as-is. If you are starting from
# scratch, it is easier to generate the config files following the instructions
# in INSTALL.md.
# It is intended to act as a reference for the default configuration,
# helping admins keep track of new options and other changes, and compare
# their configs with the current default. As such, many of the actual
# config values shown are placeholders.
#
# It is *not* intended to be copied and used as the basis for a real
# homeserver.yaml. Instead, if you are starting from scratch, please generate
# a fresh config using Synapse by following the instructions in INSTALL.md.

View File

@@ -1,9 +1,14 @@
# This file is a reference to the configuration options which can be set in
# homeserver.yaml.
# The config is maintained as an up-to-date snapshot of the default
# homeserver.yaml configuration generated by Synapse.
#
# Note that it is not quite ready to be used as-is. If you are starting from
# scratch, it is easier to generate the config files following the instructions
# in INSTALL.md.
# It is intended to act as a reference for the default configuration,
# helping admins keep track of new options and other changes, and compare
# their configs with the current default. As such, many of the actual
# config values shown are placeholders.
#
# It is *not* intended to be copied and used as the basis for a real
# homeserver.yaml. Instead, if you are starting from scratch, please generate
# a fresh config using Synapse by following the instructions in INSTALL.md.
## Server ##

View File

@@ -225,8 +225,6 @@ following regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
^/_matrix/client/(api/v1|r0|unstable)/login$
^/_matrix/client/(api/v1|r0|unstable)/account/3pid$
^/_matrix/client/(api/v1|r0|unstable)/keys/query$
^/_matrix/client/(api/v1|r0|unstable)/keys/changes$
Additionally, the following REST endpoints can be handled, but all requests must
be routed to the same instance::

View File

@@ -27,4 +27,4 @@ try:
except ImportError:
pass
__version__ = "0.99.2.post1"
__version__ = "0.99.2"

View File

@@ -33,13 +33,9 @@ from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
@@ -53,7 +49,6 @@ from synapse.rest.client.v1.room import (
RoomStateRestServlet,
)
from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
@@ -66,10 +61,6 @@ logger = logging.getLogger("synapse.app.client_reader")
class ClientReaderSlavedStore(
SlavedDeviceInboxStore,
SlavedDeviceStore,
SlavedReceiptsStore,
SlavedPushRuleStore,
SlavedAccountDataStore,
SlavedEventStore,
SlavedKeyStore,
@@ -107,8 +98,6 @@ class ClientReaderServer(HomeServer):
RegisterRestServlet(self).register(resource)
LoginRestServlet(self).register(resource)
ThreepidRestServlet(self).register(resource)
KeyQueryServlet(self).register(resource)
KeyChangesServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,

View File

@@ -37,185 +37,13 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
class DeviceWorkerHandler(BaseHandler):
class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceWorkerHandler, self).__init__(hs)
super(DeviceHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self._auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""
Retrieve the given user's devices
Args:
user_id (str):
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
"""
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
user_id, device_id=None
)
devices = list(device_map.values())
for device in devices:
_update_device_from_client_ips(device, ips)
defer.returnValue(devices)
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred: dict[str, X]: info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
user_id, device_id,
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@measure_func("device.get_user_ids_changed")
@defer.inlineCallbacks
def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
Args:
user_id (str)
from_token (StreamToken)
"""
now_room_key = yield self.store.get_room_events_max_id()
room_ids = yield self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed
changed = yield self.store.get_user_whose_devices_changed(
from_token.device_list_key
)
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
member_events = yield self.store.get_membership_changes_for_user(
user_id, from_token.room_key, now_room_key,
)
rooms_changed.update(event.room_id for event in member_events)
stream_ordering = RoomStreamToken.parse_stream_token(
from_token.room_key
).stream
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
current_state_ids = yield self.store.get_current_state_ids(room_id)
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids:
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_left.add(state_key)
continue
# Fetch the current state at the time.
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering
)
except errors.StoreError:
# we have purged the stream_ordering index since the stream
# ordering: treat it the same as a new room
event_ids = []
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
continue
current_member_id = current_state_ids.get((EventTypes.Member, user_id))
if not current_member_id:
continue
# mapping from event_id -> state_dict
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
for state_dict in itervalues(prev_state_ids):
member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id:
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
break
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
for state_dict in itervalues(prev_state_ids):
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
if state_key != user_id:
possibly_changed.add(state_key)
break
if possibly_changed or possibly_left:
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
# Take the intersection of the users whose devices may have changed
# and those that actually still share a room with the user
possibly_joined = possibly_changed & users_who_share_room
possibly_left = (possibly_changed | possibly_left) - users_who_share_room
else:
possibly_joined = []
possibly_left = []
defer.returnValue({
"changed": list(possibly_joined),
"left": list(possibly_left),
})
class DeviceHandler(DeviceWorkerHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
self.federation_sender = hs.get_federation_sender()
self._edu_updater = DeviceListEduUpdater(hs, self)
@@ -275,6 +103,52 @@ class DeviceHandler(DeviceWorkerHandler):
raise errors.StoreError(500, "Couldn't generate a device ID.")
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""
Retrieve the given user's devices
Args:
user_id (str):
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
"""
device_map = yield self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(
user_id, device_id=None
)
devices = list(device_map.values())
for device in devices:
_update_device_from_client_ips(device, ips)
defer.returnValue(devices)
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred: dict[str, X]: info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(
user_id, device_id,
)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
@@ -413,6 +287,126 @@ class DeviceHandler(DeviceWorkerHandler):
for host in hosts:
self.federation_sender.send_device_messages(host)
@measure_func("device.get_user_ids_changed")
@defer.inlineCallbacks
def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
Args:
user_id (str)
from_token (StreamToken)
"""
now_token = yield self.hs.get_event_sources().get_current_token()
room_ids = yield self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed
changed = yield self.store.get_user_whose_devices_changed(
from_token.device_list_key
)
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
member_events = yield self.store.get_membership_changes_for_user(
user_id, from_token.room_key, now_token.room_key
)
rooms_changed.update(event.room_id for event in member_events)
stream_ordering = RoomStreamToken.parse_stream_token(
from_token.room_key
).stream
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
current_state_ids = yield self.store.get_current_state_ids(room_id)
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids:
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_left.add(state_key)
continue
# Fetch the current state at the time.
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering
)
except errors.StoreError:
# we have purged the stream_ordering index since the stream
# ordering: treat it the same as a new room
event_ids = []
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
continue
current_member_id = current_state_ids.get((EventTypes.Member, user_id))
if not current_member_id:
continue
# mapping from event_id -> state_dict
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
for state_dict in itervalues(prev_state_ids):
member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id:
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
break
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
for state_dict in itervalues(prev_state_ids):
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
if state_key != user_id:
possibly_changed.add(state_key)
break
if possibly_changed or possibly_left:
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
# Take the intersection of the users whose devices may have changed
# and those that actually still share a room with the user
possibly_joined = possibly_changed & users_who_share_room
possibly_left = (possibly_changed | possibly_left) - users_who_share_room
else:
possibly_joined = []
possibly_left = []
defer.returnValue({
"changed": list(possibly_joined),
"left": list(possibly_left),
})
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)

View File

@@ -858,52 +858,6 @@ class FederationHandler(BaseHandler):
logger.debug("Not backfilling as no extremeties found.")
return
# We only want to paginate if we can actually see the events we'll get,
# as otherwise we'll just spend a lot of resources to get redacted
# events.
#
# We do this by filtering all the backwards extremities and seeing if
# any remain. Given we don't have the extremity events themselves, we
# need to actually check the events that reference them.
#
# *Note*: the spec wants us to keep backfilling until we reach the start
# of the room in case we are allowed to see some of the history. However
# in practice that causes more issues than its worth, as a) its
# relatively rare for there to be any visible history and b) even when
# there is its often sufficiently long ago that clients would stop
# attempting to paginate before backfill reached the visible history.
#
# TODO: If we do do a backfill then we should filter the backwards
# extremities to only include those that point to visible portions of
# history.
#
# TODO: Correctly handle the case where we are allowed to see the
# forward event but not the backward extremity, e.g. in the case of
# initial join of the server where we are allowed to see the join
# event but not anything before it. This would require looking at the
# state *before* the event, ignoring the special casing certain event
# types have.
forward_events = yield self.store.get_successor_events(
list(extremities),
)
extremities_events = yield self.store.get_events(
forward_events,
check_redacted=False,
get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = yield filter_events_for_server(
self.store, self.server_name, list(extremities_events.values()),
redact=False, check_history_visibility_only=True,
)
if not filtered_extremities:
defer.returnValue(False)
# Check if we reached a point where we should start backfilling.
sorted_extremeties_tuple = sorted(
extremities.items(),

View File

@@ -178,6 +178,8 @@ class Notifier(object):
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
)
self.replication_deferred = ObservableDeferred(defer.Deferred())
# This is not a very cheap test to perform, but it's only executed
# when rendering the metrics page, which is likely once per minute at
# most when scraping it.
@@ -203,9 +205,7 @@ class Notifier(object):
def add_replication_callback(self, cb):
"""Add a callback that will be called when some new data is available.
Callback is not given any arguments. It should *not* return a Deferred - if
it needs to do any asynchronous work, a background thread should be started and
wrapped with run_as_background_process.
Callback is not given any arguments.
"""
self.replication_callbacks.append(cb)
@@ -517,5 +517,60 @@ class Notifier(object):
def notify_replication(self):
"""Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks:
cb()
with PreserveLoggingContext():
deferred = self.replication_deferred
self.replication_deferred = ObservableDeferred(defer.Deferred())
deferred.callback(None)
# the callbacks may well outlast the current request, so we run
# them in the sentinel logcontext.
#
# (ideally it would be up to the callbacks to know if they were
# starting off background processes and drop the logcontext
# accordingly, but that requires more changes)
for cb in self.replication_callbacks:
cb()
@defer.inlineCallbacks
def wait_for_replication(self, callback, timeout):
"""Wait for an event to happen.
Args:
callback: Gets called whenever an event happens. If this returns a
truthy value then ``wait_for_replication`` returns, otherwise
it waits for another event.
timeout: How many milliseconds to wait for callback return a truthy
value.
Returns:
A deferred that resolves with the value returned by the callback.
"""
listener = _NotificationListener(None)
end_time = self.clock.time_msec() + timeout
while True:
listener.deferred = self.replication_deferred.observe()
result = yield callback()
if result:
break
now = self.clock.time_msec()
if end_time <= now:
break
listener.deferred = timeout_deferred(
listener.deferred,
timeout=(end_time - now) / 1000.,
reactor=self.hs.get_reactor(),
)
try:
with PreserveLoggingContext():
yield listener.deferred
except defer.TimeoutError:
break
except defer.CancelledError:
break
defer.returnValue(result)

View File

@@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.deviceinbox import DeviceInboxWorkerStore
from synapse.storage import DataStore
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
from ._slaved_id_tracker import SlavedIdTracker
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
class SlavedDeviceInboxStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
@@ -42,6 +43,12 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
expiry_ms=30 * 60 * 1000,
)
get_to_device_stream_token = __func__(DataStore.get_to_device_stream_token)
get_new_messages_for_device = __func__(DataStore.get_new_messages_for_device)
get_new_device_msgs_for_remote = __func__(DataStore.get_new_device_msgs_for_remote)
delete_messages_for_device = __func__(DataStore.delete_messages_for_device)
delete_device_msgs_for_remote = __func__(DataStore.delete_device_msgs_for_remote)
def stream_positions(self):
result = super(SlavedDeviceInboxStore, self).stream_positions()
result["to_device"] = self._device_inbox_id_gen.get_current_token()

View File

@@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage.devices import DeviceWorkerStore
from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.storage import DataStore
from synapse.storage.end_to_end_keys import EndToEndKeyStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
from ._slaved_id_tracker import SlavedIdTracker
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
class SlavedDeviceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
@@ -37,6 +38,17 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max,
)
get_device_stream_token = __func__(DataStore.get_device_stream_token)
get_user_whose_devices_changed = __func__(DataStore.get_user_whose_devices_changed)
get_devices_by_remote = __func__(DataStore.get_devices_by_remote)
_get_devices_by_remote_txn = __func__(DataStore._get_devices_by_remote_txn)
_get_e2e_device_keys_txn = __func__(DataStore._get_e2e_device_keys_txn)
mark_as_sent_devices_by_remote = __func__(DataStore.mark_as_sent_devices_by_remote)
_mark_as_sent_devices_by_remote_txn = (
__func__(DataStore._mark_as_sent_devices_by_remote_txn)
)
count_e2e_one_time_keys = EndToEndKeyStore.__dict__["count_e2e_one_time_keys"]
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token()
@@ -46,23 +58,14 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
if stream_name == "device_lists":
self._device_list_id_gen.advance(token)
for row in rows:
self._invalidate_caches_for_devices(
token, row.user_id, row.destination,
self._device_list_stream_cache.entity_has_changed(
row.user_id, token
)
if row.destination:
self._device_list_federation_stream_cache.entity_has_changed(
row.destination, token
)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
def _invalidate_caches_for_devices(self, token, user_id, destination):
self._device_list_stream_cache.entity_has_changed(
user_id, token
)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
self._get_cached_devices_for_user.invalidate((user_id,))
self._get_cached_user_device.invalidate_many((user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))

View File

@@ -20,7 +20,7 @@ from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
class SlavedPushRuleStore(PushRulesWorkerStore, SlavedEventStore):
def __init__(self, db_conn, hs):
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id",

View File

@@ -51,7 +51,7 @@ from synapse.handlers.acme import AcmeHandler
from synapse.handlers.appservice import ApplicationServicesHandler
from synapse.handlers.auth import AuthHandler, MacaroonGenerator
from synapse.handlers.deactivate_account import DeactivateAccountHandler
from synapse.handlers.device import DeviceHandler, DeviceWorkerHandler
from synapse.handlers.device import DeviceHandler
from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
@@ -307,10 +307,7 @@ class HomeServer(object):
return MacaroonGenerator(self)
def build_device_handler(self):
if self.config.worker_app:
return DeviceWorkerHandler(self)
else:
return DeviceHandler(self)
return DeviceHandler(self)
def build_device_message_handler(self):
return DeviceMessageHandler(self)

View File

@@ -19,174 +19,14 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.expiringcache import ExpiringCache
from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore):
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
current_stream_id(int): The current position of the to device
message stream.
Returns:
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_device_txn(txn):
sql = (
"SELECT stream_id, message_json FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
user_id, device_id, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn,
)
@defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves to the number of messages deleted.
"""
# If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), None
)
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_deleted_stream_id
)
if not has_changed:
defer.returnValue(0)
def delete_messages_for_device_txn(txn):
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
count = yield self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
last_deleted_stream_id, up_to_stream_id
)
defer.returnValue(count)
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit=100
):
"""
Args:
destination(str): The name of the remote server.
last_stream_id(int|long): The last position of the device message stream
that the server sent up to.
current_stream_id(int|long): The current position of the device
message stream.
Returns:
Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
destination, last_stream_id
)
if not has_changed or last_stream_id == current_stream_id:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_remote_destination_txn(txn):
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
destination, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
destination(str): The destination server_name
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
"""
def delete_messages_for_remote_destination_txn(txn):
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
)
class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
class DeviceInboxStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
@@ -380,6 +220,93 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
txn.executemany(sql, rows)
def get_new_messages_for_device(
self, user_id, device_id, last_stream_id, current_stream_id, limit=100
):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
current_stream_id(int): The current position of the to device
message stream.
Returns:
Deferred ([dict], int): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_stream_id
)
if not has_changed:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_device_txn(txn):
sql = (
"SELECT stream_id, message_json FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
user_id, device_id, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn,
)
@defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
"""
Args:
user_id(str): The recipient user_id.
device_id(str): The recipient device_id.
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves to the number of messages deleted.
"""
# If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), None
)
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_deleted_stream_id
)
if not has_changed:
defer.returnValue(0)
def delete_messages_for_device_txn(txn):
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
count = yield self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
last_deleted_stream_id, up_to_stream_id
)
defer.returnValue(count)
def get_all_new_device_messages(self, last_pos, current_pos, limit):
"""
Args:
@@ -424,6 +351,77 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
"get_all_new_device_messages", get_all_new_device_messages_txn
)
def get_to_device_stream_token(self):
return self._device_inbox_id_gen.get_current_token()
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit=100
):
"""
Args:
destination(str): The name of the remote server.
last_stream_id(int|long): The last position of the device message stream
that the server sent up to.
current_stream_id(int|long): The current position of the device
message stream.
Returns:
Deferred ([dict], int|long): List of messages for the device and where
in the stream the messages got to.
"""
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
destination, last_stream_id
)
if not has_changed or last_stream_id == current_stream_id:
return defer.succeed(([], current_stream_id))
def get_new_messages_for_remote_destination_txn(txn):
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
" WHERE destination = ?"
" AND ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
destination, last_stream_id, current_stream_id, limit
))
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
return self.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
destination(str): The destination server_name
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
"""
def delete_messages_for_remote_destination_txn(txn):
sql = (
"DELETE FROM device_federation_outbox"
" WHERE destination = ?"
" AND stream_id <= ?"
)
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):

View File

@@ -22,10 +22,11 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import Cache, SQLBaseStore, db_to_json
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from ._base import Cache, db_to_json
logger = logging.getLogger(__name__)
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@@ -33,343 +34,7 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
)
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
"""Retrieve a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
Raises:
StoreError: if the device is not found
"""
return self._simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices.
Args:
user_id (str):
Returns:
defer.Deferred: resolves to a dict from device_id to a dict
containing "device_id", "user_id" and "display_name" for each
device.
"""
devices = yield self._simple_select_list(
table="devices",
keyvalues={"user_id": user_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user"
)
defer.returnValue({d["device_id"]: d for d in devices})
def get_devices_by_remote(self, destination, from_stream_id):
"""Get stream of updates to send to remote servers
Returns:
(int, list[dict]): current stream id and list of updates
"""
now_stream_id = self._device_list_id_gen.get_current_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])
return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn,
destination, from_stream_id, now_stream_id,
)
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
now_stream_id):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
LIMIT 20
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
)
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
results = []
for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
}
prev_id = stream_id
if device is not None:
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
else:
result["deleted"] = True
results.append(result)
return (now_stream_id, results)
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
destination, stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# We update the device_lists_outbound_last_success with the successfully
# poked users. We do the join to see which users need to be inserted and
# which updated.
sql = """
SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
FROM device_lists_outbound_pokes as o
LEFT JOIN device_lists_outbound_last_success as s
USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
UPDATE device_lists_outbound_last_success
SET stream_id = ?
WHERE destination = ? AND user_id = ?
"""
txn.executemany(
sql, ((row[1], destination, row[0],) for row in rows if row[2])
)
sql = """
INSERT INTO device_lists_outbound_last_success
(destination, user_id, stream_id) VALUES (?, ?, ?)
"""
txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows if not row[2])
)
# Delete all sent outbound pokes
sql = """
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
"""
txn.execute(sql, (destination, stream_id,))
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
@defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
Args:
query_list(list): List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
user_ids = set(user_id for user_id, _ in query_list)
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
user_ids_in_cache = set(
user_id for user_id, stream_id in user_map.items() if stream_id
)
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
if device_id:
device = yield self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
results[user_id] = yield self._get_cached_devices_for_user(user_id)
defer.returnValue((user_ids_not_in_cache, results))
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
content = yield self._simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="content",
desc="_get_cached_user_device",
)
defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
devices = yield self._simple_select_list(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
retcols=("device_id", "content"),
desc="_get_cached_devices_for_user",
)
defer.returnValue({
device["device_id"]: db_to_json(device["content"])
for device in devices
})
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
Returns:
(stream_id, devices)
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
txn, [(user_id, None)], include_all_devices=True
)
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in iteritems(user_devices):
result = {
"device_id": device_id,
}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
"""Get set of users whose devices have changed since `from_key`.
"""
from_key = int(from_key)
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
if changed is not None:
defer.returnValue(set(changed))
sql = """
SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
# We do a group by here as there can be a large number of duplicate
# entries, since we throw away device IDs.
sql = """
SELECT MAX(stream_id) AS stream_id, user_id, destination
FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
return self._execute(
"get_all_device_list_changes_for_remotes", None,
sql, from_key, to_key
)
@cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self._simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
desc="get_device_list_last_stream_id_for_remote",
allow_none=True,
)
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids", inlineCallbacks=True)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self._simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id",),
desc="get_device_list_last_stream_id_for_remotes",
)
results = {user_id: None for user_id in user_ids}
results.update({
row["user_id"]: row["stream_id"] for row in rows
})
defer.returnValue(results)
class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
class DeviceStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs)
@@ -456,6 +121,24 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
initial_device_display_name, e)
raise StoreError(500, "Problem storing device.")
def get_device(self, user_id, device_id):
"""Retrieve a device.
Args:
user_id (str): The ID of the user which owns the device
device_id (str): The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
Raises:
StoreError: if the device is not found
"""
return self._simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
"""Delete a device.
@@ -519,6 +202,57 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
desc="update_device",
)
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""Retrieve all of a user's registered devices.
Args:
user_id (str):
Returns:
defer.Deferred: resolves to a dict from device_id to a dict
containing "device_id", "user_id" and "display_name" for each
device.
"""
devices = yield self._simple_select_list(
table="devices",
keyvalues={"user_id": user_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user"
)
defer.returnValue({d["device_id"]: d for d in devices})
@cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self._simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
desc="get_device_list_remote_extremity",
allow_none=True,
)
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids", inlineCallbacks=True)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self._simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id",),
desc="get_user_devices_from_cache",
)
results = {user_id: None for user_id in user_ids}
results.update({
row["user_id"]: row["stream_id"] for row in rows
})
defer.returnValue(results)
@defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
@@ -671,6 +405,268 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
lock=False,
)
def get_devices_by_remote(self, destination, from_stream_id):
"""Get stream of updates to send to remote servers
Returns:
(int, list[dict]): current stream id and list of updates
"""
now_stream_id = self._device_list_id_gen.get_current_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])
return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn,
destination, from_stream_id, now_stream_id,
)
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
now_stream_id):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
LIMIT 20
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in txn}
if not query_map:
return (now_stream_id, [])
if len(query_map) >= 20:
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
)
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
results = []
for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
}
prev_id = stream_id
if device is not None:
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
else:
result["deleted"] = True
results.append(result)
return (now_stream_id, results)
@defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
Args:
query_list(list): List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
user_ids = set(user_id for user_id, _ in query_list)
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
user_ids_in_cache = set(
user_id for user_id, stream_id in user_map.items() if stream_id
)
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
if device_id:
device = yield self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
results[user_id] = yield self._get_cached_devices_for_user(user_id)
defer.returnValue((user_ids_not_in_cache, results))
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
content = yield self._simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="content",
desc="_get_cached_user_device",
)
defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
devices = yield self._simple_select_list(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
retcols=("device_id", "content"),
desc="_get_cached_devices_for_user",
)
defer.returnValue({
device["device_id"]: db_to_json(device["content"])
for device in devices
})
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
Returns:
(stream_id, devices)
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
txn, [(user_id, None)], include_all_devices=True
)
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in iteritems(user_devices):
result = {
"device_id": device_id,
}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
destination, stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# We update the device_lists_outbound_last_success with the successfully
# poked users. We do the join to see which users need to be inserted and
# which updated.
sql = """
SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
FROM device_lists_outbound_pokes as o
LEFT JOIN device_lists_outbound_last_success as s
USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
UPDATE device_lists_outbound_last_success
SET stream_id = ?
WHERE destination = ? AND user_id = ?
"""
txn.executemany(
sql, ((row[1], destination, row[0],) for row in rows if row[2])
)
sql = """
INSERT INTO device_lists_outbound_last_success
(destination, user_id, stream_id) VALUES (?, ?, ?)
"""
txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows if not row[2])
)
# Delete all sent outbound pokes
sql = """
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
"""
txn.execute(sql, (destination, stream_id,))
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
"""Get set of users whose devices have changed since `from_key`.
"""
from_key = int(from_key)
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
if changed is not None:
defer.returnValue(set(changed))
sql = """
SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
# We do a group by here as there can be a large number of duplicate
# entries, since we throw away device IDs.
sql = """
SELECT MAX(stream_id) AS stream_id, user_id, destination
FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
return self._execute(
"get_all_device_list_changes_for_remotes", None,
sql, from_key, to_key
)
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
@@ -736,6 +732,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
]
)
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
def _prune_old_outbound_device_pokes(self):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers. We keep one entry per

View File

@@ -23,7 +23,49 @@ from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore, db_to_json
class EndToEndKeyWorkerStore(SQLBaseStore):
class EndToEndKeyStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
return False
self._simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": new_key_json,
}
)
return True
return self.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
@defer.inlineCallbacks
def get_e2e_device_keys(
self, query_list, include_all_devices=False,
@@ -196,50 +238,6 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="key_json",
allow_none=True,
)
# In py3 we need old_key_json to match new_key_json type. The DB
# returns unicode while encode_canonical_json returns bytes.
new_key_json = encode_canonical_json(device_keys).decode("utf-8")
if old_key_json == new_key_json:
return False
self._simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": new_key_json,
}
)
return True
return self.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
def _claim_e2e_one_time_keys(txn):

View File

@@ -442,28 +442,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
event_results.reverse()
return event_results
@defer.inlineCallbacks
def get_successor_events(self, event_ids):
"""Fetch all events that have the given events as a prev event
Args:
event_ids (iterable[str])
Returns:
Deferred[list[str]]
"""
rows = yield self._simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=event_ids,
retcols=("event_id",),
desc="get_successor_events"
)
defer.returnValue([
row["event_id"] for row in rows
])
class EventFederationStore(EventFederationWorkerStore):
""" Responsible for storing and serving up the various graphs associated

View File

@@ -216,36 +216,28 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
@defer.inlineCallbacks
def filter_events_for_server(store, server_name, events, redact=True,
check_history_visibility_only=False):
"""Filter a list of events based on whether given server is allowed to
see them.
def filter_events_for_server(store, server_name, events):
# Whatever else we do, we need to check for senders which have requested
# erasure of their data.
erased_senders = yield store.are_users_erased(
(e.sender for e in events),
)
Args:
store (DataStore)
server_name (str)
events (iterable[FrozenEvent])
redact (bool): Whether to return a redacted version of the event, or
to filter them out entirely.
check_history_visibility_only (bool): Whether to only check the
history visibility, rather than things like if the sender has been
erased. This is used e.g. during pagination to decide whether to
backfill or not.
Returns
Deferred[list[FrozenEvent]]
"""
def is_sender_erased(event, erased_senders):
if erased_senders and erased_senders[event.sender]:
def redact_disallowed(event, state):
# if the sender has been gdpr17ed, always return a redacted
# copy of the event.
if erased_senders[event.sender]:
logger.info(
"Sender of %s has been erased, redacting",
event.event_id,
)
return True
return False
return prune_event(event)
# state will be None if we decided we didn't need to filter by
# room membership.
if not state:
return event
def check_event_is_visible(event, state):
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
@@ -267,17 +259,17 @@ def filter_events_for_server(store, server_name, events, redact=True,
memtype = ev.membership
if memtype == Membership.JOIN:
return True
return event
elif memtype == Membership.INVITE:
if visibility == "invited":
return True
return event
else:
# server has no users in the room: redact
return False
return prune_event(event)
return True
return event
# Lets check to see if all the events have a history visibility
# Next lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If thats the case then we don't
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield store.get_state_ids_for_events(
@@ -304,31 +296,16 @@ def filter_events_for_server(store, server_name, events, redact=True,
for e in itervalues(event_map)
)
if not check_history_visibility_only:
erased_senders = yield store.are_users_erased(
(e.sender for e in events),
)
else:
# We don't want to check whether users are erased, which is equivalent
# to no users having been erased.
erased_senders = {}
if all_open:
# all the history_visibility state affecting these events is open, so
# we don't need to filter by membership state. We *do* need to check
# for user erasure, though.
if erased_senders:
to_return = []
for e in events:
if not is_sender_erased(e, erased_senders):
to_return.append(e)
elif redact:
to_return.append(prune_event(e))
events = [
redact_disallowed(e, None)
for e in events
]
defer.returnValue(to_return)
# If there are no erased users then we can just return the given list
# of events without having to copy it.
defer.returnValue(events)
# Ok, so we're dealing with events that have non-trivial visibility
@@ -384,13 +361,7 @@ def filter_events_for_server(store, server_name, events, redact=True,
for e_id, key_to_eid in iteritems(event_to_state_ids)
}
to_return = []
for e in events:
erased = is_sender_erased(e, erased_senders)
visible = check_event_is_visible(e, event_to_state[e.event_id])
if visible and not erased:
to_return.append(e)
elif redact:
to_return.append(prune_event(e))
defer.returnValue(to_return)
defer.returnValue([
redact_disallowed(e, event_to_state[e.event_id])
for e in events
])