1
0

Compare commits

..

46 Commits

Author SHA1 Message Date
Andrew Morgan 88d8a7cd19 log invalidating 2020-05-07 15:33:57 +01:00
Andrew Morgan 14a9d59edf Up get_users_in_room cache size again 2020-05-07 15:08:52 +01:00
Andrew Morgan fac30c0554 Don't async too many things 2020-05-07 15:00:02 +01:00
Andrew Morgan 908d15e904 Convert get_users_who_share_room_with_user to async/await 2020-05-07 14:56:01 +01:00
Andrew Morgan 0a1ebe5442 More logging, make get_users_who_share_room_with_user non-iterable 2020-05-07 14:46:47 +01:00
Andrew Morgan 9316091c8c Return get_users_in_room cache size 2020-05-07 14:26:04 +01:00
Andrew Morgan 7d54f2413f Increase get_users_in_room cache 2020-05-07 14:02:49 +01:00
Andrew Morgan 98a6910d94 Cache debugging 2020-05-07 13:57:11 +01:00
Andrew Morgan c08273f529 undo more dbkr 2020-05-07 13:54:22 +01:00
Andrew Morgan edca4cf1bb More debugging 2020-05-07 12:20:36 +01:00
Andrew Morgan 9d7fcb2fda Back out always telling user about their own devices 2020-05-07 12:06:32 +01:00
Andrew Morgan 1786b1ee0d Add debug logging 2020-05-07 11:58:28 +01:00
Brendan Abolivier 2929ce29d6 Merge pull request #7398 from Starbix/alpine-3.11
Update docker runtime image to Alpine v3.11
2020-05-07 11:56:56 +02:00
Richard van der Hoff 62ee862119 Merge branch 'release-v1.13.0' into develop 2020-05-06 15:56:03 +01:00
Andrew Morgan aee9130a83 Stop Auth methods from polling the config on every req. (#7420) 2020-05-06 15:54:58 +01:00
Richard van der Hoff fa0b2bd28d Merge pull request #7428 from matrix-org/rav/cross_signing_keys_cache
Make get_e2e_cross_signing_key delegate to get_e2e_cross_signing_keys_bulk
2020-05-06 12:00:01 +01:00
Richard van der Hoff 16b67c404d Make get_e2e_cross_signing_key delegate to get_e2e_cross_signing_keys_bulk
... mostly because the latter has a cache.
2020-05-06 11:59:19 +01:00
Richard van der Hoff db5f9031b7 Fix batching for fetching cross-signing keys
There's no point carefully dividing a list into batches, and then completely
ignoring the batches.
2020-05-06 11:59:19 +01:00
Richard van der Hoff 2e0c46ca07 Merge branch 'release-v1.13.0' into develop 2020-05-06 11:58:31 +01:00
Richard van der Hoff 79007a42b2 Merge pull request #7429 from matrix-org/rav/upsert_for_device_list
use an upsert to update device_lists_outbound_last_success
2020-05-06 11:53:18 +01:00
Richard van der Hoff 30a19daa02 Merge branch 'develop' into rav/upsert_for_device_list 2020-05-06 11:43:11 +01:00
Richard van der Hoff e48361545d use an upsert to update device_lists_outbound_last_success 2020-05-06 11:41:23 +01:00
Richard van der Hoff 0f6ebf393d Better type annotations for simple_upsert_txn
most of these params don't really need to be lists.
2020-05-06 11:41:23 +01:00
Erik Johnston b26f3e582c Merge pull request #7423 from matrix-org/erikj/faster_device_lists_fetch
Speed up fetching device lists changes in sync.
2020-05-06 11:14:13 +01:00
Richard van der Hoff c255b0ffdc Merge pull request #7427 from matrix-org/rav/fix_dropped_messages
Fix lost events on replication reconnection
2020-05-06 10:54:25 +01:00
Richard van der Hoff a8c17da245 Merge branch 'release-v1.13.0' into rav/fix_dropped_messages 2020-05-05 23:01:12 +01:00
Richard van der Hoff 1242267316 Merge branch 'release-v1.13.0' into rav/fix_dropped_messages 2020-05-05 22:38:44 +01:00
Richard van der Hoff 7bf788ac73 changelog 2020-05-05 22:38:16 +01:00
Richard van der Hoff 7f7eedbebb Wait for a POSITION on the right connection before accepting RDATA
... otherwise we can believe we're up to date when we're not.
2020-05-05 22:38:16 +01:00
Brendan Abolivier 5b8023dc7f Move logs about discarded RDATA to debug (#7421) 2020-05-05 21:07:33 +02:00
Richard van der Hoff d78265af0c Wait to subscribe before sending REPLICATE 2020-05-05 19:31:37 +01:00
Richard van der Hoff 13dd458b8d Merge branch 'release-v1.13.0' into erikj/faster_device_lists_fetch 2020-05-05 18:14:00 +01:00
Richard van der Hoff 714560e325 Update changelog.d/7423.misc 2020-05-05 18:03:59 +01:00
Erik Johnston 79fe3e068b Newsfile 2020-05-05 17:40:29 +01:00
Erik Johnston f9073893af Speed up fetching device lists changes in sync.
Currently we copy `users_who_share_room` needlessly about three times,
which is expensive when the set is large (which it can easily be).
2020-05-05 17:40:29 +01:00
Richard van der Hoff 16b1a34e80 Fix typing annotations in synapse/federation (#7382)
We're pretty close to having mypy working for `synapse.federation`, so let's
finish the job.
2020-05-05 14:27:13 +01:00
Patrick Cloke fe69fb6263 Add backwards compatibility codepath to LoggingContext. (#7408) 2020-05-05 09:21:34 -04:00
Erik Johnston 7941a70fa8 Fix bug in EventContext.deserialize. (#7393)
This caused `prev_state_ids` to be incorrect if the state event was not
replacing an existing state entry.
2020-05-05 14:17:27 +01:00
Richard van der Hoff d5aa7d93ed Fix catchup-on-reconnect for the Federation Stream (#7374)
looks like we managed to break this during the refactorathon.
2020-05-05 14:15:57 +01:00
Erik Johnston 8123b2f909 Add MultiWriterIdGenerator. (#7281)
This will be used to coordinate stream IDs across multiple writers.

Functions as the equivalent of both `StreamIdGenerator` and
`SlavedIdTracker`.
2020-05-04 17:17:45 +01:00
Brendan Abolivier 15aa09bbe6 Merge branch 'release-v1.13.0' into develop 2020-05-04 16:33:56 +02:00
Brendan Abolivier 9858d5c362 Fix ordering in MANIFEST.in 2020-05-04 16:33:30 +02:00
Brendan Abolivier ad088716bc Merge pull request #7404 from matrix-org/babolivier/fix_manifest
Fix MANIFEST.in
2020-05-04 16:24:15 +02:00
Brendan Abolivier 068da604c2 Fix MANIFEST.in
An update of check-manifest shone some light on some issues with MANIFEST.in, specifically that we didn't ignore/prune the contrib directory, and that we were using prune instead of exclude for files. This fixes both issues.

Fixes #7403
2020-05-04 15:18:06 +02:00
Erik Johnston 350421e058 Fix redis password support. (#7401)
We forgot to set the password on the subscriber connection, as well as
not calling super methods for overridden connectionMade/connectionLost
functions.
2020-05-04 14:04:09 +01:00
Cédric Laubacher a251e0f4ba Update runtime docker image to Alpine v3.11 2020-05-03 16:07:24 +02:00
59 changed files with 1200 additions and 402 deletions
+6 -5
View File
@@ -30,23 +30,24 @@ recursive-include synapse/static *.gif
recursive-include synapse/static *.html
recursive-include synapse/static *.js
exclude Dockerfile
exclude .codecov.yml
exclude .coveragerc
exclude .dockerignore
exclude test_postgresql.sh
exclude .editorconfig
exclude Dockerfile
exclude mypy.ini
exclude sytest-blacklist
exclude test_postgresql.sh
include pyproject.toml
recursive-include changelog.d *
prune .buildkite
prune .circleci
prune .codecov.yml
prune .coveragerc
prune .github
prune contrib
prune debian
prune demo/etc
prune docker
prune mypy.ini
prune snap
prune stubs
+1 -1
View File
@@ -1 +1 @@
Add typing information to federation server code.
Add typing annotations in `synapse.federation`.
+1
View File
@@ -0,0 +1 @@
Add MultiWriterIdGenerator to support multiple concurrent writers of streams.
+1
View File
@@ -0,0 +1 @@
Move catchup of replication streams logic to worker.
+1
View File
@@ -0,0 +1 @@
Add typing annotations in `synapse.federation`.
+1
View File
@@ -0,0 +1 @@
Fix bug in `EventContext.deserialize`.
+1
View File
@@ -0,0 +1 @@
Update docker runtime image to Alpine v3.11. Contributed by @Starbix.
+1
View File
@@ -0,0 +1 @@
Add support for running replication over Redis when using workers.
+1
View File
@@ -0,0 +1 @@
Fix issues with the Python package manifest.
+1
View File
@@ -0,0 +1 @@
Clean up some LoggingContext code.
+1
View File
@@ -0,0 +1 @@
Prevent methods in `synapse.handlers.auth` from polling the homeserver config every request.
+1
View File
@@ -0,0 +1 @@
Move catchup of replication streams logic to worker.
+1
View File
@@ -0,0 +1 @@
Speed up fetching device lists changes when handling `/sync` requests.
+1
View File
@@ -0,0 +1 @@
Add support for running replication over Redis when using workers.
+1
View File
@@ -0,0 +1 @@
Improve performance of `get_e2e_cross_signing_key`.
+1
View File
@@ -0,0 +1 @@
Improve performance of `mark_as_sent_devices_by_remote`.
+1 -1
View File
@@ -55,7 +55,7 @@ RUN pip install --prefix="/install" --no-warn-script-location \
### Stage 1: runtime
###
FROM docker.io/python:${PYTHON_VERSION}-alpine3.10
FROM docker.io/python:${PYTHON_VERSION}-alpine3.11
# xmlsec is required for saml support
RUN apk add --no-cache --virtual .runtime_deps \
+3
View File
@@ -22,7 +22,10 @@ class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
class SubscriberProtocol:
password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ...
def connectionMade(self): ...
def connectionLost(self, reason): ...
def lazyConnection(
host: str = ...,
+10 -73
View File
@@ -26,16 +26,15 @@ from twisted.internet import defer
import synapse.logging.opentracing as opentracing
import synapse.types
from synapse import event_auth
from synapse.api.constants import EventTypes, LimitBlockingTypes, Membership, UserTypes
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
InvalidClientTokenError,
MissingClientTokenError,
ResourceLimitError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import is_threepid_reserved
from synapse.events import EventBase
from synapse.types import StateMap, UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
@@ -77,7 +76,11 @@ class Auth(object):
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
register_cache("cache", "token_cache", self.token_cache)
self._auth_blocking = AuthBlocking(self.hs)
self._account_validity = hs.config.account_validity
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
@defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
@@ -191,7 +194,7 @@ class Auth(object):
opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self.hs.config.track_appservice_user_ips:
if ip_addr and self._track_appservice_user_ips:
yield self.store.insert_client_ip(
user_id=user_id,
access_token=access_token,
@@ -454,7 +457,7 @@ class Auth(object):
# access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.verify(macaroon, self.hs.config.macaroon_secret_key)
v.verify(macaroon, self._macaroon_secret_key)
def _verify_expiry(self, caveat):
prefix = "time < "
@@ -663,71 +666,5 @@ class Auth(object):
% (user_id, room_id),
)
@defer.inlineCallbacks
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
user_id(str|None): If present, checks for presence against existing
MAU cohort
threepid(dict|None): If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
user_type(str|None): If present, is used to decide whether to check against
certain blocking reasons like MAU.
"""
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
if user_id is not None:
if user_id == self.hs.config.server_notices_mxid:
return
if (yield self.store.is_support_user(user_id)):
return
if self.hs.config.hs_disabled:
raise ResourceLimitError(
403,
self.hs.config.hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact=self.hs.config.admin_contact,
limit_type=LimitBlockingTypes.HS_DISABLED,
)
if self.hs.config.limit_usage_by_mau is True:
assert not (user_id and threepid)
# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return
is_trial = yield self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
# If the user does not exist yet, but is signing up with a
# reserved threepid then pass auth check
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
return
elif user_type == UserTypes.SUPPORT:
# If the user does not exist yet and is of type "support",
# allow registration. Support users are excluded from MAU checks.
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise ResourceLimitError(
403,
"Monthly Active User Limit Exceeded",
admin_contact=self.hs.config.admin_contact,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER,
)
def check_auth_blocking(self, *args, **kwargs):
return self._auth_blocking.check_auth_blocking(*args, **kwargs)
+104
View File
@@ -0,0 +1,104 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
logger = logging.getLogger(__name__)
class AuthBlocking(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self._server_notices_mxid = hs.config.server_notices_mxid
self._hs_disabled = hs.config.hs_disabled
self._hs_disabled_message = hs.config.hs_disabled_message
self._admin_contact = hs.config.admin_contact
self._max_mau_value = hs.config.max_mau_value
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
@defer.inlineCallbacks
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
user_id(str|None): If present, checks for presence against existing
MAU cohort
threepid(dict|None): If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
user_type(str|None): If present, is used to decide whether to check against
certain blocking reasons like MAU.
"""
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
if user_id is not None:
if user_id == self._server_notices_mxid:
return
if (yield self.store.is_support_user(user_id)):
return
if self._hs_disabled:
raise ResourceLimitError(
403,
self._hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact=self._admin_contact,
limit_type=LimitBlockingTypes.HS_DISABLED,
)
if self._limit_usage_by_mau is True:
assert not (user_id and threepid)
# If the user is already part of the MAU cohort or a trial user
if user_id:
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
if timestamp:
return
is_trial = yield self.store.is_trial_user(user_id)
if is_trial:
return
elif threepid:
# If the user does not exist yet, but is signing up with a
# reserved threepid then pass auth check
if is_threepid_reserved(self._mau_limits_reserved_threepids, threepid):
return
elif user_type == UserTypes.SUPPORT:
# If the user does not exist yet and is of type "support",
# allow registration. Support users are excluded from MAU checks.
return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self._max_mau_value:
raise ResourceLimitError(
403,
"Monthly Active User Limit Exceeded",
admin_contact=self._admin_contact,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER,
)
+5 -2
View File
@@ -322,11 +322,14 @@ class _AsyncEventContextImpl(EventContext):
self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
self.state_group
)
if self._prev_state_id and self._event_state_key is not None:
if self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
key = (self._event_type, self._event_state_key)
self._prev_state_ids[key] = self._prev_state_id
if self._prev_state_id:
self._prev_state_ids[key] = self._prev_state_id
else:
self._prev_state_ids.pop(key, None)
else:
self._prev_state_ids = self._current_state_ids
+54 -30
View File
@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
import logging
from collections import namedtuple
from typing import Dict, List, Tuple, Type
from six import iteritems
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.presence_map = {} # Pending presence map user_id -> UserPresenceState
self.presence_changed = SortedDict() # Stream position -> list[user_id]
# Pending presence map user_id -> UserPresenceState
self.presence_map = {} # type: Dict[str, UserPresenceState]
# Stream position -> list[user_id]
self.presence_changed = SortedDict() # type: SortedDict[int, List[str]]
# Stores the destinations we need to explicitly send presence to about a
# given user.
# Stream position -> (user_id, destinations)
self.presence_destinations = SortedDict()
self.presence_destinations = (
SortedDict()
) # type: SortedDict[int, Tuple[str, List[str]]]
self.keyed_edu = {} # (destination, key) -> EDU
self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
# (destination, key) -> EDU
self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
self.edus = SortedDict() # stream position -> Edu
# stream position -> (destination, key)
self.keyed_edu_changed = (
SortedDict()
) # type: SortedDict[int, Tuple[str, tuple]]
self.edus = SortedDict() # type: SortedDict[int, Edu]
# stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
self.pos = 1
self.pos_time = SortedDict()
# map from stream ID to the time that stream entry was generated, so that we
# can clear out entries after a while
self.pos_time = SortedDict() # type: SortedDict[int, int]
# EVERYTHING IS SAD. In particular, python only makes new scopes when
# we make a new function, so we need to make a new function so the inner
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
for edu_key in self.keyed_edu_changed.values():
live_keys.add(edu_key)
to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
for edu_key in to_del:
keys_to_del = [
edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
]
for edu_key in keys_to_del:
del self.keyed_edu[edu_key]
# Delete things out of edu map
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
self._clear_queue_before_pos(token)
async def get_replication_rows(
self, from_token, to_token, limit, federation_ack=None
):
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
"""Get rows to be sent over federation between the two tokens
Args:
from_token (int)
to_token(int)
limit (int)
federation_ack (int): Optional. The position where the worker is
explicitly acknowledged it has handled. Allows us to drop
data from before that point
instance_name: the name of the current process
from_token: the previous stream token: the starting point for fetching the
updates
to_token: the new stream token: the point to get updates up to
target_row_count: a target for the number of rows to be returned.
Returns: a triplet `(updates, new_last_token, limited)`, where:
* `updates` is a list of `(token, row)` entries.
* `new_last_token` is the new position in stream.
* `limited` is whether there are more updates to fetch.
"""
# TODO: Handle limit.
# TODO: Handle target_row_count.
# To handle restarts where we wrap around
if from_token > self.pos:
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
# list of tuple(int, BaseFederationRow), where the first is the position
# of the federation stream.
rows = []
# There should be only one reader, so lets delete everything its
# acknowledged its seen.
if federation_ack:
self._clear_queue_before_pos(federation_ack)
rows = [] # type: List[Tuple[int, BaseFederationRow]]
# Fetch changed presence
i = self.presence_changed.bisect_right(from_token)
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
# Sort rows based on pos
rows.sort()
return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
return (
[(pos, (row.TypeId, row.to_data())) for pos, row in rows],
to_token,
False,
)
class BaseFederationRow(object):
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
@staticmethod
def from_data(data):
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
TypeToRow = {
Row.TypeId: Row
for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
}
_rowtypes = (
PresenceRow,
PresenceDestinationsRow,
KeyedEduRow,
EduRow,
) # type: Tuple[Type[BaseFederationRow], ...]
TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
ParsedFederationStreamData = namedtuple(
+7 -5
View File
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Dict, Hashable, Iterable, List, Optional, Set
from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
from six import itervalues
@@ -498,14 +498,16 @@ class FederationSender(object):
self._get_per_destination_queue(destination).attempt_new_transaction()
def get_current_token(self) -> int:
@staticmethod
def get_current_token() -> int:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return 0
@staticmethod
async def get_replication_rows(
self, from_token, to_token, limit, federation_ack=None
):
instance_name: str, from_token: int, to_token: int, target_row_count: int
) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
# Dummy implementation for case where federation sender isn't offloaded
# to a worker.
return []
return [], 0, False
@@ -15,11 +15,10 @@
# limitations under the License.
import datetime
import logging
from typing import Dict, Hashable, Iterable, List, Tuple
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
from prometheus_client import Counter
import synapse.server
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
if TYPE_CHECKING:
import synapse.server
# This is defined in the Matrix spec and enforced by the receiver.
MAX_EDUS_PER_TRANSACTION = 100
@@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import List
from typing import TYPE_CHECKING, List
from canonicaljson import json
import synapse.server
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
)
from synapse.util.metrics import measure_func
if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__)
+8 -7
View File
@@ -122,17 +122,17 @@ class DeviceWorkerHandler(BaseHandler):
# First we check if any devices have changed for users that we share
# rooms with.
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
users_who_share_room = yield defer.ensureDeferred(self.store.get_users_who_share_room_with_user(
user_id
)
))
tracked_users = set(users_who_share_room)
#tracked_users = set(users_who_share_room)
# Always tell the user about their own devices
tracked_users.add(user_id)
#tracked_users.add(user_id)
changed = yield self.store.get_users_whose_devices_changed(
from_token.device_list_key, tracked_users
from_token.device_list_key, users_who_share_room
)
# Then work out if any users have since joined
@@ -444,9 +444,10 @@ class DeviceHandler(DeviceWorkerHandler):
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
logger.info("get_users_who_share_room... called from notify_device_update")
users_who_share_room = yield defer.ensureDeferred(self.store.get_users_who_share_room_with_user(
user_id
)
))
hosts = set()
if self.hs.is_mine_id(user_id):
+1
View File
@@ -172,6 +172,7 @@ class EventHandler(BaseHandler):
if not event:
return None
logger.info("get_users_in_room called from get_event!")
users = await self.store.get_users_in_room(event.room_id)
is_peeking = user.to_string() not in users
+6 -20
View File
@@ -1547,7 +1547,7 @@ class FederationHandler(BaseHandler):
async def do_remotely_reject_invite(
self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict
) -> Optional[EventBase]:
) -> EventBase:
origin, event, room_version = await self._make_and_verify_event(
target_hosts, room_id, user_id, "leave", content=content
)
@@ -1564,26 +1564,12 @@ class FederationHandler(BaseHandler):
except ValueError:
pass
try:
await self.federation_client.send_leave(target_hosts, event)
return event
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warning("Failed to reject invite: %s", e)
await self.federation_client.send_leave(target_hosts, event)
await self.store.locally_reject_invite(user_id, room_id)
return None
finally:
# This block will always run before returning, and will return with
# whatever value was returned in the try/except blocks
# (it will not, for example, be over-written by None)
context = await self.state_handler.compute_event_context(event)
await self.persist_events_and_notify([(event, context)])
context = await self.state_handler.compute_event_context(event)
await self.persist_events_and_notify([(event, context)])
return event
async def _make_and_verify_event(
self,
+1
View File
@@ -1098,6 +1098,7 @@ class PresenceEventSource(object):
users_interested_in = set()
users_interested_in.add(user_id) # So that we receive our own presence
logger.info("get_users_who_share_room... _get_interested_in")
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id, on_invalidate=cache_context.invalidate
)
+18 -5
View File
@@ -976,12 +976,25 @@ class RoomMemberMasterHandler(RoomMemberHandler):
):
"""Implements RoomMemberHandler._remote_reject_invite
"""
ret = yield defer.ensureDeferred(
self.federation_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, target.to_string(), content=content,
fed_handler = self.federation_handler
try:
ret = yield defer.ensureDeferred(
fed_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, target.to_string(), content=content,
)
)
)
return ret if ret else {}
return ret
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warning("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(target.to_string(), room_id)
return {}
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
+13 -4
View File
@@ -308,12 +308,14 @@ class SyncHandler(object):
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
logger.info("_wait_for_sync_for_user1")
result = await self.current_sync_for_user(
sync_config, since_token, full_state=full_state
)
else:
def current_sync_callback(before_token, after_token):
logger.info("_wait_for_sync_for_user2")
return self.current_sync_for_user(sync_config, since_token)
result = await self.notifier.wait_for_events(
@@ -340,6 +342,7 @@ class SyncHandler(object):
) -> SyncResult:
"""Get the sync for client needed to match what the server has now.
"""
logger.info("current_sync_for_user")
return await self.generate_sync_result(sync_config, since_token, full_state)
async def push_rules_for_user(self, user: UserID) -> JsonDict:
@@ -1139,18 +1142,24 @@ class SyncHandler(object):
# room with by looking at all users that have left a room plus users
# that were in a room we've left.
logger.info("get_users_who_share_room... called from _generate_sync_entry")
logger.info("*Called with %s", user_id)
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
tracked_users = set(users_who_share_room)
# Always tell the user about their own devices. We check as the user
# ID is almost certainly already included (unless they're not in any
# rooms) and taking a copy of the set is relatively expensive.
#if user_id not in users_who_share_room:
# users_who_share_room = set(users_who_share_room)
# users_who_share_room.add(user_id)
# Always tell the user about their own devices
tracked_users.add(user_id)
#tracked_users = users_who_share_room
# Step 1a, check for changes in devices of users we share a room with
users_that_have_changed = await self.store.get_users_whose_devices_changed(
since_token.device_list_key, tracked_users
since_token.device_list_key, users_who_share_room
)
# Step 1b, check for newly joined rooms
+41
View File
@@ -27,6 +27,7 @@ import inspect
import logging
import threading
import types
import warnings
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
from typing_extensions import Literal
@@ -287,6 +288,46 @@ class LoggingContext(object):
return str(self.request)
return "%s@%x" % (self.name, id(self))
@classmethod
def current_context(cls) -> LoggingContextOrSentinel:
"""Get the current logging context from thread local storage
This exists for backwards compatibility. ``current_context()`` should be
called directly.
Returns:
LoggingContext: the current logging context
"""
warnings.warn(
"synapse.logging.context.LoggingContext.current_context() is deprecated "
"in favor of synapse.logging.context.current_context().",
DeprecationWarning,
stacklevel=2,
)
return current_context()
@classmethod
def set_current_context(
cls, context: LoggingContextOrSentinel
) -> LoggingContextOrSentinel:
"""Set the current logging context in thread local storage
This exists for backwards compatibility. ``set_current_context()`` should be
called directly.
Args:
context(LoggingContext): The context to activate.
Returns:
The context that was previously active
"""
warnings.warn(
"synapse.logging.context.LoggingContext.set_current_context() is deprecated "
"in favor of synapse.logging.context.set_current_context().",
DeprecationWarning,
stacklevel=2,
)
return set_current_context(context)
def __enter__(self) -> "LoggingContext":
"""Enters this logging context into thread local storage"""
old_context = set_current_context(self)
+18 -4
View File
@@ -135,10 +135,24 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
event = await self.federation_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, user_id, event_content,
)
return 200, event.get_pdu_json() if event else 200, {}
try:
event = await self.federation_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, user_id, event_content,
)
ret = event.get_pdu_json()
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
#
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
logger.warning("Failed to reject invite: %s", e)
await self.store.locally_reject_invite(user_id, room_id)
ret = {}
return 200, ret
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
@@ -135,6 +135,7 @@ class SlavedEventStore(
)
if data.type == EventTypes.Member:
logger.info("INVALIDATING get_rooms_for_user_with_stream_ordering")
self.get_rooms_for_user_with_stream_ordering.invalidate(
(data.state_key,)
)
+38 -17
View File
@@ -81,9 +81,6 @@ class ReplicationCommandHandler:
self._instance_id = hs.get_instance_id()
self._instance_name = hs.get_instance_name()
# Set of streams that we've caught up with.
self._streams_connected = set() # type: Set[str]
self._streams = {
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
} # type: Dict[str, Stream]
@@ -99,9 +96,13 @@ class ReplicationCommandHandler:
# The factory used to create connections.
self._factory = None # type: Optional[ReconnectingClientFactory]
# The currently connected connections.
# The currently connected connections. (The list of places we need to send
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
# For each connection, the incoming streams that are coming from that connection
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
@@ -257,12 +258,14 @@ class ReplicationCommandHandler:
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
with await self._position_linearizer.queue(cmd.stream_name):
if stream_name not in self._streams_connected:
# If the stream isn't marked as connected then we haven't seen a
# `POSITION` command yet, and so we may have missed some rows.
# make sure that we've processed a POSITION for this stream *on this
# connection*. (A POSITION on another connection is no good, as there
# is no guarantee that we have seen all the intermediate updates.)
sbc = self._streams_by_connection.get(conn)
if not sbc or stream_name not in sbc:
# Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then.
logger.warning(
logger.debug(
"Discarding RDATA for unconnected stream %s -> %s",
stream_name,
cmd.token,
@@ -302,21 +305,25 @@ class ReplicationCommandHandler:
# Ignore POSITION that are just our own echoes
return
stream = self._streams.get(cmd.stream_name)
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
stream_name = cmd.stream_name
stream = self._streams.get(stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
logger.error("Got POSITION for unknown stream: %s", stream_name)
return
# We protect catching up with a linearizer in case the replication
# connection reconnects under us.
with await self._position_linearizer.queue(cmd.stream_name):
with await self._position_linearizer.queue(stream_name):
# We're about to go and catch up with the stream, so remove from set
# of connected streams.
self._streams_connected.discard(cmd.stream_name)
for streams in self._streams_by_connection.values():
streams.discard(stream_name)
# We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch.
self._pending_batches.pop(cmd.stream_name, [])
self._pending_batches.pop(stream_name, [])
# Find where we previously streamed up to.
current_token = stream.current_token()
@@ -326,6 +333,12 @@ class ReplicationCommandHandler:
# between then and now.
missing_updates = cmd.token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
cmd.token,
)
(
updates,
current_token,
@@ -341,16 +354,18 @@ class ReplicationCommandHandler:
for token, rows in _batch_updates(updates):
await self.on_rdata(
cmd.stream_name,
stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
self._streams_connected.add(cmd.stream_name)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(stream_name, cmd.token)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
@@ -408,6 +423,12 @@ class ReplicationCommandHandler:
def lost_connection(self, connection: AbstractConnection):
"""Called when a connection is closed/lost.
"""
# we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None)
if streams:
logger.info(
"Lost replication connection; streams now disconnected: %s", streams
)
try:
self._connections.remove(connection)
except ValueError:
+36 -19
View File
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING
import txredisapi
from synapse.logging.context import PreserveLoggingContext
from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import (
Command,
@@ -41,8 +41,14 @@ logger = logging.getLogger(__name__)
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
"""Connection to redis subscribed to replication stream.
Parses incoming messages from redis into replication commands, and passes
them to `ReplicationCommandHandler`
This class fulfils two functions:
(a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis
connection, parsing *incoming* messages into replication commands, and passing them
to `ReplicationCommandHandler`
(b) it implements the AbstractConnection API, where it sends *outgoing* commands
onto outbound_redis_connection.
Due to the vagaries of `txredisapi` we don't want to have a custom
constructor, so instead we expect the defined attributes below to be set
@@ -50,8 +56,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
Attributes:
handler: The command handler to handle incoming commands.
stream_name: The *redis* stream name to subscribe to (not anything to
do with Synapse replication streams).
stream_name: The *redis* stream name to subscribe to and publish from
(not anything to do with Synapse replication streams).
outbound_redis_connection: The connection to redis to use to send
commands.
"""
@@ -61,12 +67,23 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
outbound_redis_connection = None # type: txredisapi.RedisProtocol
def connectionMade(self):
logger.info("Connected to redis instance")
self.subscribe(self.stream_name)
self.send_command(ReplicateCommand())
logger.info("Connected to redis")
super().connectionMade()
run_as_background_process("subscribe-replication", self._send_subscribe)
self.handler.new_connection(self)
async def _send_subscribe(self):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
await make_deferred_yieldable(self.subscribe(self.stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
"""
@@ -119,7 +136,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
logger.warning("Unhandled command: %r", cmd)
def connectionLost(self, reason):
logger.info("Lost connection to redis instance")
logger.info("Lost connection to redis")
super().connectionLost(reason)
self.handler.lost_connection(self)
def send_command(self, cmd: Command):
@@ -128,6 +146,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
Args:
cmd (Command)
"""
run_as_background_process("send-cmd", self._async_send_command, cmd)
async def _async_send_command(self, cmd: Command):
"""Encode a replication command and send it over our outbound connection"""
string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
@@ -138,15 +160,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# remote instances.
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
async def _send():
with PreserveLoggingContext():
# Note that we use the other connection as we can't send
# commands using the subscription connection.
await self.outbound_redis_connection.publish(
self.stream_name, encoded_string
)
run_as_background_process("send-cmd", _send)
await make_deferred_yieldable(
self.outbound_redis_connection.publish(self.stream_name, encoded_string)
)
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
@@ -189,5 +205,6 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
p.handler = self.handler
p.outbound_redis_connection = self.outbound_redis_connection
p.stream_name = self.stream_name
p.password = self.password
return p
+1 -1
View File
@@ -80,7 +80,7 @@ class ReplicationStreamer(object):
for stream in STREAMS_MAP.values():
if stream == FederationStream and hs.config.send_federation:
# We only support federation stream if federation sending
# hase been disabled on the master.
# has been disabled on the master.
continue
self.streams.append(stream(hs))
+2 -1
View File
@@ -104,7 +104,8 @@ class Stream(object):
implemented by subclasses.
current_token_function is called to get the current token of the underlying
stream.
stream. It is only meaningful on the process that is the source of the
replication stream (ie, usually the master).
update_function is called to get updates for this stream between a pair of
stream tokens. See the UpdateFunction type definition for more info.
+21 -9
View File
@@ -15,7 +15,7 @@
# limitations under the License.
from collections import namedtuple
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
from synapse.replication.tcp.streams._base import Stream, make_http_update_function
class FederationStream(Stream):
@@ -35,21 +35,33 @@ class FederationStream(Stream):
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
# Not all synapse instances will have a federation sender instance,
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
# so we stub the stream out when that is the case.
if hs.config.worker_app is None or hs.should_send_federation():
if hs.config.worker_app is None:
# master process: get updates from the FederationRemoteSendQueue.
# (if the master is configured to send federation itself, federation_sender
# will be a real FederationSender, which has stubs for current_token and
# get_replication_rows.)
federation_sender = hs.get_federation_sender()
current_token = federation_sender.get_current_token
update_function = db_query_to_update_function(
federation_sender.get_replication_rows
)
update_function = federation_sender.get_replication_rows
elif hs.should_send_federation():
# federation sender: Query master process
update_function = make_http_update_function(hs, self.NAME)
current_token = self._stub_current_token
else:
current_token = lambda: 0
# other worker: stub out the update function (we're not interested in
# any updates so when we get a POSITION we do nothing)
update_function = self._stub_update_function
current_token = self._stub_current_token
super().__init__(hs.get_instance_name(), current_token, update_function)
@staticmethod
def _stub_current_token():
# dummy current-token method for use on workers
return 0
@staticmethod
async def _stub_update_function(instance_name, from_token, upto_token, limit):
return [], upto_token, False
+1
View File
@@ -171,6 +171,7 @@ class SyncRestServlet(RestServlet):
user.to_string(), affect_presence=affect_presence
)
with context:
logger.info("Sync servlet")
sync_result = await self.sync_handler.wait_for_sync_for_user(
sync_config,
since_token=since_token,
+44 -20
View File
@@ -55,6 +55,10 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX = (
"drop_device_lists_outbound_last_success_non_unique_idx"
)
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
@@ -342,32 +346,23 @@ class DeviceWorkerStore(SQLBaseStore):
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# We update the device_lists_outbound_last_success with the successfully
# poked users. We do the join to see which users need to be inserted and
# which updated.
# poked users.
sql = """
SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
SELECT user_id, coalesce(max(o.stream_id), 0)
FROM device_lists_outbound_pokes as o
LEFT JOIN device_lists_outbound_last_success as s
USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id))
rows = txn.fetchall()
sql = """
UPDATE device_lists_outbound_last_success
SET stream_id = ?
WHERE destination = ? AND user_id = ?
"""
txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
sql = """
INSERT INTO device_lists_outbound_last_success
(destination, user_id, stream_id) VALUES (?, ?, ?)
"""
txn.executemany(
sql, ((destination, row[0], row[1]) for row in rows if not row[2])
self.db.simple_upsert_many_txn(
txn=txn,
table="device_lists_outbound_last_success",
key_names=("destination", "user_id"),
key_values=((destination, user_id) for user_id, _ in rows),
value_names=("stream_id",),
value_values=((stream_id,) for _, stream_id in rows),
)
# Delete all sent outbound pokes
@@ -541,8 +536,8 @@ class DeviceWorkerStore(SQLBaseStore):
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
to_check = list(
self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
to_check = self._device_list_stream_cache.get_entities_changed(
user_ids, from_key
)
if not to_check:
@@ -725,6 +720,21 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
)
# create a unique index on device_lists_outbound_last_success
self.db.updates.register_background_index_update(
"device_lists_outbound_last_success_unique_idx",
index_name="device_lists_outbound_last_success_unique_idx",
table="device_lists_outbound_last_success",
columns=["destination", "user_id"],
unique=True,
)
# once that completes, we can remove the old non-unique index.
self.db.updates.register_background_update_handler(
BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX,
self._drop_device_lists_outbound_last_success_non_unique_idx,
)
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
@@ -799,6 +809,20 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
return rows
async def _drop_device_lists_outbound_last_success_non_unique_idx(
self, progress, batch_size
):
def f(txn):
txn.execute("DROP INDEX IF EXISTS device_lists_outbound_last_success_idx")
await self.db.runInteraction(
"drop_device_lists_outbound_last_success_non_unique_idx", f,
)
await self.db.updates._end_background_update(
BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX
)
return 1
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
@@ -25,7 +25,9 @@ from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -268,53 +270,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
"""Returns a user's cross-signing key.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_id (str): the user whose key is being requested
key_type (str): the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
the key will be included in the result
Returns:
dict of the key data or None if not found
"""
sql = (
"SELECT keydata "
" FROM e2e_cross_signing_keys "
" WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
)
txn.execute(sql, (user_id, key_type))
row = txn.fetchone()
if not row:
return None
key = json.loads(row[0])
device_id = None
for k in key["keys"].values():
device_id = k
if from_user_id is not None:
sql = (
"SELECT key_id, signature "
" FROM e2e_cross_signing_signatures "
" WHERE user_id = ? "
" AND target_user_id = ? "
" AND target_device_id = ? "
)
txn.execute(sql, (from_user_id, user_id, device_id))
row = txn.fetchone()
if row:
key.setdefault("signatures", {}).setdefault(from_user_id, {})[
row[0]
] = row[1]
return key
@defer.inlineCallbacks
def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
"""Returns a user's cross-signing key.
@@ -329,13 +285,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Returns:
dict of the key data or None if not found
"""
return self.db.runInteraction(
"get_e2e_cross_signing_key",
self._get_e2e_cross_signing_key_txn,
user_id,
key_type,
from_user_id,
)
res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
user_keys = res.get(user_id)
if not user_keys:
return None
return user_keys.get(key_type)
@cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id):
@@ -391,26 +345,24 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"""
result = {}
batch_size = 100
chunks = [
user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
]
for user_chunk in chunks:
sql = """
for user_chunk in batch_iter(user_ids, 100):
clause, params = make_in_list_sql_clause(
txn.database_engine, "k.user_id", user_chunk
)
sql = (
"""
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
FROM e2e_cross_signing_keys k
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
FROM e2e_cross_signing_keys
GROUP BY user_id, keytype) s
USING (user_id, stream_id, keytype)
WHERE k.user_id IN (%s)
""" % (
",".join("?" for u in user_chunk),
WHERE
"""
+ clause
)
query_params = []
query_params.extend(user_chunk)
txn.execute(sql, query_params)
txn.execute(sql, params)
rows = self.db.cursor_to_dict(txn)
for row in rows:
@@ -453,15 +405,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
device_id = k
devices[(user_id, device_id)] = key_type
device_list = list(devices)
# split into batches
batch_size = 100
chunks = [
device_list[i : i + batch_size]
for i in range(0, len(device_list), batch_size)
]
for user_chunk in chunks:
for batch in batch_iter(devices.keys(), size=100):
sql = """
SELECT target_user_id, target_device_id, key_id, signature
FROM e2e_cross_signing_signatures
@@ -469,11 +413,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
AND (%s)
""" % (
" OR ".join(
"(target_user_id = ? AND target_device_id = ?)" for d in devices
"(target_user_id = ? AND target_device_id = ?)" for _ in batch
)
)
query_params = [from_user_id]
for item in devices:
for item in batch:
# item is a (user_id, device_id) tuple
query_params.extend(item)
@@ -605,6 +605,7 @@ class EventsStore(
}
for member in members_changed:
logger.info("INVALIDATING get_rooms_for_user_with_stream_ordering")
txn.call_after(
self.get_rooms_for_user_with_stream_ordering.invalidate, (member,)
)
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
import traceback
from typing import Iterable, List, Set
from six import iteritems, itervalues
@@ -163,8 +164,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
return hosts
@cached(max_entries=100000, iterable=True)
@cached(max_entries=1000000, iterable=True)
def get_users_in_room(self, room_id):
logger.info("Traceback: %s", traceback.format_stack())
return self.db.runInteraction(
"get_users_in_room", self.get_users_in_room_txn, room_id
)
@@ -484,17 +486,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
return frozenset(r.room_id for r in rooms)
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context):
@cached(max_entries=500000, cache_context=True, iterable=True)
async def get_users_who_share_room_with_user(self, user_id, cache_context):
"""Returns the set of users who share a room with `user_id`
"""
room_ids = yield self.get_rooms_for_user(
logger.info("Called with %s %s", user_id, cache_context)
room_ids = await self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate
)
user_who_share_room = set()
for room_id in room_ids:
user_ids = yield self.get_users_in_room(
user_ids = await self.get_users_in_room(
room_id, on_invalidate=cache_context.invalidate
)
user_who_share_room.update(user_ids)
@@ -0,0 +1,28 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- register a background update which will create a unique index on
-- device_lists_outbound_last_success
INSERT into background_updates (ordering, update_name, progress_json)
VALUES (5804, 'device_lists_outbound_last_success_unique_idx', '{}');
-- once that completes, we can drop the old index.
INSERT into background_updates (ordering, update_name, progress_json, depends_on)
VALUES (
5804,
'drop_device_lists_outbound_last_success_non_unique_idx',
'{}',
'device_lists_outbound_last_success_unique_idx'
);
+44 -30
View File
@@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
from synapse.types import Collection
from synapse.util.stringutils import exception_to_unicode
logger = logging.getLogger(__name__)
@@ -78,6 +79,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
"device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
"device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
"event_search": "event_search_event_id_idx",
"device_lists_outbound_last_success": "device_lists_outbound_last_success_unique_idx",
}
@@ -889,20 +891,24 @@ class Database(object):
txn.execute(sql, list(allvalues.values()))
def simple_upsert_many_txn(
self, txn, table, key_names, key_values, value_names, value_values
):
self,
txn: LoggingTransaction,
table: str,
key_names: Collection[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[str]],
) -> None:
"""
Upsert, many times.
Args:
table (str): The table to upsert into
key_names (list[str]): The key column names.
key_values (list[list]): A list of each row's key column values.
value_names (list[str]): The value column names. If empty, no
values will be used, even if value_values is provided.
value_values (list[list]): A list of each row's value column values.
Returns:
None
table: The table to upsert into
key_names: The key column names.
key_values: A list of each row's key column values.
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
@@ -914,20 +920,24 @@ class Database(object):
)
def simple_upsert_many_txn_emulated(
self, txn, table, key_names, key_values, value_names, value_values
):
self,
txn: LoggingTransaction,
table: str,
key_names: Iterable[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[str]],
) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
Args:
table (str): The table to upsert into
key_names (list[str]): The key column names.
key_values (list[list]): A list of each row's key column values.
value_names (list[str]): The value column names. If empty, no
values will be used, even if value_values is provided.
value_values (list[list]): A list of each row's value column values.
Returns:
None
table: The table to upsert into
key_names: The key column names.
key_values: A list of each row's key column values.
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
"""
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
@@ -941,20 +951,24 @@ class Database(object):
self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
def simple_upsert_many_txn_native_upsert(
self, txn, table, key_names, key_values, value_names, value_values
):
self,
txn: LoggingTransaction,
table: str,
key_names: Collection[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
value_values: Iterable[Iterable[Any]],
) -> None:
"""
Upsert, many times, using batching where possible.
Args:
table (str): The table to upsert into
key_names (list[str]): The key column names.
key_values (list[list]): A list of each row's key column values.
value_names (list[str]): The value column names. If empty, no
values will be used, even if value_values is provided.
value_values (list[list]): A list of each row's value column values.
Returns:
None
table: The table to upsert into
key_names: The key column names.
key_values: A list of each row's key column values.
value_names: The value column names
value_values: A list of each row's value column values.
Ignored if value_names is empty.
"""
allnames = [] # type: List[str]
allnames.extend(key_names)
+167 -2
View File
@@ -16,6 +16,11 @@
import contextlib
import threading
from collections import deque
from typing import Dict, Set, Tuple
from typing_extensions import Deque
from synapse.storage.database import Database, LoggingTransaction
class IdGenerator(object):
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
self._unfinished_ids = deque()
self._unfinished_ids = deque() # type: Deque[int]
def get_next(self):
"""
@@ -163,7 +168,7 @@ class ChainedIdGenerator(object):
self.chained_generator = chained_generator
self._lock = threading.Lock()
self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque()
self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
def get_next(self):
"""
@@ -198,3 +203,163 @@ class ChainedIdGenerator(object):
return stream_id - 1, chained_id
return self._current_max, self.chained_generator.get_current_token()
class MultiWriterIdGenerator:
"""An ID generator that tracks a stream that can have multiple writers.
Uses a Postgres sequence to coordinate ID assignment, but positions of other
writers will only get updated when `advance` is called (by replication).
Note: Only works with Postgres.
Args:
db_conn
db
instance_name: The name of this instance.
table: Database table associated with stream.
instance_column: Column that stores the row's writer's instance name
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
"""
def __init__(
self,
db_conn,
db: Database,
instance_name: str,
table: str,
instance_column: str,
id_column: str,
sequence_name: str,
):
self._db = db
self._instance_name = instance_name
self._sequence_name = sequence_name
# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()
self._current_positions = self._load_current_ids(
db_conn, table, instance_column, id_column
)
# Set of local IDs that we're still processing. The current position
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
sql = """
SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
GROUP BY %(instance)s
""" % {
"instance": instance_column,
"id": id_column,
"table": table,
}
cur = db_conn.cursor()
cur.execute(sql)
# `cur` is an iterable over returned rows, which are 2-tuples.
current_positions = dict(cur)
cur.close()
return current_positions
def _load_next_id_txn(self, txn):
txn.execute("SELECT nextval(?)", (self._sequence_name,))
(next_id,) = txn.fetchone()
return next_id
async def get_next(self):
"""
Usage:
with await stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
assert self.get_current_token() < next_id
with self._lock:
self._unfinished_ids.add(next_id)
@contextlib.contextmanager
def manager():
try:
yield next_id
finally:
self._mark_id_as_finished(next_id)
return manager()
def get_next_txn(self, txn: LoggingTransaction):
"""
Usage:
stream_id = stream_id_gen.get_next(txn)
# ... persist event ...
"""
next_id = self._load_next_id_txn(txn)
with self._lock:
self._unfinished_ids.add(next_id)
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
return next_id
def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
current poistion if possible.
"""
with self._lock:
self._unfinished_ids.discard(next_id)
# Figure out if its safe to advance the position by checking there
# aren't any lower allocated IDs that are yet to finish.
if all(c > next_id for c in self._unfinished_ids):
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, next_id)
def get_current_token(self, instance_name: str = None) -> int:
"""Gets the current position of a named writer (defaults to current
instance).
Returns 0 if we don't have a position for the named writer (likely due
to it being a new writer).
"""
if instance_name is None:
instance_name = self._instance_name
with self._lock:
return self._current_positions.get(instance_name, 0)
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
"""
with self._lock:
return dict(self._current_positions)
def advance(self, instance_name: str, new_id: int):
"""Advance the postion of the named writer to the given ID, if greater
than existing entry.
"""
with self._lock:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
)
+15 -4
View File
@@ -14,12 +14,13 @@
# limitations under the License.
import logging
from typing import Dict, Iterable, List, Mapping, Optional, Set
from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union
from six import integer_types
from sortedcontainers import SortedDict
from synapse.types import Collection
from synapse.util import caches
logger = logging.getLogger(__name__)
@@ -85,8 +86,8 @@ class StreamChangeCache:
return False
def get_entities_changed(
self, entities: Iterable[EntityType], stream_pos: int
) -> Set[EntityType]:
self, entities: Collection[EntityType], stream_pos: int
) -> Union[Set[EntityType], FrozenSet[EntityType]]:
"""
Returns subset of entities that have had new things since the given
position. Entities unknown to the cache will be returned. If the
@@ -94,7 +95,17 @@ class StreamChangeCache:
"""
changed_entities = self.get_all_entities_changed(stream_pos)
if changed_entities is not None:
result = set(changed_entities).intersection(entities)
# We now do an intersection, trying to do so in the most efficient
# way possible (some of these sets are *large*). First check in the
# given iterable is already set that we can reuse, otherwise we
# create a set of the *smallest* of the two iterables and call
# `intersection(..)` on it (this can be twice as fast as the reverse).
if isinstance(entities, (set, frozenset)):
result = entities.intersection(changed_entities)
elif len(changed_entities) < len(entities):
result = set(changed_entities).intersection(entities)
else:
result = set(entities).intersection(changed_entities)
self.metrics.inc_hits()
else:
result = set(entities)
+20 -16
View File
@@ -52,6 +52,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = TestHandlers(self.hs)
self.auth = Auth(self.hs)
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = self.auth._auth_blocking
self.test_user = "@foo:bar"
self.test_token = b"_test_token_"
@@ -321,15 +325,15 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_blocking_mau(self):
self.hs.config.limit_usage_by_mau = False
self.hs.config.max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = False
self.auth_blocking._max_mau_value = 50
lots_of_users = 100
small_number_of_users = 1
# Ensure no error thrown
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.hs.config.limit_usage_by_mau = True
self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(lots_of_users)
@@ -349,8 +353,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self):
self.hs.config.max_mau_value = 50
self.hs.config.limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Support users allowed
@@ -370,12 +374,12 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_reserved_threepid(self):
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1
self.store.get_monthly_active_count = lambda: defer.succeed(2)
threepid = {"medium": "email", "address": "reserved@server.com"}
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.hs.config.mau_limits_reserved_threepids = [threepid]
self.auth_blocking._mau_limits_reserved_threepids = [threepid]
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(self.auth.check_auth_blocking())
@@ -389,8 +393,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_hs_disabled(self):
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
@@ -404,10 +408,10 @@ class AuthTestCase(unittest.TestCase):
"""
# this should be the default, but we had a bug where the test was doing the wrong
# thing, so let's make it explicit
self.hs.config.server_notices_mxid = None
self.auth_blocking._server_notices_mxid = None
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
@@ -416,8 +420,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_server_notices_mxid_special_cased(self):
self.hs.config.hs_disabled = True
self.auth_blocking._hs_disabled = True
user = "@user:server"
self.hs.config.server_notices_mxid = user
self.hs.config.hs_disabled_message = "Reason for being disabled"
self.auth_blocking._server_notices_mxid = user
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
+100
View File
@@ -0,0 +1,100 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.events.snapshot import EventContext
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from tests import unittest
from tests.test_utils.event_injection import create_event
class TestEventContext(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
room.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.user_id = self.register_user("u1", "pass")
self.user_tok = self.login("u1", "pass")
self.room_id = self.helper.create_room_as(tok=self.user_tok)
def test_serialize_deserialize_msg(self):
"""Test that an EventContext for a message event is the same after
serialize/deserialize.
"""
event, context = create_event(
self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
)
self._check_serialize_deserialize(event, context)
def test_serialize_deserialize_state_no_prev(self):
"""Test that an EventContext for a state event (with not previous entry)
is the same after serialize/deserialize.
"""
event, context = create_event(
self.hs,
room_id=self.room_id,
type="m.test",
sender=self.user_id,
state_key="",
)
self._check_serialize_deserialize(event, context)
def test_serialize_deserialize_state_prev(self):
"""Test that an EventContext for a state event (which replaces a
previous entry) is the same after serialize/deserialize.
"""
event, context = create_event(
self.hs,
room_id=self.room_id,
type="m.room.member",
sender=self.user_id,
state_key=self.user_id,
content={"membership": "leave"},
)
self._check_serialize_deserialize(event, context)
def _check_serialize_deserialize(self, event, context):
serialized = self.get_success(context.serialize(event, self.store))
d_context = EventContext.deserialize(self.storage, serialized)
self.assertEqual(context.state_group, d_context.state_group)
self.assertEqual(context.rejected, d_context.rejected)
self.assertEqual(
context.state_group_before_event, d_context.state_group_before_event
)
self.assertEqual(context.prev_group, d_context.prev_group)
self.assertEqual(context.delta_ids, d_context.delta_ids)
self.assertEqual(context.app_service, d_context.app_service)
self.assertEqual(
self.get_success(context.get_current_state_ids()),
self.get_success(d_context.get_current_state_ids()),
)
self.assertEqual(
self.get_success(context.get_prev_state_ids()),
self.get_success(d_context.get_prev_state_ids()),
)
+14 -9
View File
@@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase):
self.hs.handlers = AuthHandlers(self.hs)
self.auth_handler = self.hs.handlers.auth_handler
self.macaroon_generator = self.hs.get_macaroon_generator()
# MAU tests
self.hs.config.max_mau_value = 50
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = self.hs.get_auth()._auth_blocking
self.auth_blocking._max_mau_value = 50
self.small_number_of_users = 1
self.large_number_of_users = 100
@@ -119,7 +124,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
@@ -135,7 +140,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_mau_limits_exceeded_large(self):
self.hs.config.limit_usage_by_mau = True
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
@@ -159,11 +164,11 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_mau_limits_parity(self):
self.hs.config.limit_usage_by_mau = True
self.auth_blocking._limit_usage_by_mau = True
# If not in monthly active cohort
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -173,7 +178,7 @@ class AuthTestCase(unittest.TestCase):
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
with self.assertRaises(ResourceLimitError):
yield defer.ensureDeferred(
@@ -186,7 +191,7 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.get_access_token_for_user_id(
@@ -197,7 +202,7 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.get_clock().time_msec())
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
return_value=defer.succeed(self.auth_blocking._max_mau_value)
)
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
@@ -207,7 +212,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_mau_limits_not_exceeded(self):
self.hs.config.limit_usage_by_mau = True
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
+8 -5
View File
@@ -30,28 +30,31 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore()
def test_wait_for_sync_for_user_auth_blocking(self):
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = self.hs.get_auth()._auth_blocking
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
self.reactor.advance(100) # So we get not 0 time
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
self.auth_blocking._limit_usage_by_mau = True
self.auth_blocking._max_mau_value = 1
# Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1))
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
# Test that global lock works
self.hs.config.hs_disabled = True
self.auth_blocking._hs_disabled = True
e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.hs.config.hs_disabled = False
self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
+13 -7
View File
@@ -27,6 +27,7 @@ from synapse.app.generic_worker import (
GenericWorkerServer,
)
from synapse.http.site import SynapseRequest
from synapse.replication.http import streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -42,6 +43,10 @@ logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
servlets = [
streams.register_servlets,
]
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -49,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.server = server_factory.buildProtocol(None)
# Make a new HomeServer object for the worker
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
http_client=None,
homeserverToUse=GenericWorkerServer,
config=config,
config=self._get_worker_hs_config(),
reactor=self.reactor,
)
@@ -78,6 +77,13 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config
def _build_replication_data_handler(self):
return TestReplicationDataHandler(self.worker_hs)
@@ -0,0 +1,81 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.federation.send_queue import EduRow
from synapse.replication.tcp.streams.federation import FederationStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
class FederationStreamTestCase(BaseStreamTestCase):
def _get_worker_hs_config(self) -> dict:
# enable federation sending on the worker
config = super()._get_worker_hs_config()
# TODO: make it so we don't need both of these
config["send_federation"] = True
config["worker_app"] = "synapse.app.federation_sender"
return config
def test_catchup(self):
"""Basic test of catchup on reconnect
Makes sure that updates sent while we are offline are received later.
"""
fed_sender = self.hs.get_federation_sender()
received_rows = self.test_handler.received_rdata_rows
fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"})
self.reconnect()
self.reactor.advance(0)
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual(received_rows, [])
# We should now see an attempt to connect to the master
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "federation")
# we should have received an update row
stream_name, token, row = received_rows.pop()
self.assertEqual(stream_name, "federation")
self.assertIsInstance(row, FederationStream.FederationStreamRow)
self.assertEqual(row.type, EduRow.TypeId)
edurow = EduRow.from_data(row.data)
self.assertEqual(edurow.edu.edu_type, "m.test_edu")
self.assertEqual(edurow.edu.origin, self.hs.hostname)
self.assertEqual(edurow.edu.destination, "testdest")
self.assertEqual(edurow.edu.content, {"a": "b"})
self.assertEqual(received_rows, [])
# additional updates should be transferred without an HTTP hit
fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"})
self.reactor.advance(0)
# there should be no http hit
self.assertEqual(len(self.reactor.tcpClients), 0)
# ... but we should have a row
self.assertEqual(len(received_rows), 1)
stream_name, token, row = received_rows.pop()
self.assertEqual(stream_name, "federation")
self.assertIsInstance(row, FederationStream.FederationStreamRow)
self.assertEqual(row.type, EduRow.TypeId)
edurow = EduRow.from_data(row.data)
self.assertEqual(edurow.edu.edu_type, "m.test1")
self.assertEqual(edurow.edu.origin, self.hs.hostname)
self.assertEqual(edurow.edu.destination, "testdest")
self.assertEqual(edurow.edu.content, {"c": "d"})
@@ -15,7 +15,6 @@
from mock import Mock
from synapse.handlers.typing import RoomMember
from synapse.replication.http import streams
from synapse.replication.tcp.streams import TypingStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
@@ -24,10 +23,6 @@ USER_ID = "@feeling:blue"
class TypingStreamTestCase(BaseStreamTestCase):
servlets = [
streams.register_servlets,
]
def _build_replication_data_handler(self):
return Mock(wraps=super()._build_replication_data_handler())
+184
View File
@@ -0,0 +1,184 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.database import Database
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.unittest import HomeserverTestCase
from tests.utils import USE_POSTGRES_FOR_TESTS
class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db = self.store.db # type: Database
self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db,
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
)
return self.get_success(self.db.runWithConnection(_create))
def _insert_rows(self, instance_name: str, number: int):
def _insert(txn):
for _ in range(number):
txn.execute(
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
(instance_name,),
)
self.get_success(self.db.runInteraction("test_single_instance", _insert))
def test_empty(self):
"""Test an ID generator against an empty database gives sensible
current positions.
"""
id_gen = self._create_id_generator()
# The table is empty so we expect an empty map for positions
self.assertEqual(id_gen.get_positions(), {})
def test_single_instance(self):
"""Test that reads and writes from a single process are handled
correctly.
"""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
async def _get_next_async():
with await id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token("master"), 8)
def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled
correctly.
"""
self._insert_rows("first", 3)
self._insert_rows("second", 4)
first_id_gen = self._create_id_generator("first")
second_id_gen = self._create_id_generator("second")
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token("first"), 3)
self.assertEqual(first_id_gen.get_current_token("second"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
async def _get_next_async():
with await first_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 8)
self.assertEqual(
first_id_gen.get_positions(), {"first": 3, "second": 7}
)
self.get_success(_get_next_async())
self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
# However the ID gen on the second instance won't have seen the update
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
# ... but calling `get_next` on the second instance should give a unique
# stream ID
async def _get_next_async():
with await second_id_gen.get_next() as stream_id:
self.assertEqual(stream_id, 9)
self.assertEqual(
second_id_gen.get_positions(), {"first": 3, "second": 7}
)
self.get_success(_get_next_async())
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
# If the second ID gen gets told about the first, it correctly updates
second_id_gen.advance("first", 8)
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
def test_get_next_txn(self):
"""Test that the `get_next_txn` function works correctly.
"""
# Prefill table with 7 rows written by 'master'
self._insert_rows("master", 7)
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
def _get_next_txn(txn):
stream_id = id_gen.get_next_txn(txn)
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token("master"), 7)
self.get_success(self.db.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token("master"), 8)
+11 -3
View File
@@ -19,6 +19,7 @@ import json
from mock import Mock
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.rest.client.v2_alpha import register, sync
@@ -45,11 +46,17 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.hs.config.limit_usage_by_mau = True
self.hs.config.hs_disabled = False
self.hs.config.max_mau_value = 2
self.hs.config.mau_trial_days = 0
self.hs.config.server_notices_mxid = "@server:red"
self.hs.config.server_notices_mxid_display_name = None
self.hs.config.server_notices_mxid_avatar_url = None
self.hs.config.server_notices_room_name = "Test Server Notice Room"
self.hs.config.mau_trial_days = 0
# AuthBlocking reads config options during hs creation. Recreate the
# hs' copy of AuthBlocking after we've updated config values above
self.auth_blocking = AuthBlocking(self.hs)
self.hs.get_auth()._auth_blocking = self.auth_blocking
return self.hs
def test_simple_deny_mau(self):
@@ -121,6 +128,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_trial_users_cant_come_back(self):
self.auth_blocking._mau_trial_days = 1
self.hs.config.mau_trial_days = 1
# We should be able to register more than the limit initially
@@ -169,8 +177,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_tracked_but_not_limited(self):
self.hs.config.max_mau_value = 1 # should not matter
self.hs.config.limit_usage_by_mau = False
self.auth_blocking._max_mau_value = 1 # should not matter
self.auth_blocking._limit_usage_by_mau = False
self.hs.config.mau_stats_only = True
# Simply being able to create 2 users indicates that the
+20 -6
View File
@@ -14,12 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Collection
from tests.test_utils import get_awaitable_result
@@ -75,6 +76,23 @@ def inject_event(
"""
test_reactor = hs.get_reactor()
event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
d = hs.get_storage().persistence.persist_event(event, context)
test_reactor.advance(0)
get_awaitable_result(d)
return event
def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
test_reactor = hs.get_reactor()
if room_version is None:
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
test_reactor.advance(0)
@@ -89,8 +107,4 @@ def inject_event(
test_reactor.advance(0)
event, context = get_awaitable_result(d)
d = hs.get_storage().persistence.persist_event(event, context)
test_reactor.advance(0)
get_awaitable_result(d)
return event
return event, context
+2 -5
View File
@@ -181,11 +181,7 @@ commands = mypy \
synapse/appservice \
synapse/config \
synapse/events/spamcheck.py \
synapse/federation/federation_base.py \
synapse/federation/federation_client.py \
synapse/federation/federation_server.py \
synapse/federation/sender \
synapse/federation/transport \
synapse/federation \
synapse/handlers/auth.py \
synapse/handlers/cas_handler.py \
synapse/handlers/directory.py \
@@ -203,6 +199,7 @@ commands = mypy \
synapse/storage/data_stores/main/ui_auth.py \
synapse/storage/database.py \
synapse/storage/engines \
synapse/storage/util \
synapse/streams \
synapse/util/caches/stream_change_cache.py \
tests/replication/tcp/streams \