1
0

Merge remote-tracking branch 'origin/rav/state_stream_limit_assertion' into hs/sssh-testing-redis-things

This commit is contained in:
Half-Shot
2020-04-28 18:36:16 +01:00
36 changed files with 1172 additions and 153 deletions

1
changelog.d/7091.doc Normal file
View File

@@ -0,0 +1 @@
Improve the documentation of application service configuration files.

1
changelog.d/7146.misc Normal file
View File

@@ -0,0 +1 @@
Run replication streamers on workers.

1
changelog.d/7278.misc Normal file
View File

@@ -0,0 +1 @@
Add some unit tests for replication.

1
changelog.d/7337.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.

1
changelog.d/7338.misc Normal file
View File

@@ -0,0 +1 @@
Convert some federation handler code to async/await.

1
changelog.d/7341.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bad error handling that would cause Synapse to crash if it's provided with a YAML configuration file that's either empty or doesn't parse into a key-value map.

1
changelog.d/7343.feature Normal file
View File

@@ -0,0 +1 @@
Support SSO in the user interactive authentication workflow.

1
changelog.d/7344.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix incorrect metrics reporting for `renew_attestations` background task.

1
changelog.d/7357.doc Normal file
View File

@@ -0,0 +1 @@
Add documentation on monitoring workers with Prometheus.

1
changelog.d/7358.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug where event updates might not be sent over replication to worker processes after the stream falls behind.

1
changelog.d/7359.misc Normal file
View File

@@ -0,0 +1 @@
Fix collation for postgres for unit tests.

View File

@@ -23,9 +23,13 @@ namespaces:
users: # List of users we're interested in
- exclusive: <bool>
regex: <regex>
group_id: <group>
- ...
aliases: [] # List of aliases we're interested in
rooms: [] # List of room ids we're interested in
```
`exclusive`: If enabled, only this application service is allowed to register users in its namespace(s).
`group_id`: All users of this application service are dynamically joined to this group. This is useful for e.g user organisation or flairs.
See the [spec](https://matrix.org/docs/spec/application_service/unstable.html) for further details on how application services work.

View File

@@ -60,6 +60,31 @@
1. Restart Prometheus.
## Monitoring workers
To monitor a Synapse installation using
[workers](https://github.com/matrix-org/synapse/blob/master/docs/workers.md),
every worker needs to be monitored independently, in addition to
the main homeserver process. This is because workers don't send
their metrics to the main homeserver process, but expose them
directly (if they are configured to do so).
To allow collecting metrics from a worker, you need to add a
`metrics` listener to its configuration, by adding the following
under `worker_listeners`:
```yaml
- type: metrics
bind_address: ''
port: 9101
```
The `bind_address` and `port` parameters should be set so that
the resulting listener can be reached by prometheus, and they
don't clash with an existing worker.
With this example, the worker's metrics would then be available
on `http://127.0.0.1:9101`.
## Renaming of metrics & deprecation of old names in 1.2
Synapse 1.2 updates the Prometheus metrics to match the naming

View File

@@ -1518,6 +1518,30 @@ sso:
#
# * server_name: the homeserver's name.
#
# * HTML page which notifies the user that they are authenticating to confirm
# an operation on their account during the user interactive authentication
# process: 'sso_auth_confirm.html'.
#
# When rendering, this template is given the following variables:
# * redirect_url: the URL the user is about to be redirected to. Needs
# manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * description: the operation which the user is being asked to confirm
#
# * HTML page shown after a successful user interactive authentication session:
# 'sso_auth_success.html'.
#
# Note that this page must include the JavaScript which notifies of a successful authentication
# (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback).
#
# This template has no additional variables.
#
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
# attempts to login: 'sso_account_deactivated.html'.
#
# This template has no additional variables.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#

View File

