1
0

Compare commits

...

16 Commits

Author SHA1 Message Date
Neil Johnson a17421fb23 towncrier 2019-03-28 17:27:53 +00:00
Neil Johnson 4f1d7d29a6 remove log line for password 2019-03-28 17:16:50 +00:00
Andrew Morgan 7a91b9d81c Allow password providers to bind emails (#4947)
This PR allows password provider modules to bind email addresses when a user is registering and is motivated by matrix-org/matrix-synapse-ldap3#58
2019-03-28 15:48:07 +00:00
Erik Johnston 248014379e Merge pull request #4942 from matrix-org/erikj/fix_presence
Use event streams to calculate presence
2019-03-28 14:38:31 +00:00
Erik Johnston 4e5f0f7ca0 Use an assert 2019-03-28 14:05:05 +00:00
Erik Johnston 40e56997bc Review comments 2019-03-28 13:48:41 +00:00
Richard van der Hoff 4eeb2c2f07 Merge pull request #4953 from matrix-org/rav/refactor_replication_streams
Split up replication.tcp.streams into smaller files
2019-03-28 13:43:25 +00:00
Amber Brown 2e060774ad Run black on some storage modules that the stats branch touches (#4959) 2019-03-29 00:37:16 +11:00
Richard van der Hoff 91c3513668 changelog 2019-03-27 21:30:01 +00:00
Richard van der Hoff 71dcb275f1 move FederationStream out to its own file 2019-03-27 21:13:14 +00:00
Richard van der Hoff aa1e017864 move EventsStream out to its own file 2019-03-27 21:13:14 +00:00
Richard van der Hoff a5798de067 Move replication.tcp.streams into a package 2019-03-27 21:13:14 +00:00
Richard van der Hoff acaa18f7dd Fix/improve some docstrings in the replication code. (#4949) 2019-03-27 21:12:36 +00:00
Erik Johnston d5a5d1c632 Newsfile 2019-03-27 13:41:36 +00:00
Erik Johnston b7fa834c40 Add unit tests 2019-03-27 13:41:36 +00:00
Erik Johnston 197fae1639 Use event streams to calculate presence
Primarily this fixes a bug in the handling of remote users joining a
room where the server sent out the presence for all local users in the
room to all servers in the room.

We also change to using the state delta stream, rather than the
distributor, as it will make it easier to split processing out of the
master process (as well as being more flexible).

Finally, when sending presence states to newly joined servers we filter
out old presence states to reduce the number sent. Initially we filter
out states that are offline and have a last active more than a week ago,
though this can be changed down the line.

Fixes #3962
2019-03-27 13:41:36 +00:00
23 changed files with 928 additions and 503 deletions
+1
View File
@@ -0,0 +1 @@
Fix bug where presence updates were sent to all servers in a room when a new server joined, rather than to just the new server.
+1
View File
@@ -0,0 +1 @@
Add ability for password provider modules to bind email addresses to users upon registration.
+1
View File
@@ -0,0 +1 @@
Fix/improve some docstrings in the replication code.
+2
View File
@@ -0,0 +1,2 @@
Split synapse.replication.tcp.streams into smaller files.
+1
View File
@@ -0,0 +1 @@
Run `black` to clean up formatting on `synapse/storage/roommember.py` and `synapse/storage/events.py`.
+1
View File
@@ -0,0 +1 @@
Remove log line for password via the admin API.
+1 -1
View File
@@ -38,7 +38,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.replication.tcp.streams import ReceiptsStream
from synapse.replication.tcp.streams._base import ReceiptsStream
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.types import ReadReceipt
+71 -2
View File
@@ -55,7 +55,12 @@ class FederationRemoteSendQueue(object):
self.is_mine_id = hs.is_mine_id
self.presence_map = {} # Pending presence map user_id -> UserPresenceState
self.presence_changed = SortedDict() # Stream position -> user_id
self.presence_changed = SortedDict() # Stream position -> list[user_id]
# Stores the destinations we need to explicitly send presence to about a
# given user.
# Stream position -> (user_id, destinations)
self.presence_destinations = SortedDict()
self.keyed_edu = {} # (destination, key) -> EDU
self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
@@ -77,7 +82,7 @@ class FederationRemoteSendQueue(object):
for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
"edus", "device_messages", "pos_time",
"edus", "device_messages", "pos_time", "presence_destinations",
]:
register(queue_name, getattr(self, queue_name))
@@ -121,6 +126,15 @@ class FederationRemoteSendQueue(object):
for user_id in uids
)
keys = self.presence_destinations.keys()
i = self.presence_destinations.bisect_left(position_to_delete)
for key in keys[:i]:
del self.presence_destinations[key]
user_ids.update(
user_id for user_id, _ in self.presence_destinations.values()
)
to_del = [
user_id for user_id in self.presence_map if user_id not in user_ids
]
@@ -209,6 +223,20 @@ class FederationRemoteSendQueue(object):
self.notifier.on_new_replication_data()
def send_presence_to_destinations(self, states, destinations):
"""As per FederationSender
Args:
states (list[UserPresenceState])
destinations (list[str])
"""
for state in states:
pos = self._next_pos()
self.presence_map.update({state.user_id: state for state in states})
self.presence_destinations[pos] = (state.user_id, destinations)
self.notifier.on_new_replication_data()
def send_device_messages(self, destination):
"""As per FederationSender"""
pos = self._next_pos()
@@ -261,6 +289,16 @@ class FederationRemoteSendQueue(object):
state=self.presence_map[user_id],
)))
# Fetch presence to send to destinations
i = self.presence_destinations.bisect_right(from_token)
j = self.presence_destinations.bisect_right(to_token) + 1
for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
rows.append((pos, PresenceDestinationsRow(
state=self.presence_map[user_id],
destinations=list(dests),
)))
# Fetch changes keyed edus
i = self.keyed_edu_changed.bisect_right(from_token)
j = self.keyed_edu_changed.bisect_right(to_token) + 1
@@ -357,6 +395,29 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
buff.presence.append(self.state)
class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", (
"state", # UserPresenceState
"destinations", # list[str]
))):
TypeId = "pd"
@staticmethod
def from_data(data):
return PresenceDestinationsRow(
state=UserPresenceState.from_dict(data["state"]),
destinations=data["dests"],
)
def to_data(self):
return {
"state": self.state.as_dict(),
"dests": self.destinations,
}
def add_to_buffer(self, buff):
buff.presence_destinations.append((self.state, self.destinations))
class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
"key", # tuple(str) - the edu key passed to send_edu
"edu", # Edu
@@ -428,6 +489,7 @@ TypeToRow = {
Row.TypeId: Row
for Row in (
PresenceRow,
PresenceDestinationsRow,
KeyedEduRow,
EduRow,
DeviceRow,
@@ -437,6 +499,7 @@ TypeToRow = {
ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
"presence", # list(UserPresenceState)
"presence_destinations", # list of tuples of UserPresenceState and destinations
"keyed_edus", # dict of destination -> { key -> Edu }
"edus", # dict of destination -> [Edu]
"device_destinations", # set of destinations
@@ -458,6 +521,7 @@ def process_rows_for_federation(transaction_queue, rows):
buff = ParsedFederationStreamData(
presence=[],
presence_destinations=[],
keyed_edus={},
edus={},
device_destinations=set(),
@@ -476,6 +540,11 @@ def process_rows_for_federation(transaction_queue, rows):
if buff.presence:
transaction_queue.send_presence(buff.presence)
for state, destinations in buff.presence_destinations:
transaction_queue.send_presence_to_destinations(
states=[state], destinations=destinations,
)
for destination, edu_map in iteritems(buff.keyed_edus):
for key, edu in edu_map.items():
transaction_queue.send_edu(edu, key)
+18 -1
View File
@@ -371,7 +371,7 @@ class FederationSender(object):
return
# First we queue up the new presence by user ID, so multiple presence
# updates in quick successtion are correctly handled
# updates in quick succession are correctly handled.
# We only want to send presence for our own users, so lets always just
# filter here just in case.
self.pending_presence.update({
@@ -402,6 +402,23 @@ class FederationSender(object):
finally:
self._processing_pending_presence = False
def send_presence_to_destinations(self, states, destinations):
"""Send the given presence states to the given destinations.
Args:
states (list[UserPresenceState])
destinations (list[str])
"""
if not states or not self.hs.config.use_presence:
# No-op if presence is disabled.
return
for destination in destinations:
if destination == self.server_name:
continue
self._get_per_destination_queue(destination).send_presence(states)
@measure_func("txnqueue._process_presence")
@defer.inlineCallbacks
def _process_presence_inner(self, states):
+147 -29
View File
@@ -31,9 +31,11 @@ from prometheus_client import Counter
from twisted.internet import defer
from synapse.api.constants import PresenceState
import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
@@ -98,6 +100,7 @@ class PresenceHandler(object):
self.hs = hs
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.wheel_timer = WheelTimer()
@@ -132,9 +135,6 @@ class PresenceHandler(object):
)
)
distributor = hs.get_distributor()
distributor.observe("user_joined_room", self.user_joined_room)
active_presence = self.store.take_presence_startup_info()
# A dictionary of the current state of users. This is prefilled with
@@ -220,6 +220,15 @@ class PresenceHandler(object):
LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
lambda: len(self.wheel_timer))
# Used to handle sending of presence to newly joined users/servers
if hs.config.use_presence:
self.notifier.add_replication_callback(self.notify_new_event)
# Presence is best effort and quickly heals itself, so lets just always
# stream from the current state when we restart.
self._event_pos = self.store.get_current_events_token()
self._event_processing = False
@defer.inlineCallbacks
def _on_shutdown(self):
"""Gets called when shutting down. This lets us persist any updates that
@@ -750,31 +759,6 @@ class PresenceHandler(object):
yield self._update_states([prev_state.copy_and_replace(**new_fields)])
@defer.inlineCallbacks
def user_joined_room(self, user, room_id):
"""Called (via the distributor) when a user joins a room. This funciton
sends presence updates to servers, either:
1. the joining user is a local user and we send their presence to
all servers in the room.
2. the joining user is a remote user and so we send presence for all
local users in the room.
"""
# We only need to send presence to servers that don't have it yet. We
# don't need to send to local clients here, as that is done as part
# of the event stream/sync.
# TODO: Only send to servers not already in the room.
if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string())
self._push_to_remotes([state])
else:
user_ids = yield self.store.get_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids))
states = yield self.current_state_for_users(user_ids)
self._push_to_remotes(list(states.values()))
@defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
"""Returns the presence for all users in their presence list.
@@ -945,6 +929,140 @@ class PresenceHandler(object):
rows = yield self.store.get_all_presence_updates(last_id, current_id)
defer.returnValue(rows)
def notify_new_event(self):
"""Called when new events have happened. Handles users and servers
joining rooms and require being sent presence.
"""
if self._event_processing:
return
@defer.inlineCallbacks
def _process_presence():
assert not self._event_processing
self._event_processing = True
try:
yield self._unsafe_process()
finally:
self._event_processing = False
run_as_background_process("presence.notify_new_event", _process_presence)
@defer.inlineCallbacks
def _unsafe_process(self):
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "presence_delta"):
deltas = yield self.store.get_current_state_deltas(self._event_pos)
if not deltas:
return
yield self._handle_state_delta(deltas)
self._event_pos = deltas[-1]["stream_id"]
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("presence").set(
self._event_pos
)
@defer.inlineCallbacks
def _handle_state_delta(self, deltas):
"""Process current state deltas to find new joins that need to be
handled.
"""
for delta in deltas:
typ = delta["type"]
state_key = delta["state_key"]
room_id = delta["room_id"]
event_id = delta["event_id"]
prev_event_id = delta["prev_event_id"]
logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
if typ != EventTypes.Member:
continue
event = yield self.store.get_event(event_id)
if event.content.get("membership") != Membership.JOIN:
# We only care about joins
continue
if prev_event_id:
prev_event = yield self.store.get_event(prev_event_id)
if prev_event.content.get("membership") == Membership.JOIN:
# Ignore changes to join events.
continue
yield self._on_user_joined_room(room_id, state_key)
@defer.inlineCallbacks
def _on_user_joined_room(self, room_id, user_id):
"""Called when we detect a user joining the room via the current state
delta stream.
Args:
room_id (str)
user_id (str)
Returns:
Deferred
"""
if self.is_mine_id(user_id):
# If this is a local user then we need to send their presence
# out to hosts in the room (who don't already have it)
# TODO: We should be able to filter the hosts down to those that
# haven't previously seen the user
state = yield self.current_state_for_user(user_id)
hosts = yield self.state.get_current_hosts_in_room(room_id)
# Filter out ourselves.
hosts = set(host for host in hosts if host != self.server_name)
self.federation.send_presence_to_destinations(
states=[state],
destinations=hosts,
)
else:
# A remote user has joined the room, so we need to:
# 1. Check if this is a new server in the room
# 2. If so send any presence they don't already have for
# local users in the room.
# TODO: We should be able to filter the users down to those that
# the server hasn't previously seen
# TODO: Check that this is actually a new server joining the
# room.
user_ids = yield self.state.get_current_user_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids))
states = yield self.current_state_for_users(user_ids)
# Filter out old presence, i.e. offline presence states where
# the user hasn't been active for a week. We can change this
# depending on what we want the UX to be, but at the least we
# should filter out offline presence where the state is just the
# default state.
now = self.clock.time_msec()
states = [
state for state in states.values()
if state.state != PresenceState.OFFLINE
or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
or state.status_msg is not None
]
if states:
self.federation.send_presence_to_destinations(
states=states,
destinations=[get_domain_from_id(user_id)],
)
def should_notify(old_state, new_state):
"""Decides if a presence state change should be sent to interested parties.
+17
View File
@@ -153,6 +153,7 @@ class RegistrationHandler(BaseHandler):
user_type=None,
default_display_name=None,
address=None,
bind_emails=[],
):
"""Registers a new client on the server.
@@ -172,6 +173,7 @@ class RegistrationHandler(BaseHandler):
default_display_name (unicode|None): if set, the new user's displayname
will be set to this. Defaults to 'localpart'.
address (str|None): the IP address used to perform the registration.
bind_emails (List[str]): list of emails to bind to this account.
Returns:
A tuple of (user_id, access_token).
Raises:
@@ -261,6 +263,21 @@ class RegistrationHandler(BaseHandler):
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
# Bind any specified emails to this account
current_time = self.hs.get_clock().time_msec()
for email in bind_emails:
# generate threepid dict
threepid_dict = {
"medium": "email",
"address": email,
"validated_at": current_time,
}
# Bind email to new account
yield self._register_email_threepid(
user_id, threepid_dict, None, False,
)
defer.returnValue((user_id, token))
@defer.inlineCallbacks
+5 -4
View File
@@ -74,14 +74,14 @@ class ModuleApi(object):
return self._auth_handler.check_user_exists(user_id)
@defer.inlineCallbacks
def register(self, localpart, displayname=None):
def register(self, localpart, displayname=None, emails=[]):
"""Registers a new user with given localpart and optional
displayname.
displayname, emails.
Args:
localpart (str): The localpart of the new user.
displayname (str|None): The displayname of the new user. If None,
the user's displayname will default to `localpart`.
displayname (str|None): The displayname of the new user.
emails (List[str]): Emails to bind to the new user.
Returns:
Deferred: a 2-tuple of (user_id, access_token)
@@ -90,6 +90,7 @@ class ModuleApi(object):
reg = self.hs.get_registration_handler()
user_id, access_token = yield reg.register(
localpart=localpart, default_display_name=displayname,
bind_emails=emails,
)
defer.returnValue((user_id, access_token))
+11 -3
View File
@@ -103,10 +103,18 @@ class ReplicationClientHandler(object):
hs.get_reactor().connectTCP(host, port, self.factory)
def on_rdata(self, stream_name, token, rows):
"""Called when we get new replication data. By default this just pokes
the slave store.
"""Called to handle a batch of replication data with a given stream token.
Can be overriden in subclasses to handle more.
By default this just pokes the slave store. Can be overriden in subclasses to
handle more.
Args:
stream_name (str): name of the replication stream for this batch of rows
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects.
Returns:
Deferred|None
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
return self.store.process_replication_rows(stream_name, token, rows)
+2 -1
View File
@@ -30,7 +30,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.metrics import Measure, measure_func
from .protocol import ServerReplicationStreamProtocol
from .streams import STREAMS_MAP, FederationStream
from .streams import STREAMS_MAP
from .streams.federation import FederationStream
stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
"", ["stream_name"])
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines all the valid streams that clients can subscribe to, and the format
of the rows returned by each stream.
Each stream is defined by the following information:
stream name: The name of the stream
row type: The type that is used to serialise/deserialse the row
current_token: The function that returns the current token for the stream
update_function: The function that returns a list of updates between two tokens
"""
from . import _base, events, federation
STREAMS_MAP = {
stream.NAME: stream
for stream in (
events.EventsStream,
_base.BackfillStream,
_base.PresenceStream,
_base.TypingStream,
_base.ReceiptsStream,
_base.PushRulesStream,
_base.PushersStream,
_base.CachesStream,
_base.PublicRoomsStream,
_base.DeviceListsStream,
_base.ToDeviceStream,
federation.FederationStream,
_base.TagAccountDataStream,
_base.AccountDataStream,
_base.CurrentStateDeltaStream,
_base.GroupServerStream,
)
}
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,16 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines all the valid streams that clients can subscribe to, and the format
of the rows returned by each stream.
Each stream is defined by the following information:
stream name: The name of the stream
row type: The type that is used to serialise/deserialse the row
current_token: The function that returns the current token for the stream
update_function: The function that returns a list of updates between two tokens
"""
import itertools
import logging
from collections import namedtuple
@@ -34,14 +26,6 @@ logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 10000
EventStreamRow = namedtuple("EventStreamRow", (
"event_id", # str
"room_id", # str
"type", # str
"state_key", # str, optional
"redacts", # str, optional
))
BackfillStreamRow = namedtuple("BackfillStreamRow", (
"event_id", # str
"room_id", # str
@@ -96,10 +80,6 @@ DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", (
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", (
"entity", # str
))
FederationStreamRow = namedtuple("FederationStreamRow", (
"type", # str, the type of data as defined in the BaseFederationRows
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
))
TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", (
"user_id", # str
"room_id", # str
@@ -162,8 +142,10 @@ class Stream(object):
until the `upto_token`
Returns:
(list(ROW_TYPE), int): list of updates plus the token used as an
upper bound of the updates (i.e. the "current token")
Deferred[Tuple[List[Tuple[int, Any]], int]:
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
"""
updates, current_token = yield self.get_updates_since(self.last_token)
self.last_token = current_token
@@ -176,8 +158,10 @@ class Stream(object):
stream updates
Returns:
(list(ROW_TYPE), int): list of updates plus the token used as an
upper bound of the updates (i.e. the "current token")
Deferred[Tuple[List[Tuple[int, Any]], int]:
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
"""
if from_token in ("NOW", "now"):
defer.returnValue(([], self.upto_token))
@@ -232,20 +216,6 @@ class Stream(object):
raise NotImplementedError()
class EventsStream(Stream):
"""We received a new event, or an event went from being an outlier to not
"""
NAME = "events"
ROW_TYPE = EventStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_events_token
self.update_function = store.get_all_new_forward_event_rows
super(EventsStream, self).__init__(hs)
class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
@@ -400,22 +370,6 @@ class ToDeviceStream(Stream):
super(ToDeviceStream, self).__init__(hs)
class FederationStream(Stream):
"""Data to be sent over federation. Only available when master has federation
sending disabled.
"""
NAME = "federation"
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token
self.update_function = federation_sender.get_replication_rows
super(FederationStream, self).__init__(hs)
class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room
"""
@@ -485,26 +439,3 @@ class GroupServerStream(Stream):
self.update_function = store.get_all_groups_changes
super(GroupServerStream, self).__init__(hs)
STREAMS_MAP = {
stream.NAME: stream
for stream in (
EventsStream,
BackfillStream,
PresenceStream,
TypingStream,
ReceiptsStream,
PushRulesStream,
PushersStream,
CachesStream,
PublicRoomsStream,
DeviceListsStream,
ToDeviceStream,
FederationStream,
TagAccountDataStream,
AccountDataStream,
CurrentStateDeltaStream,
GroupServerStream,
)
}
+40
View File
@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from ._base import Stream
EventStreamRow = namedtuple("EventStreamRow", (
"event_id", # str
"room_id", # str
"type", # str
"state_key", # str, optional
"redacts", # str, optional
))
class EventsStream(Stream):
"""We received a new event, or an event went from being an outlier to not
"""
NAME = "events"
ROW_TYPE = EventStreamRow
def __init__(self, hs):
store = hs.get_datastore()
self.current_token = store.get_current_events_token
self.update_function = store.get_all_new_forward_event_rows
super(EventsStream, self).__init__(hs)
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import namedtuple
from ._base import Stream
FederationStreamRow = namedtuple("FederationStreamRow", (
"type", # str, the type of data as defined in the BaseFederationRows
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
))
class FederationStream(Stream):
"""Data to be sent over federation. Only available when master has federation
sending disabled.
"""
NAME = "federation"
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token
self.update_function = federation_sender.get_replication_rows
super(FederationStream, self).__init__(hs)
-2
View File
@@ -647,8 +647,6 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
assert_params_in_dict(params, ["new_password"])
new_password = params['new_password']
logger.info("new_password: %r", new_password)
yield self._set_password_handler.set_password(
target_user_id, new_password, requester
)
+248 -294
View File
File diff suppressed because it is too large Load Diff
+87 -86
View File
@@ -35,28 +35,22 @@ logger = logging.getLogger(__name__)
RoomsForUser = namedtuple(
"RoomsForUser",
("room_id", "sender", "membership", "event_id", "stream_ordering")
"RoomsForUser", ("room_id", "sender", "membership", "event_id", "stream_ordering")
)
GetRoomsForUserWithStreamOrdering = namedtuple(
"_GetRoomsForUserWithStreamOrdering",
("room_id", "stream_ordering",)
"_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
)
# We store this using a namedtuple so that we save about 3x space over using a
# dict.
ProfileInfo = namedtuple(
"ProfileInfo", ("avatar_url", "display_name")
)
ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
# "members" points to a truncated list of (user_id, event_id) tuples for users of
# a given membership type, suitable for use in calculating heroes for a room.
# "count" points to the total numberr of users of a given membership type.
MemberSummary = namedtuple(
"MemberSummary", ("members", "count")
)
MemberSummary = namedtuple("MemberSummary", ("members", "count"))
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
@@ -67,7 +61,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Returns the set of all hosts currently in the room
"""
user_ids = yield self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate,
room_id, on_invalidate=cache_context.invalidate
)
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
defer.returnValue(hosts)
@@ -84,8 +78,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
" WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
)
txn.execute(sql, (room_id, Membership.JOIN,))
txn.execute(sql, (room_id, Membership.JOIN))
return [to_ascii(r[0]) for r in txn]
return self.runInteraction("get_users_in_room", f)
@cached(max_entries=100000)
@@ -156,9 +151,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
A deferred list of RoomsForUser.
"""
return self.get_rooms_for_user_where_membership_is(
user_id, [Membership.INVITE]
)
return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
@defer.inlineCallbacks
def get_invite_for_user_in_room(self, user_id, room_id):
@@ -196,11 +189,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self.runInteraction(
"get_rooms_for_user_where_membership_is",
self._get_rooms_for_user_where_membership_is_txn,
user_id, membership_list
user_id,
membership_list,
)
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
membership_list):
def _get_rooms_for_user_where_membership_is_txn(
self, txn, user_id, membership_list
):
do_invite = Membership.INVITE in membership_list
membership_list = [m for m in membership_list if m != Membership.INVITE]
@@ -227,9 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
) % (where_clause,)
txn.execute(sql, args)
results = [
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
]
results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
if do_invite:
sql = (
@@ -241,13 +234,16 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
txn.execute(sql, (user_id,))
results.extend(RoomsForUser(
room_id=r["room_id"],
sender=r["inviter"],
event_id=r["event_id"],
stream_ordering=r["stream_ordering"],
membership=Membership.INVITE,
) for r in self.cursor_to_dict(txn))
results.extend(
RoomsForUser(
room_id=r["room_id"],
sender=r["inviter"],
event_id=r["event_id"],
stream_ordering=r["stream_ordering"],
membership=Membership.INVITE,
)
for r in self.cursor_to_dict(txn)
)
return results
@@ -264,19 +260,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
of the most recent join for that user and room.
"""
rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN],
user_id, membership_list=[Membership.JOIN]
)
defer.returnValue(
frozenset(
GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
for r in rooms
)
)
defer.returnValue(frozenset(
GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
for r in rooms
))
@defer.inlineCallbacks
def get_rooms_for_user(self, user_id, on_invalidate=None):
"""Returns a set of room_ids the user is currently joined to
"""
rooms = yield self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate,
user_id, on_invalidate=on_invalidate
)
defer.returnValue(frozenset(r.room_id for r in rooms))
@@ -285,13 +283,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Returns the set of users who share a room with `user_id`
"""
room_ids = yield self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate,
user_id, on_invalidate=cache_context.invalidate
)
user_who_share_room = set()
for room_id in room_ids:
user_ids = yield self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate,
room_id, on_invalidate=cache_context.invalidate
)
user_who_share_room.update(user_ids)
@@ -309,9 +307,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
current_state_ids = yield context.get_current_state_ids(self)
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids,
event=event,
context=context,
event.room_id, state_group, current_state_ids, event=event, context=context
)
defer.returnValue(result)
@@ -325,13 +321,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
return self._get_joined_users_from_context(
room_id, state_group, state_entry.state, context=state_entry,
room_id, state_group, state_entry.state, context=state_entry
)
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
max_entries=100000)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
cache_context, event=None, context=None):
@cachedInlineCallbacks(
num_args=2, cache_context=True, iterable=True, max_entries=100000
)
def _get_joined_users_from_context(
self,
room_id,
state_group,
current_state_ids,
cache_context,
event=None,
context=None,
):
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
@@ -371,9 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the hit ratio counts. After all, we don't populate the cache if we
# miss it here
event_map = self._get_events_from_cache(
member_event_ids,
allow_rejected=False,
update_metrics=False,
member_event_ids, allow_rejected=False, update_metrics=False
)
missing_member_event_ids = []
@@ -397,21 +399,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
table="room_memberships",
column="event_id",
iterable=missing_member_event_ids,
retcols=('user_id', 'display_name', 'avatar_url',),
keyvalues={
"membership": Membership.JOIN,
},
retcols=('user_id', 'display_name', 'avatar_url'),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_joined_users_from_context",
)
users_in_room.update({
to_ascii(row["user_id"]): ProfileInfo(
avatar_url=to_ascii(row["avatar_url"]),
display_name=to_ascii(row["display_name"]),
)
for row in rows
})
users_in_room.update(
{
to_ascii(row["user_id"]): ProfileInfo(
avatar_url=to_ascii(row["avatar_url"]),
display_name=to_ascii(row["display_name"]),
)
for row in rows
}
)
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
@@ -505,7 +507,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
state_group = object()
return self._get_joined_hosts(
room_id, state_group, state_entry.state, state_entry=state_entry,
room_id, state_group, state_entry.state, state_entry=state_entry
)
@cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
@@ -531,6 +533,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""Returns whether user_id has elected to discard history for room_id.
Returns False if they have since re-joined."""
def f(txn):
sql = (
"SELECT"
@@ -547,6 +550,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (user_id, room_id))
rows = txn.fetchall()
return rows[0][0]
count = yield self.runInteraction("did_forget_membership", f)
defer.returnValue(count == 0)
@@ -575,13 +579,14 @@ class RoomMemberStore(RoomMemberWorkerStore):
"avatar_url": event.content.get("avatar_url", None),
}
for event in events
]
],
)
for event in events:
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
event.state_key,
event.internal_metadata.stream_ordering,
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
@@ -607,7 +612,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
}
},
)
else:
sql = (
@@ -616,12 +621,15 @@ class RoomMemberStore(RoomMemberWorkerStore):
" AND replaced_by is NULL"
)
txn.execute(sql, (
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
))
txn.execute(
sql,
(
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
),
)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
@@ -632,18 +640,14 @@ class RoomMemberStore(RoomMemberWorkerStore):
)
def f(txn, stream_ordering):
txn.execute(sql, (
stream_ordering,
True,
room_id,
user_id,
))
txn.execute(sql, (stream_ordering, True, room_id, user_id))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
sql = (
"UPDATE"
@@ -657,9 +661,8 @@ class RoomMemberStore(RoomMemberWorkerStore):
)
txn.execute(sql, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.did_forget, (user_id, room_id,),
)
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
return self.runInteraction("forget_membership", f)
@defer.inlineCallbacks
@@ -674,7 +677,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
INSERT_CLUMP_SIZE = 1000
def add_membership_profile_txn(txn):
sql = ("""
sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
FROM events
INNER JOIN event_json USING (event_id)
@@ -683,7 +686,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
AND type = 'm.room.member'
ORDER BY stream_ordering DESC
LIMIT ?
""")
"""
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
@@ -707,16 +710,14 @@ class RoomMemberStore(RoomMemberWorkerStore):
avatar_url = content.get("avatar_url", None)
if display_name or avatar_url:
to_update.append((
display_name, avatar_url, event_id, room_id
))
to_update.append((display_name, avatar_url, event_id, room_id))
to_update_sql = ("""
to_update_sql = """
UPDATE room_memberships SET display_name = ?, avatar_url = ?
WHERE event_id = ? AND room_id = ?
""")
"""
for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
clump = to_update[index:index + INSERT_CLUMP_SIZE]
clump = to_update[index : index + INSERT_CLUMP_SIZE]
txn.executemany(to_update_sql, clump)
progress = {
@@ -789,7 +790,7 @@ class _JoinedHostsCache(object):
self.hosts_to_joined_users.pop(host, None)
else:
joined_users = yield self.store.get_joined_users_from_state(
self.room_id, state_entry,
self.room_id, state_entry
)
self.hosts_to_joined_users = {}
+175 -1
View File
@@ -16,7 +16,11 @@
from mock import Mock, call
from synapse.api.constants import PresenceState
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.events import room_version_to_event_format
from synapse.events.builder import EventBuilder
from synapse.handlers.presence import (
FEDERATION_PING_INTERVAL,
FEDERATION_TIMEOUT,
@@ -26,7 +30,9 @@ from synapse.handlers.presence import (
handle_timeout,
handle_update,
)
from synapse.rest.client.v1 import room
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from tests import unittest
@@ -405,3 +411,171 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertIsNotNone(new_state)
self.assertEquals(state, new_state)
class PresenceJoinTestCase(unittest.HomeserverTestCase):
"""Tests remote servers get told about presence of users in the room when
they join and when new local users join.
"""
user_id = "@test:server"
servlets = [room.register_servlets]
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
"server", http_client=None,
federation_sender=Mock(),
)
return hs
def prepare(self, reactor, clock, hs):
self.federation_sender = hs.get_federation_sender()
self.event_builder_factory = hs.get_event_builder_factory()
self.federation_handler = hs.get_handlers().federation_handler
self.presence_handler = hs.get_presence_handler()
# self.event_builder_for_2 = EventBuilderFactory(hs)
# self.event_builder_for_2.hostname = "test2"
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
# We don't actually check signatures in tests, so lets just create a
# random key to use.
self.random_signing_key = generate_signing_key("ver")
def test_remote_joins(self):
# We advance time to something that isn't 0, as we use 0 as a special
# value.
self.reactor.advance(1000000000000)
# Create a room with two local users
room_id = self.helper.create_room_as(self.user_id)
self.helper.join(room_id, "@test2:server")
# Mark test2 as online, test will be offline with a last_active of 0
self.presence_handler.set_state(
UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE},
)
self.reactor.pump([0]) # Wait for presence updates to be handled
#
# Test that a new server gets told about existing presence
#
self.federation_sender.reset_mock()
# Add a new remote server to the room
self._add_new_user(room_id, "@alice:server2")
# We shouldn't have sent out any local presence *updates*
self.federation_sender.send_presence.assert_not_called()
# When new server is joined we send it the local users presence states.
# We expect to only see user @test2:server, as @test:server is offline
# and has a zero last_active_ts
expected_state = self.get_success(
self.presence_handler.current_state_for_user("@test2:server")
)
self.assertEqual(expected_state.state, PresenceState.ONLINE)
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
destinations=["server2"], states=[expected_state]
)
#
# Test that only the new server gets sent presence and not existing servers
#
self.federation_sender.reset_mock()
self._add_new_user(room_id, "@bob:server3")
self.federation_sender.send_presence.assert_not_called()
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
destinations=["server3"], states=[expected_state]
)
def test_remote_gets_presence_when_local_user_joins(self):
# We advance time to something that isn't 0, as we use 0 as a special
# value.
self.reactor.advance(1000000000000)
# Create a room with one local users
room_id = self.helper.create_room_as(self.user_id)
# Mark test as online
self.presence_handler.set_state(
UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE},
)
# Mark test2 as online, test will be offline with a last_active of 0.
# Note we don't join them to the room yet
self.presence_handler.set_state(
UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE},
)
# Add servers to the room
self._add_new_user(room_id, "@alice:server2")
self._add_new_user(room_id, "@bob:server3")
self.reactor.pump([0]) # Wait for presence updates to be handled
#
# Test that when a local join happens remote servers get told about it
#
self.federation_sender.reset_mock()
# Join local user to room
self.helper.join(room_id, "@test2:server")
self.reactor.pump([0]) # Wait for presence updates to be handled
# We shouldn't have sent out any local presence *updates*
self.federation_sender.send_presence.assert_not_called()
# We expect to only send test2 presence to server2 and server3
expected_state = self.get_success(
self.presence_handler.current_state_for_user("@test2:server")
)
self.assertEqual(expected_state.state, PresenceState.ONLINE)
self.federation_sender.send_presence_to_destinations.assert_called_once_with(
destinations=set(("server2", "server3")),
states=[expected_state]
)
def _add_new_user(self, room_id, user_id):
"""Add new user to the room by creating an event and poking the federation API.
"""
hostname = get_domain_from_id(user_id)
room_version = self.get_success(self.store.get_room_version(room_id))
builder = EventBuilder(
state=self.state,
auth=self.auth,
store=self.store,
clock=self.clock,
hostname=hostname,
signing_key=self.random_signing_key,
format_version=room_version_to_event_format(room_version),
room_id=room_id,
type=EventTypes.Member,
sender=user_id,
state_key=user_id,
content={"membership": Membership.JOIN}
)
prev_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(room_id)
)
event = self.get_success(builder.build(prev_event_ids))
self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
# Check that it was successfully persisted.
self.get_success(self.store.get_event(event.event_id))
self.get_success(self.store.get_event(event.event_id))
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.tcp.streams import ReceiptsStreamRow
from synapse.replication.tcp.streams._base import ReceiptsStreamRow
from tests.replication.tcp.streams._base import BaseStreamTestCase