@@ -960,17 +960,22 @@ def start(config_options):
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
ss = GenericWorkerServer(
hs = GenericWorkerServer(
config.server_name,
config=config,
version_string="Synapse/" + get_version_string(synapse),
)
setup_logging(ss, config, use_worker_options=True)
setup_logging(hs, config, use_worker_options=True)
hs.setup()
# Ensure the replication streamer is always started in case we write to any
# streams. Will no-op if no streams can be written to by this worker.
hs.get_replication_streamer()
ss.setup()
reactor.addSystemEventTrigger(
"before", "startup", _base.start, ss, config.worker_listeners
"before", "startup", _base.start, hs, config.worker_listeners
)
_base.start_worker_reactor("synapse-generic-worker", config)

View File

@@ -657,6 +657,12 @@ def read_config_files(config_files):
for config_file in config_files:
with open(config_file) as file_stream:
yaml_config = yaml.safe_load(file_stream)
if not isinstance(yaml_config, dict):
err = "File %r is empty or doesn't parse into a key-value map. IGNORING."
print(err % (config_file,))
continue
specified_config.update(yaml_config)
if "server_name" not in specified_config:

View File

@@ -138,7 +138,7 @@ class DatabaseConfig(Config):
database_path = config.get("database_path")
if multi_database_config and database_config:
raise ConfigError("Can't specify both 'database' and 'datbases' in config")
raise ConfigError("Can't specify both 'database' and 'databases' in config")
if multi_database_config:
if database_path:

View File

@@ -113,6 +113,30 @@ class SSOConfig(Config):
#
# * server_name: the homeserver's name.
#
# * HTML page which notifies the user that they are authenticating to confirm
# an operation on their account during the user interactive authentication
# process: 'sso_auth_confirm.html'.
#
# When rendering, this template is given the following variables:
# * redirect_url: the URL the user is about to be redirected to. Needs
# manual escaping (see
# https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
#
# * description: the operation which the user is being asked to confirm
#
# * HTML page shown after a successful user interactive authentication session:
# 'sso_auth_success.html'.
#
# Note that this page must include the JavaScript which notifies of a successful authentication
# (see https://matrix.org/docs/spec/client_server/r0.6.0#fallback).
#
# This template has no additional variables.
#
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
# attempts to login: 'sso_account_deactivated.html'.
#
# This template has no additional variables.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#

View File

@@ -37,15 +37,16 @@ An attestation is a signed blob of json that looks like:
import logging
import random
from typing import Tuple
from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util.async_helpers import yieldable_gather_results
logger = logging.getLogger(__name__)
@@ -162,19 +163,19 @@ class GroupAttestionRenewer(object):
def _start_renew_attestations(self):
return run_as_background_process("renew_attestations", self._renew_attestations)
@defer.inlineCallbacks
def _renew_attestations(self):
async def _renew_attestations(self):
"""Called periodically to check if we need to update any of our attestations
"""
now = self.clock.time_msec()
rows = yield self.store.get_attestations_need_renewals(
rows = await self.store.get_attestations_need_renewals(
now + UPDATE_ATTESTATION_TIME_MS
)
@defer.inlineCallbacks
def _renew_attestation(group_id, user_id):
def _renew_attestation(group_user: Tuple[str, str]):
group_id, user_id = group_user
try:
if not self.is_mine_id(group_id):
destination = get_domain_from_id(group_id)
@@ -207,8 +208,6 @@ class GroupAttestionRenewer(object):
"Error renewing attestation of %r in %r", user_id, group_id
)
for row in rows:
group_id = row["group_id"]
user_id = row["user_id"]
run_in_background(_renew_attestation, group_id, user_id)
await yieldable_gather_results(
_renew_attestation, ((row["group_id"], row["user_id"]) for row in rows)
)

View File

@@ -343,7 +343,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(ours.values()) # type: list[StateMap[str]]
state_maps = list(ours.values()) # type: List[StateMap[str]]
# we don't need this any more, let's delete it.
del ours
@@ -1694,16 +1694,15 @@ class FederationHandler(BaseHandler):
return None
@defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id):
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event.
"""
event = yield self.store.get_event(
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(iteritems(state_groups)).pop()
@@ -1714,7 +1713,7 @@ class FederationHandler(BaseHandler):
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
prev_event = yield self.store.get_event(prev_id)
prev_event = await self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event
else:
del results[(event.type, event.state_key)]
@@ -1724,15 +1723,14 @@ class FederationHandler(BaseHandler):
else:
return []
@defer.inlineCallbacks
def get_state_ids_for_pdu(self, room_id, event_id):
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
event = yield self.store.get_event(
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
@@ -1751,49 +1749,50 @@ class FederationHandler(BaseHandler):
else:
return []
@defer.inlineCallbacks
@log_function
def on_backfill_request(self, origin, room_id, pdu_list, limit):
in_room = yield self.auth.check_host_in_room(room_id, origin)
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
) -> List[EventBase]:
in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
# Synapse asks for 100 events per backfill request. Do not allow more.
limit = min(limit, 100)
events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
events = await self.store.get_backfill_events(room_id, pdu_list, limit)
events = yield filter_events_for_server(self.storage, origin, events)
events = await filter_events_for_server(self.storage, origin, events)
return events
@defer.inlineCallbacks
@log_function
def get_persisted_pdu(self, origin, event_id):
async def get_persisted_pdu(
self, origin: str, event_id: str
) -> Optional[EventBase]:
"""Get an event from the database for the given server.
Args:
origin [str]: hostname of server which is requesting the event; we
origin: hostname of server which is requesting the event; we
will check that the server is allowed to see it.
event_id [str]: id of the event being requested
event_id: id of the event being requested
Returns:
Deferred[EventBase|None]: None if we know nothing about the event;
otherwise the (possibly-redacted) event.
None if we know nothing about the event; otherwise the (possibly-redacted) event.
Raises:
AuthError if the server is not currently in the room
"""
event = yield self.store.get_event(
event = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
if event:
in_room = yield self.auth.check_host_in_room(event.room_id, origin)
in_room = await self.auth.check_host_in_room(event.room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
events = yield filter_events_for_server(self.storage, origin, [event])
events = await filter_events_for_server(self.storage, origin, [event])
event = events[0]
return event
else:
@@ -2397,7 +2396,7 @@ class FederationHandler(BaseHandler):
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
event_key = (event.type, event.state_key)
event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
else:
event_key = None
state_updates = {

View File

@@ -28,7 +28,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
The API looks like:
GET /_synapse/replication/get_repl_stream_updates/events?from_token=0&to_token=10&limit=100
GET /_synapse/replication/get_repl_stream_updates/<stream name>?from_token=0&to_token=10
200 OK
@@ -38,6 +38,9 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
limited: False,
}
If there are more rows than can sensibly be returned in one lump, `limited` will be
set to true, and the caller should call again with a new `from_token`.
"""
NAME = "get_repl_stream_updates"
@@ -52,8 +55,8 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
self.streams = hs.get_replication_streamer().get_streams()
@staticmethod
def _serialize_payload(stream_name, from_token, upto_token, limit):
return {"from_token": from_token, "upto_token": upto_token, "limit": limit}
def _serialize_payload(stream_name, from_token, upto_token):
return {"from_token": from_token, "upto_token": upto_token}
async def _handle_request(self, request, stream_name):
stream = self.streams.get(stream_name)
@@ -62,10 +65,9 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
from_token = parse_integer(request, "from_token", required=True)
upto_token = parse_integer(request, "upto_token", required=True)
limit = parse_integer(request, "limit", required=True)
updates, upto_token, limited = await stream.get_updates_since(
from_token, upto_token, limit
from_token, upto_token
)
return (

View File

@@ -17,9 +17,7 @@
import logging
import random
from typing import Dict
from six import itervalues
from typing import Dict, List
from prometheus_client import Counter
@@ -71,29 +69,28 @@ class ReplicationStreamer(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self._server_notices_sender = hs.get_server_notices_sender()
self._replication_torture_level = hs.config.replication_torture_level
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
stream(hs)
for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
# Work out list of streams that this instance is the source of.
self.streams = [] # type: List[Stream]
if hs.config.worker_app is None:
for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation:
# We only support federation stream if federation sending
# hase been disabled on the master.
continue
self.streams.append(stream(hs))
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
self.federation_sender = None
if not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke)
# Only bother registering the notifier callback if we have streams to
# publish.
if self.streams:
self.notifier.add_replication_callback(self.on_notifier_poke)
# Keeps track of whether we are currently checking for updates
self.is_looping = False

View File

@@ -24,8 +24,8 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 500000
# the number of rows to request from an update_function.
_STREAM_UPDATE_TARGET_ROW_COUNT = 100
# Some type aliases to make things a bit easier.
@@ -56,7 +56,11 @@ StreamUpdateResult = Tuple[List[Tuple[Token, StreamRow]], Token, bool]
# * from_token: the previous stream token: the starting point for fetching the
# updates
# * to_token: the new stream token: the point to get updates up to
# * limit: the maximum number of rows to return
# * target_row_count: a target for the number of rows to be returned.
#
# The update_function is expected to return up to _approximately_ target_row_count rows.
# If there are more updates available, it should set `limited` in the result, and
# it will be called again to get the next batch.
#
UpdateFunction = Callable[[Token, Token, int], Awaitable[StreamUpdateResult]]
@@ -138,7 +142,7 @@ class Stream(object):
return updates, current_token, limited
async def get_updates_since(
self, from_token: Token, upto_token: Token, limit: int = 100
self, from_token: Token, upto_token: Token
) -> StreamUpdateResult:
"""Like get_updates except allows specifying from when we should
stream updates
@@ -156,7 +160,7 @@ class Stream(object):
return [], upto_token, False
updates, upto_token, limited = await self.update_function(
from_token, upto_token, limit,
from_token, upto_token, _STREAM_UPDATE_TARGET_ROW_COUNT,
)
return updates, upto_token, limited
@@ -193,10 +197,7 @@ def make_http_update_function(hs, stream_name: str) -> UpdateFunction:
from_token: int, upto_token: int, limit: int
) -> StreamUpdateResult:
result = await client(
stream_name=stream_name,
from_token=from_token,
upto_token=upto_token,
limit=limit,
stream_name=stream_name, from_token=from_token, upto_token=upto_token,
)
return result["updates"], result["upto_token"], result["limited"]

View File

@@ -15,11 +15,12 @@
# limitations under the License.
import heapq
from typing import Iterable, Tuple, Type
from collections import Iterable
from typing import List, Tuple, Type
import attr
from ._base import Stream, Token, db_query_to_update_function
from ._base import Stream, StreamUpdateResult, Token
"""Handling of the 'events' replication stream
@@ -117,30 +118,120 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
super().__init__(
self._store.get_current_events_token,
db_query_to_update_function(self._update_function),
self._store.get_current_events_token, self._update_function,
)
async def _update_function(
self, from_token: Token, current_token: Token, limit: int
) -> Iterable[tuple]:
self, from_token: Token, current_token: Token, target_row_count: int
) -> StreamUpdateResult:
# the events stream merges together three separate sources:
# * new events
# * current_state changes
# * events which were previously outliers, but have now been de-outliered.
#
# The merge operation is complicated by the fact that we only have a single
# "stream token" which is supposed to indicate how far we have got through
# all three streams. It's therefore no good to return rows 1-1000 from the
# "new events" table if the state_deltas are limited to rows 1-100 by the
# target_row_count.
#
# In other words: we must pick a new upper limit, and must return *all* rows
# up to that point for each of the three sources.
#
# Start by trying to split the target_row_count up. We expect to have a
# negligible number of ex-outliers, and a rough approximation based on recent
# traffic on sw1v.org shows that there are approximately the same number of
# event rows between a given pair of stream ids as there are state
# updates, so let's split our target_row_count among those two types. The target
# is only an approximation - it doesn't matter if we end up going a bit over it.
target_row_count //= 2
# now we fetch up to that many rows from the events table
event_rows = await self._store.get_all_new_forward_event_rows(
from_token, current_token, limit
)
event_updates = (
(row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
)
from_token, current_token, target_row_count
) # type: List[Tuple]
# we rely on get_all_new_forward_event_rows strictly honouring the limit, so
# that we know it is safe to just take upper_limit = event_rows[-1][0].
assert (
len(event_rows) <= target_row_count
), "get_all_new_forward_event_rows did not honour row limit"
# if we hit the limit on event_updates, there's no point in going beyond the
# last stream_id in the batch for the other sources.
if len(event_rows) == target_row_count:
limited = True
upper_limit = event_rows[-1][0] # type: int
else:
limited = False
upper_limit = current_token
# next up is the state delta table
state_rows = await self._store.get_all_updated_current_state_deltas(
from_token, current_token, limit
)
from_token, upper_limit, target_row_count
) # type: List[Tuple]
assert len(state_rows) <= target_row_count
# there can be more than one row per stream_id in that table, so if we hit
# the limit there, we'll need to truncate the results so that we have a complete
# set of changes for all the stream IDs we include.
if len(state_rows) == target_row_count:
assert state_rows[-1][0] <= upper_limit
upper_limit = state_rows[-1][0] - 1
# search for the point to truncate the list
for idx in range(len(state_rows) - 1, 0, -1):
if state_rows[idx - 1][0] <= upper_limit:
state_rows = state_rows[:idx]
break
else:
# bother. We didn't get a full set of changes for even a single
# stream id. let's run the query again, without a row limit, but for
# just one stream id.
upper_limit += 1
state_rows = await self._store.get_all_updated_current_state_deltas(
from_token, upper_limit, limit=None
)
limited = True
# finally, fetch the ex-outliers rows. We assume there are few enough of these
# not to bother with the limit.
ex_outliers_rows = await self._store.get_ex_outlier_stream_rows(
from_token, upper_limit
) # type: List[Tuple]
# we now need to turn the raw database rows returned into tuples suitable
# for the replication protocol (basically, we add an identifier to
# distinguish the row type). At the same time, we can limit the event_rows
# to the max stream_id from state_rows.
event_updates = (
(stream_id, (EventsStreamEventRow.TypeId, rest))
for (stream_id, *rest) in event_rows
if stream_id <= upper_limit
) # type: Iterable[Tuple[int, Tuple]]
state_updates = (
(row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows
)
(stream_id, (EventsStreamCurrentStateRow.TypeId, rest))
for (stream_id, *rest) in state_rows
) # type: Iterable[Tuple[int, Tuple]]
all_updates = heapq.merge(event_updates, state_updates)
ex_outliers_updates = (
(stream_id, (EventsStreamEventRow.TypeId, rest))
for (stream_id, *rest) in ex_outliers_rows
) # type: Iterable[Tuple[int, Tuple]]
return all_updates
# we need to return a sorted list, so merge them together.
updates = list(heapq.merge(event_updates, state_updates, ex_outliers_updates))
return updates, upper_limit, limited
@classmethod
def parse_row(cls, row):

View File

@@ -25,6 +25,7 @@ import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
from synapse.events.builder import EventBuilderFactory
class HomeServer(object):
@property
@@ -121,3 +122,7 @@ class HomeServer(object):
pass
def get_instance_id(self) -> str:
pass
def get_event_builder_factory(self) -> EventBuilderFactory:
pass
def get_storage(self) -> synapse.storage.Storage:
pass

View File

@@ -973,8 +973,18 @@ class EventsWorkerStore(SQLBaseStore):
return self._stream_id_gen.get_current_token()
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
"""Returns new events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
limit: the maximum number of rows to return
Returns: Deferred[List[Tuple]]
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
"""
def get_all_new_forward_event_rows(txn):
sql = (
@@ -989,13 +999,26 @@ class EventsWorkerStore(SQLBaseStore):
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
new_event_updates = txn.fetchall()
return txn.fetchall()
if len(new_event_updates) == limit:
upper_bound = new_event_updates[-1][0]
else:
upper_bound = current_id
return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
def get_ex_outlier_stream_rows(self, last_id, current_id):
"""Returns de-outliered events, for the Events replication stream
Args:
last_id: the last stream_id from the previous batch.
current_id: the maximum stream_id to return up to
Returns: Deferred[List[Tuple]]
a list of events stream rows. Each tuple consists of a stream id as
the first element, followed by fields suitable for casting into an
EventsStreamRow.
"""
def get_ex_outlier_stream_rows_txn(txn):
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts, relates_to_id"
@@ -1006,15 +1029,14 @@ class EventsWorkerStore(SQLBaseStore):
" LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering DESC"
" ORDER BY event_stream_ordering ASC"
)
txn.execute(sql, (last_id, upper_bound))
new_event_updates.extend(txn)
return new_event_updates
txn.execute(sql, (last_id, current_id))
return txn.fetchall()
return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
)
def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
@@ -1062,15 +1084,23 @@ class EventsWorkerStore(SQLBaseStore):
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas(
self, from_token: int, to_token: int, limit: Optional[int]
):
def get_all_updated_current_state_deltas_txn(txn):
sql = """
SELECT stream_id, room_id, type, state_key, event_id
FROM current_state_delta_stream
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC LIMIT ?
ORDER BY stream_id ASC
"""
txn.execute(sql, (from_token, to_token, limit))
params = [from_token, to_token]
if limit is not None:
sql += "LIMIT ?"
params.append(limit)
txn.execute(sql, params)
return txn.fetchall()
return self.db.runInteraction(

View File

@@ -13,38 +13,72 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock
import logging
from typing import Any, Dict, List, Optional, Tuple
import attr
from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.site import SynapseRequest
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
def make_homeserver(self, reactor, clock):
self.test_handler = Mock(wraps=TestReplicationDataHandler())
return self.setup_test_homeserver(replication_data_handler=self.test_handler)
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
self.server = server_factory.buildProtocol(None)
repl_handler = ReplicationCommandHandler(hs)
repl_handler.handler = self.test_handler
# Make a new HomeServer object for the worker
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
http_client=None,
homeserverToUse=GenericWorkerServer,
config=config,
reactor=self.reactor,
)
# Since we use sqlite in memory databases we need to make sure the
# databases objects are the same.
self.worker_hs.get_datastore().db = hs.get_datastore().db
self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
hs, "client", "test", clock, repl_handler,
self.worker_hs, "client", "test", clock, repl_handler,
)
self._client_transport = None
self._server_transport = None
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs.get_datastore())
def reconnect(self):
if self._client_transport:
self.client.close()
@@ -74,24 +108,204 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump(0.1)
def handle_http_replication_attempt(self) -> SynapseRequest:
"""Asserts that a connection attempt was made to the master HS on the
HTTP replication port, then proxies it to the master HS object to be
handled.
class TestReplicationDataHandler:
Returns:
The request object received by master HS.
"""
# We should have an outbound connection attempt.
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 8765)
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
channel.site = self.site
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
)
channel.makeConnection(server_to_client_transport)
# The request will now be processed by `self.site` and the response
# streamed back.
self.reactor.advance(0)
# We tear down the connection so it doesn't get reused without our
# knowledge.
server_to_client_transport.loseConnection()
client_to_server_transport.loseConnection()
return request_factory.request
def assert_request_is_get_repl_stream_updates(
self, request: SynapseRequest, stream_name: str
):
"""Asserts that the given request is a HTTP replication request for
fetching updates for given stream.
"""
self.assertRegex(
request.path,
br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
% (stream_name.encode("ascii"),),
)
self.assertEqual(request.method, b"GET")
class TestReplicationDataHandler(ReplicationDataHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
def __init__(self):
self.streams = set()
self._received_rdata_rows = []
def __init__(self, store: BaseSlavedStore):
super().__init__(store)
# streams to subscribe to: map from stream id to position
self.stream_positions = {} # type: Dict[str, int]
# list of received (stream_name, token, row) tuples
self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]]
def get_streams_to_replicate(self):
positions = {s: 0 for s in self.streams}
for stream, token, _ in self._received_rdata_rows:
if stream in self.streams:
positions[stream] = max(token, positions.get(stream, 0))
return positions
return self.stream_positions
async def on_rdata(self, stream_name, token, rows):
await super().on_rdata(stream_name, token, rows)
for r in rows:
self._received_rdata_rows.append((stream_name, token, r))
self.received_rdata_rows.append((stream_name, token, r))
async def on_position(self, stream_name, token):
pass
if (
stream_name in self.stream_positions
and token > self.stream_positions[stream_name]
):
self.stream_positions[stream_name] = token
@attr.s()
class OneShotRequestFactory:
"""A simple request factory that generates a single `SynapseRequest` and
stores it for future use. Can only be used once.
"""
request = attr.ib(default=None)
def __call__(self, *args, **kwargs):
assert self.request is None
self.request = SynapseRequest(*args, **kwargs)
return self.request
class _PushHTTPChannel(HTTPChannel):
"""A HTTPChannel that wraps pull producers to push producers.
This is a hack to get around the fact that HTTPChannel transparently wraps a
pull producer (which is what Synapse uses to reply to requests) with
`_PullToPush` to convert it to a push producer. Unfortunately `_PullToPush`
uses the standard reactor rather than letting us use our test reactor, which
makes it very hard to test.
"""
def __init__(self, reactor: IReactorTime):
super().__init__()
self.reactor = reactor
self._pull_to_push_producer = None # type: Optional[_PullToPushProducer]
def registerProducer(self, producer, streaming):
# Convert pull producers to push producer.
if not streaming:
self._pull_to_push_producer = _PullToPushProducer(
self.reactor, producer, self
)
producer = self._pull_to_push_producer
super().registerProducer(producer, True)
def unregisterProducer(self):
if self._pull_to_push_producer:
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()
class _PullToPushProducer:
"""A push producer that wraps a pull producer.
"""
def __init__(
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
):
self._clock = Clock(reactor)
self._producer = producer
self._consumer = consumer
# While running we use a looping call with a zero delay to call
# resumeProducing on given producer.
self._looping_call = None # type: Optional[LoopingCall]
# We start writing next reactor tick.
self._start_loop()
def _start_loop(self):
"""Start the looping call to
"""
if not self._looping_call:
# Start a looping call which runs every tick.
self._looping_call = self._clock.looping_call(self._run_once, 0)
def stop(self):
"""Stops calling resumeProducing.
"""
if self._looping_call:
self._looping_call.stop()
self._looping_call = None
def pauseProducing(self):
"""Implements IPushProducer
"""
self.stop()
def resumeProducing(self):
"""Implements IPushProducer
"""
self._start_loop()
def stopProducing(self):
"""Implements IPushProducer
"""
self.stop()
self._producer.stopProducing()
def _run_once(self):
"""Calls resumeProducing on producer once.
"""
try:
self._producer.resumeProducing()
except Exception:
logger.exception("Failed to call resumeProducing")
try:
self._consumer.unregisterProducer()
except Exception:
pass
self.stopProducing()

View File

@@ -0,0 +1,390 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests.replication.tcp.streams._base import BaseStreamTestCase
from tests.test_utils.event_injection import inject_event, inject_member_event
class EventsStreamTestCase(BaseStreamTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
super().prepare(reactor, clock, hs)
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
self.reconnect()
self.test_handler.stream_positions["events"] = 0
self.room_id = self.helper.create_room_as(tok=self.user_tok)
self.test_handler.received_rdata_rows.clear()
def test_update_function_event_row_limit(self):
"""Test replication with many non-state events
Checks that all events are correctly replicated when there are lots of
event rows to be replicated.
"""
# generate lots of non-state events. We inject them using inject_event
# so that they are not send out over replication until we call self.replicate().
events = [
self._inject_test_event()
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT + 1)
]
# also one state event
state_event = self._inject_state_event()
# check we're testing what we think we are: no rows should yet have been
# receieved
self.assertEqual([], self.test_handler.received_rdata_rows)
# now fire up the replicator
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
for event in events:
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, event.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state_event.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state_event.event_id)
self.assertEqual([], received_rows)
def test_update_function_huge_state_change(self):
"""Test replication with many state events
Ensures that all events are correctly replicated when there are lots of
state change rows to be replicated.
"""
# we want to generate lots of state changes at a single stream ID.
#
# We do this by having two branches in the DAG. On one, we have a moderator
# which that generates lots of state; on the other, we de-op the moderator,
# thus invalidating all the state.
OTHER_USER = "@other_user:localhost"
# have the user join
inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
)
pls["users"][OTHER_USER] = 50
self.helper.send_state(
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
fork_point = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
events = [
self._inject_state_event(sender=OTHER_USER)
for _ in range(_STREAM_UPDATE_TARGET_ROW_COUNT)
]
self.replicate()
# all those events and state changes should have landed
self.assertGreaterEqual(
len(self.test_handler.received_rdata_rows), 2 * len(events)
)
self.test_handler.received_rdata_rows.clear()
# a state event which doesn't get rolled back, to check that the state
# before the huge update comes through ok
state1 = self._inject_state_event()
# roll back all the state by de-modding the user
prev_events = fork_point
pls["users"][OTHER_USER] = 0
pl_event = inject_event(
self.hs,
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
room_id=self.room_id,
content=pls,
)
# one more bit of state that doesn't get rolled back
state2 = self._inject_state_event()
# check we're testing what we think we are: no rows should yet have been
# receieved
self.assertEqual([], self.test_handler.received_rdata_rows)
# now fire up the replicator
self.replicate()
# now we should have received all the expected rows in the right order.
#
# we expect:
#
# - two rows for state1
# - the PL event row, plus state rows for the PL event and each
# of the states that got reverted.
# - two rows for state2
received_rows = self.test_handler.received_rdata_rows
# first check the first two rows, which should be state1
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state1.event_id)
stream_name, token, row = received_rows.pop(0)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state1.event_id)
# now the last two rows, which should be state2
stream_name, token, row = received_rows.pop(-2)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, state2.event_id)
stream_name, token, row = received_rows.pop(-1)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
self.assertEqual(row.data.event_id, state2.event_id)
# that should leave us with the rows for the PL event
self.assertEqual(len(received_rows), len(events) + 2)
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_event.event_id)
# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
for stream_name, token, row in received_rows:
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
state_rows.append(row.data)
state_rows.sort(key=lambda r: r.state_key)
sr = state_rows.pop(0)
self.assertEqual(sr.type, EventTypes.PowerLevels)
self.assertEqual(sr.event_id, pl_event.event_id)
for sr in state_rows:
self.assertEqual(sr.type, "test_state_event")
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
def test_update_function_state_row_limit(self):
"""Test replication with many state events over several stream ids.
"""
# we want to generate lots of state changes, but for this test, we want to
# spread out the state changes over a few stream IDs.
#
# We do this by having two branches in the DAG. On one, we have four moderators,
# each of which that generates lots of state; on the other, we de-op the users,
# thus invalidating all the state.
NUM_USERS = 4
STATES_PER_USER = _STREAM_UPDATE_TARGET_ROW_COUNT // 4 + 1
user_ids = ["@user%i:localhost" % (i,) for i in range(NUM_USERS)]
# have the users join
for u in user_ids:
inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
self.room_id, EventTypes.PowerLevels, tok=self.user_tok
)
pls["users"].update({u: 50 for u in user_ids})
self.helper.send_state(
self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
)
# this is the point in the DAG where we make a fork
fork_point = self.get_success(
self.hs.get_datastore().get_latest_event_ids_in_room(self.room_id)
) # type: List[str]
events = [] # type: List[EventBase]
for user in user_ids:
events.extend(
self._inject_state_event(sender=user) for _ in range(STATES_PER_USER)
)
self.replicate()
# all those events and state changes should have landed
self.assertGreaterEqual(
len(self.test_handler.received_rdata_rows), 2 * len(events)
)
self.test_handler.received_rdata_rows.clear()
# now roll back all that state by de-modding the users
prev_events = fork_point
pl_events = []
for u in user_ids:
pls["users"][u] = 0
e = inject_event(
self.hs,
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
room_id=self.room_id,
content=pls,
)
prev_events = [e.event_id]
pl_events.append(e)
# check we're testing what we think we are: no rows should yet have been
# receieved
self.assertEqual([], self.test_handler.received_rdata_rows)
# now fire up the replicator
self.replicate()
# we should have received all the expected rows in the right order
received_rows = self.test_handler.received_rdata_rows
self.assertGreaterEqual(len(received_rows), len(events))
for i in range(NUM_USERS):
# for each user, we expect the PL event row, followed by state rows for
# the PL event and each of the states that got reverted.
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, pl_events[i].event_id)
# the state rows are unsorted
state_rows = [] # type: List[EventsStreamCurrentStateRow]
for j in range(STATES_PER_USER + 1):
stream_name, token, row = received_rows.pop(0)
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "state")
self.assertIsInstance(row.data, EventsStreamCurrentStateRow)
state_rows.append(row.data)
state_rows.sort(key=lambda r: r.state_key)
sr = state_rows.pop(0)
self.assertEqual(sr.type, EventTypes.PowerLevels)
self.assertEqual(sr.event_id, pl_events[i].event_id)
for sr in state_rows:
self.assertEqual(sr.type, "test_state_event")
# "None" indicates the state has been deleted
self.assertIsNone(sr.event_id)
self.assertEqual([], received_rows)
event_count = 0
def _inject_test_event(
self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
) -> EventBase:
if sender is None:
sender = self.user_id
if body is None:
body = "event %i" % (self.event_count,)
self.event_count += 1
return inject_event(
self.hs,
room_id=self.room_id,
sender=sender,
type="test_event",
content={"body": body},
**kwargs
)
def _inject_state_event(
self,
body: Optional[str] = None,
state_key: Optional[str] = None,
sender: Optional[str] = None,
) -> EventBase:
if sender is None:
sender = self.user_id
if state_key is None:
state_key = "state_%i" % (self.event_count,)
self.event_count += 1
if body is None:
body = "state event %s" % (state_key,)
return inject_event(
self.hs,
room_id=self.room_id,
sender=sender,
type="test_state_event",
state_key=state_key,
content={"body": body},
)

View File

@@ -12,6 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# type: ignore
from mock import Mock
from synapse.replication.tcp.streams._base import ReceiptsStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
@@ -20,11 +25,14 @@ USER_ID = "@feeling:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
def test_receipt(self):
self.reconnect()
# make the client subscribe to the receipts stream
self.test_handler.streams.add("receipts")
self.test_handler.stream_positions.update({"receipts": 0})
# tell the master to send a new receipt
self.get_success(

View File

@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import Mock
from synapse.handlers.typing import RoomMember
from synapse.replication.http import streams
from synapse.replication.tcp.streams import TypingStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
class TypingStreamTestCase(BaseStreamTestCase):
servlets = [
streams.register_servlets,
]
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
def test_typing(self):
typing = self.hs.get_typing_handler()
room_id = "!bar:blue"
self.reconnect()
# make the client subscribe to the typing stream
self.test_handler.stream_positions.update({"typing": 0})
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
self.reactor.advance(0)
# We should now see an attempt to connect to the master
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
self.assertEqual(room_id, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
# Now let's disconnect and insert some data.
self.disconnect()
self.test_handler.on_rdata.reset_mock()
typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
self.test_handler.on_rdata.assert_not_called()
self.reconnect()
self.pump(0.1)
# We should now see an attempt to connect to the master
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
# The from token should be the token from the last RDATA we got.
self.assertEqual(int(request.args[b"from_token"][0]), token)
self.test_handler.on_rdata.assert_called_once()
stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
self.assertEqual(room_id, row.room_id)
self.assertEqual([], row.user_ids)

View File

@@ -39,7 +39,7 @@ class RestHelper(object):
resource = attr.ib()
auth_user_id = attr.ib()
def create_room_as(self, room_creator, is_public=True, tok=None):
def create_room_as(self, room_creator=None, is_public=True, tok=None):
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"

View File

@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,3 +17,22 @@
"""
Utilities for running the unit tests
"""
from typing import Awaitable, TypeVar
TV = TypeVar("TV")
def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
"""Get the result from an Awaitable which should have completed
Asserts that the given awaitable has a result ready, and returns its value
"""
i = awaitable.__await__()
try:
next(i)
except StopIteration as e:
# awaitable returned a result
return e.value
# if next didn't raise, the awaitable hasn't completed.
raise Exception("awaitable has not yet completed")

View File

@@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.types import Collection
from tests.test_utils import get_awaitable_result
"""
Utility functions for poking events into the storage of the server under test.
"""
def inject_member_event(
hs: synapse.server.HomeServer,
room_id: str,
sender: str,
membership: str,
target: Optional[str] = None,
extra_content: Optional[dict] = None,
**kwargs
) -> EventBase:
"""Inject a membership event into a room."""
if target is None:
target = sender
content = {"membership": membership}
if extra_content:
content.update(extra_content)
return inject_event(
hs,
room_id=room_id,
type=EventTypes.Member,
sender=sender,
state_key=target,
content=content,
**kwargs
)
def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
**kwargs
) -> EventBase:
"""Inject a generic event into a room
Args:
hs: the homeserver under test
room_version: the version of the room we're inserting into.
if not specified, will be looked up
prev_event_ids: prev_events for the event. If not specified, will be looked up
kwargs: fields for the event to be created
"""
test_reactor = hs.get_reactor()
if room_version is None:
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
test_reactor.advance(0)
room_version = get_awaitable_result(d)
builder = hs.get_event_builder_factory().for_room_version(
KNOWN_ROOM_VERSIONS[room_version], kwargs
)
d = hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids
)
test_reactor.advance(0)
event, context = get_awaitable_result(d)
d = hs.get_storage().persistence.persist_event(event, context)
test_reactor.advance(0)
get_awaitable_result(d)
return event

View File

@@ -32,7 +32,6 @@ from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server as federation_server
@@ -55,6 +54,7 @@ from tests.server import (
render,
setup_test_homeserver,
)
from tests.test_utils import event_injection
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -596,36 +596,14 @@ class HomeserverTestCase(TestCase):
"""
Inject a membership event into a room.
Deprecated: use event_injection.inject_room_member directly
Args:
room: Room ID to inject the event into.
user: MXID of the user to inject the membership for.
membership: The membership type.
"""
event_builder_factory = self.hs.get_event_builder_factory()
event_creation_handler = self.hs.get_event_creation_handler()
room_version = self.get_success(
self.hs.get_datastore().get_room_version_id(room)
)
builder = event_builder_factory.for_room_version(
KNOWN_ROOM_VERSIONS[room_version],
{
"type": EventTypes.Member,
"sender": user,
"state_key": user,
"room_id": room,
"content": {"membership": membership},
},
)
event, context = self.get_success(
event_creation_handler.create_new_client_event(builder)
)
self.get_success(
self.hs.get_storage().persistence.persist_event(event, context)
)
event_injection.inject_member_event(self.hs, room, user, membership)
class FederatingHomeserverTestCase(HomeserverTestCase):

View File

@@ -74,7 +74,10 @@ def setupdb():
db_conn.autocommit = True
cur = db_conn.cursor()
cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
cur.execute(
"CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
"template=template0;" % (POSTGRES_BASE_DB,)
)
cur.close()
db_conn.close()

View File

@@ -204,6 +204,8 @@ commands = mypy \
synapse/storage/database.py \
synapse/streams \
synapse/util/caches/stream_change_cache.py \
tests/replication/tcp/streams \
tests/test_utils \
tests/util/test_stream_change_cache.py
# To find all folders that pass mypy you run: