Compare commits
46 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 88d8a7cd19 | |||
| 14a9d59edf | |||
| fac30c0554 | |||
| 908d15e904 | |||
| 0a1ebe5442 | |||
| 9316091c8c | |||
| 7d54f2413f | |||
| 98a6910d94 | |||
| c08273f529 | |||
| edca4cf1bb | |||
| 9d7fcb2fda | |||
| 1786b1ee0d | |||
| 2929ce29d6 | |||
| 62ee862119 | |||
| aee9130a83 | |||
| fa0b2bd28d | |||
| 16b67c404d | |||
| db5f9031b7 | |||
| 2e0c46ca07 | |||
| 79007a42b2 | |||
| 30a19daa02 | |||
| e48361545d | |||
| 0f6ebf393d | |||
| b26f3e582c | |||
| c255b0ffdc | |||
| a8c17da245 | |||
| 1242267316 | |||
| 7bf788ac73 | |||
| 7f7eedbebb | |||
| 5b8023dc7f | |||
| d78265af0c | |||
| 13dd458b8d | |||
| 714560e325 | |||
| 79fe3e068b | |||
| f9073893af | |||
| 16b1a34e80 | |||
| fe69fb6263 | |||
| 7941a70fa8 | |||
| d5aa7d93ed | |||
| 8123b2f909 | |||
| 15aa09bbe6 | |||
| 9858d5c362 | |||
| ad088716bc | |||
| 068da604c2 | |||
| 350421e058 | |||
| a251e0f4ba |
+6
-5
@@ -30,23 +30,24 @@ recursive-include synapse/static *.gif
|
||||
recursive-include synapse/static *.html
|
||||
recursive-include synapse/static *.js
|
||||
|
||||
exclude Dockerfile
|
||||
exclude .codecov.yml
|
||||
exclude .coveragerc
|
||||
exclude .dockerignore
|
||||
exclude test_postgresql.sh
|
||||
exclude .editorconfig
|
||||
exclude Dockerfile
|
||||
exclude mypy.ini
|
||||
exclude sytest-blacklist
|
||||
exclude test_postgresql.sh
|
||||
|
||||
include pyproject.toml
|
||||
recursive-include changelog.d *
|
||||
|
||||
prune .buildkite
|
||||
prune .circleci
|
||||
prune .codecov.yml
|
||||
prune .coveragerc
|
||||
prune .github
|
||||
prune contrib
|
||||
prune debian
|
||||
prune demo/etc
|
||||
prune docker
|
||||
prune mypy.ini
|
||||
prune snap
|
||||
prune stubs
|
||||
|
||||
@@ -1 +1 @@
|
||||
Add typing information to federation server code.
|
||||
Add typing annotations in `synapse.federation`.
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
Add MultiWriterIdGenerator to support multiple concurrent writers of streams.
|
||||
@@ -0,0 +1 @@
|
||||
Move catchup of replication streams logic to worker.
|
||||
@@ -0,0 +1 @@
|
||||
Add typing annotations in `synapse.federation`.
|
||||
@@ -0,0 +1 @@
|
||||
Fix bug in `EventContext.deserialize`.
|
||||
@@ -0,0 +1 @@
|
||||
Update docker runtime image to Alpine v3.11. Contributed by @Starbix.
|
||||
@@ -0,0 +1 @@
|
||||
Add support for running replication over Redis when using workers.
|
||||
@@ -0,0 +1 @@
|
||||
Fix issues with the Python package manifest.
|
||||
@@ -0,0 +1 @@
|
||||
Clean up some LoggingContext code.
|
||||
@@ -0,0 +1 @@
|
||||
Prevent methods in `synapse.handlers.auth` from polling the homeserver config every request.
|
||||
@@ -0,0 +1 @@
|
||||
Move catchup of replication streams logic to worker.
|
||||
@@ -0,0 +1 @@
|
||||
Speed up fetching device lists changes when handling `/sync` requests.
|
||||
@@ -0,0 +1 @@
|
||||
Add support for running replication over Redis when using workers.
|
||||
@@ -0,0 +1 @@
|
||||
Improve performance of `get_e2e_cross_signing_key`.
|
||||
@@ -0,0 +1 @@
|
||||
Improve performance of `mark_as_sent_devices_by_remote`.
|
||||
+1
-1
@@ -55,7 +55,7 @@ RUN pip install --prefix="/install" --no-warn-script-location \
|
||||
### Stage 1: runtime
|
||||
###
|
||||
|
||||
FROM docker.io/python:${PYTHON_VERSION}-alpine3.10
|
||||
FROM docker.io/python:${PYTHON_VERSION}-alpine3.11
|
||||
|
||||
# xmlsec is required for saml support
|
||||
RUN apk add --no-cache --virtual .runtime_deps \
|
||||
|
||||
@@ -22,7 +22,10 @@ class RedisProtocol:
|
||||
def publish(self, channel: str, message: bytes): ...
|
||||
|
||||
class SubscriberProtocol:
|
||||
password: Optional[str]
|
||||
def subscribe(self, channels: Union[str, List[str]]): ...
|
||||
def connectionMade(self): ...
|
||||
def connectionLost(self, reason): ...
|
||||
|
||||
def lazyConnection(
|
||||
host: str = ...,
|
||||
|
||||
+10
-73
@@ -26,16 +26,15 @@ from twisted.internet import defer
|
||||
import synapse.logging.opentracing as opentracing
|
||||
import synapse.types
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes, LimitBlockingTypes, Membership, UserTypes
|
||||
from synapse.api.auth_blocking import AuthBlocking
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
InvalidClientTokenError,
|
||||
MissingClientTokenError,
|
||||
ResourceLimitError,
|
||||
)
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.config.server import is_threepid_reserved
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import StateMap, UserID
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
|
||||
@@ -77,7 +76,11 @@ class Auth(object):
|
||||
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
|
||||
register_cache("cache", "token_cache", self.token_cache)
|
||||
|
||||
self._auth_blocking = AuthBlocking(self.hs)
|
||||
|
||||
self._account_validity = hs.config.account_validity
|
||||
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
|
||||
self._macaroon_secret_key = hs.config.macaroon_secret_key
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
|
||||
@@ -191,7 +194,7 @@ class Auth(object):
|
||||
opentracing.set_tag("authenticated_entity", user_id)
|
||||
opentracing.set_tag("appservice_id", app_service.id)
|
||||
|
||||
if ip_addr and self.hs.config.track_appservice_user_ips:
|
||||
if ip_addr and self._track_appservice_user_ips:
|
||||
yield self.store.insert_client_ip(
|
||||
user_id=user_id,
|
||||
access_token=access_token,
|
||||
@@ -454,7 +457,7 @@ class Auth(object):
|
||||
# access_tokens include a nonce for uniqueness: any value is acceptable
|
||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||
|
||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||
v.verify(macaroon, self._macaroon_secret_key)
|
||||
|
||||
def _verify_expiry(self, caveat):
|
||||
prefix = "time < "
|
||||
@@ -663,71 +666,5 @@ class Auth(object):
|
||||
% (user_id, room_id),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
|
||||
"""Checks if the user should be rejected for some external reason,
|
||||
such as monthly active user limiting or global disable flag
|
||||
|
||||
Args:
|
||||
user_id(str|None): If present, checks for presence against existing
|
||||
MAU cohort
|
||||
|
||||
threepid(dict|None): If present, checks for presence against configured
|
||||
reserved threepid. Used in cases where the user is trying register
|
||||
with a MAU blocked server, normally they would be rejected but their
|
||||
threepid is on the reserved list. user_id and
|
||||
threepid should never be set at the same time.
|
||||
|
||||
user_type(str|None): If present, is used to decide whether to check against
|
||||
certain blocking reasons like MAU.
|
||||
"""
|
||||
|
||||
# Never fail an auth check for the server notices users or support user
|
||||
# This can be a problem where event creation is prohibited due to blocking
|
||||
if user_id is not None:
|
||||
if user_id == self.hs.config.server_notices_mxid:
|
||||
return
|
||||
if (yield self.store.is_support_user(user_id)):
|
||||
return
|
||||
|
||||
if self.hs.config.hs_disabled:
|
||||
raise ResourceLimitError(
|
||||
403,
|
||||
self.hs.config.hs_disabled_message,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
|
||||
admin_contact=self.hs.config.admin_contact,
|
||||
limit_type=LimitBlockingTypes.HS_DISABLED,
|
||||
)
|
||||
if self.hs.config.limit_usage_by_mau is True:
|
||||
assert not (user_id and threepid)
|
||||
|
||||
# If the user is already part of the MAU cohort or a trial user
|
||||
if user_id:
|
||||
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
|
||||
if timestamp:
|
||||
return
|
||||
|
||||
is_trial = yield self.store.is_trial_user(user_id)
|
||||
if is_trial:
|
||||
return
|
||||
elif threepid:
|
||||
# If the user does not exist yet, but is signing up with a
|
||||
# reserved threepid then pass auth check
|
||||
if is_threepid_reserved(
|
||||
self.hs.config.mau_limits_reserved_threepids, threepid
|
||||
):
|
||||
return
|
||||
elif user_type == UserTypes.SUPPORT:
|
||||
# If the user does not exist yet and is of type "support",
|
||||
# allow registration. Support users are excluded from MAU checks.
|
||||
return
|
||||
# Else if there is no room in the MAU bucket, bail
|
||||
current_mau = yield self.store.get_monthly_active_count()
|
||||
if current_mau >= self.hs.config.max_mau_value:
|
||||
raise ResourceLimitError(
|
||||
403,
|
||||
"Monthly Active User Limit Exceeded",
|
||||
admin_contact=self.hs.config.admin_contact,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
|
||||
limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER,
|
||||
)
|
||||
def check_auth_blocking(self, *args, **kwargs):
|
||||
return self._auth_blocking.check_auth_blocking(*args, **kwargs)
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import LimitBlockingTypes, UserTypes
|
||||
from synapse.api.errors import Codes, ResourceLimitError
|
||||
from synapse.config.server import is_threepid_reserved
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthBlocking(object):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||
self._hs_disabled = hs.config.hs_disabled
|
||||
self._hs_disabled_message = hs.config.hs_disabled_message
|
||||
self._admin_contact = hs.config.admin_contact
|
||||
self._max_mau_value = hs.config.max_mau_value
|
||||
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
|
||||
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
|
||||
"""Checks if the user should be rejected for some external reason,
|
||||
such as monthly active user limiting or global disable flag
|
||||
|
||||
Args:
|
||||
user_id(str|None): If present, checks for presence against existing
|
||||
MAU cohort
|
||||
|
||||
threepid(dict|None): If present, checks for presence against configured
|
||||
reserved threepid. Used in cases where the user is trying register
|
||||
with a MAU blocked server, normally they would be rejected but their
|
||||
threepid is on the reserved list. user_id and
|
||||
threepid should never be set at the same time.
|
||||
|
||||
user_type(str|None): If present, is used to decide whether to check against
|
||||
certain blocking reasons like MAU.
|
||||
"""
|
||||
|
||||
# Never fail an auth check for the server notices users or support user
|
||||
# This can be a problem where event creation is prohibited due to blocking
|
||||
if user_id is not None:
|
||||
if user_id == self._server_notices_mxid:
|
||||
return
|
||||
if (yield self.store.is_support_user(user_id)):
|
||||
return
|
||||
|
||||
if self._hs_disabled:
|
||||
raise ResourceLimitError(
|
||||
403,
|
||||
self._hs_disabled_message,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
|
||||
admin_contact=self._admin_contact,
|
||||
limit_type=LimitBlockingTypes.HS_DISABLED,
|
||||
)
|
||||
if self._limit_usage_by_mau is True:
|
||||
assert not (user_id and threepid)
|
||||
|
||||
# If the user is already part of the MAU cohort or a trial user
|
||||
if user_id:
|
||||
timestamp = yield self.store.user_last_seen_monthly_active(user_id)
|
||||
if timestamp:
|
||||
return
|
||||
|
||||
is_trial = yield self.store.is_trial_user(user_id)
|
||||
if is_trial:
|
||||
return
|
||||
elif threepid:
|
||||
# If the user does not exist yet, but is signing up with a
|
||||
# reserved threepid then pass auth check
|
||||
if is_threepid_reserved(self._mau_limits_reserved_threepids, threepid):
|
||||
return
|
||||
elif user_type == UserTypes.SUPPORT:
|
||||
# If the user does not exist yet and is of type "support",
|
||||
# allow registration. Support users are excluded from MAU checks.
|
||||
return
|
||||
# Else if there is no room in the MAU bucket, bail
|
||||
current_mau = yield self.store.get_monthly_active_count()
|
||||
if current_mau >= self._max_mau_value:
|
||||
raise ResourceLimitError(
|
||||
403,
|
||||
"Monthly Active User Limit Exceeded",
|
||||
admin_contact=self._admin_contact,
|
||||
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
|
||||
limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER,
|
||||
)
|
||||
@@ -322,11 +322,14 @@ class _AsyncEventContextImpl(EventContext):
|
||||
self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
|
||||
self.state_group
|
||||
)
|
||||
if self._prev_state_id and self._event_state_key is not None:
|
||||
if self._event_state_key is not None:
|
||||
self._prev_state_ids = dict(self._current_state_ids)
|
||||
|
||||
key = (self._event_type, self._event_state_key)
|
||||
self._prev_state_ids[key] = self._prev_state_id
|
||||
if self._prev_state_id:
|
||||
self._prev_state_ids[key] = self._prev_state_id
|
||||
else:
|
||||
self._prev_state_ids.pop(key, None)
|
||||
else:
|
||||
self._prev_state_ids = self._current_state_ids
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ Events are replicated via a separate events stream.
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Dict, List, Tuple, Type
|
||||
|
||||
from six import iteritems
|
||||
|
||||
@@ -56,21 +57,35 @@ class FederationRemoteSendQueue(object):
|
||||
self.notifier = hs.get_notifier()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
self.presence_map = {} # Pending presence map user_id -> UserPresenceState
|
||||
self.presence_changed = SortedDict() # Stream position -> list[user_id]
|
||||
# Pending presence map user_id -> UserPresenceState
|
||||
self.presence_map = {} # type: Dict[str, UserPresenceState]
|
||||
|
||||
# Stream position -> list[user_id]
|
||||
self.presence_changed = SortedDict() # type: SortedDict[int, List[str]]
|
||||
|
||||
# Stores the destinations we need to explicitly send presence to about a
|
||||
# given user.
|
||||
# Stream position -> (user_id, destinations)
|
||||
self.presence_destinations = SortedDict()
|
||||
self.presence_destinations = (
|
||||
SortedDict()
|
||||
) # type: SortedDict[int, Tuple[str, List[str]]]
|
||||
|
||||
self.keyed_edu = {} # (destination, key) -> EDU
|
||||
self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
|
||||
# (destination, key) -> EDU
|
||||
self.keyed_edu = {} # type: Dict[Tuple[str, tuple], Edu]
|
||||
|
||||
self.edus = SortedDict() # stream position -> Edu
|
||||
# stream position -> (destination, key)
|
||||
self.keyed_edu_changed = (
|
||||
SortedDict()
|
||||
) # type: SortedDict[int, Tuple[str, tuple]]
|
||||
|
||||
self.edus = SortedDict() # type: SortedDict[int, Edu]
|
||||
|
||||
# stream ID for the next entry into presence_changed/keyed_edu_changed/edus.
|
||||
self.pos = 1
|
||||
self.pos_time = SortedDict()
|
||||
|
||||
# map from stream ID to the time that stream entry was generated, so that we
|
||||
# can clear out entries after a while
|
||||
self.pos_time = SortedDict() # type: SortedDict[int, int]
|
||||
|
||||
# EVERYTHING IS SAD. In particular, python only makes new scopes when
|
||||
# we make a new function, so we need to make a new function so the inner
|
||||
@@ -158,8 +173,10 @@ class FederationRemoteSendQueue(object):
|
||||
for edu_key in self.keyed_edu_changed.values():
|
||||
live_keys.add(edu_key)
|
||||
|
||||
to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
|
||||
for edu_key in to_del:
|
||||
keys_to_del = [
|
||||
edu_key for edu_key in self.keyed_edu if edu_key not in live_keys
|
||||
]
|
||||
for edu_key in keys_to_del:
|
||||
del self.keyed_edu[edu_key]
|
||||
|
||||
# Delete things out of edu map
|
||||
@@ -250,19 +267,23 @@ class FederationRemoteSendQueue(object):
|
||||
self._clear_queue_before_pos(token)
|
||||
|
||||
async def get_replication_rows(
|
||||
self, from_token, to_token, limit, federation_ack=None
|
||||
):
|
||||
self, instance_name: str, from_token: int, to_token: int, target_row_count: int
|
||||
) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
|
||||
"""Get rows to be sent over federation between the two tokens
|
||||
|
||||
Args:
|
||||
from_token (int)
|
||||
to_token(int)
|
||||
limit (int)
|
||||
federation_ack (int): Optional. The position where the worker is
|
||||
explicitly acknowledged it has handled. Allows us to drop
|
||||
data from before that point
|
||||
instance_name: the name of the current process
|
||||
from_token: the previous stream token: the starting point for fetching the
|
||||
updates
|
||||
to_token: the new stream token: the point to get updates up to
|
||||
target_row_count: a target for the number of rows to be returned.
|
||||
|
||||
Returns: a triplet `(updates, new_last_token, limited)`, where:
|
||||
* `updates` is a list of `(token, row)` entries.
|
||||
* `new_last_token` is the new position in stream.
|
||||
* `limited` is whether there are more updates to fetch.
|
||||
"""
|
||||
# TODO: Handle limit.
|
||||
# TODO: Handle target_row_count.
|
||||
|
||||
# To handle restarts where we wrap around
|
||||
if from_token > self.pos:
|
||||
@@ -270,12 +291,7 @@ class FederationRemoteSendQueue(object):
|
||||
|
||||
# list of tuple(int, BaseFederationRow), where the first is the position
|
||||
# of the federation stream.
|
||||
rows = []
|
||||
|
||||
# There should be only one reader, so lets delete everything its
|
||||
# acknowledged its seen.
|
||||
if federation_ack:
|
||||
self._clear_queue_before_pos(federation_ack)
|
||||
rows = [] # type: List[Tuple[int, BaseFederationRow]]
|
||||
|
||||
# Fetch changed presence
|
||||
i = self.presence_changed.bisect_right(from_token)
|
||||
@@ -332,7 +348,11 @@ class FederationRemoteSendQueue(object):
|
||||
# Sort rows based on pos
|
||||
rows.sort()
|
||||
|
||||
return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
|
||||
return (
|
||||
[(pos, (row.TypeId, row.to_data())) for pos, row in rows],
|
||||
to_token,
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
class BaseFederationRow(object):
|
||||
@@ -341,7 +361,7 @@ class BaseFederationRow(object):
|
||||
Specifies how to identify, serialize and deserialize the different types.
|
||||
"""
|
||||
|
||||
TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
|
||||
TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
|
||||
|
||||
@staticmethod
|
||||
def from_data(data):
|
||||
@@ -454,10 +474,14 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
|
||||
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
|
||||
|
||||
|
||||
TypeToRow = {
|
||||
Row.TypeId: Row
|
||||
for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
|
||||
}
|
||||
_rowtypes = (
|
||||
PresenceRow,
|
||||
PresenceDestinationsRow,
|
||||
KeyedEduRow,
|
||||
EduRow,
|
||||
) # type: Tuple[Type[BaseFederationRow], ...]
|
||||
|
||||
TypeToRow = {Row.TypeId: Row for Row in _rowtypes}
|
||||
|
||||
|
||||
ParsedFederationStreamData = namedtuple(
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Hashable, Iterable, List, Optional, Set
|
||||
from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from six import itervalues
|
||||
|
||||
@@ -498,14 +498,16 @@ class FederationSender(object):
|
||||
|
||||
self._get_per_destination_queue(destination).attempt_new_transaction()
|
||||
|
||||
def get_current_token(self) -> int:
|
||||
@staticmethod
|
||||
def get_current_token() -> int:
|
||||
# Dummy implementation for case where federation sender isn't offloaded
|
||||
# to a worker.
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def get_replication_rows(
|
||||
self, from_token, to_token, limit, federation_ack=None
|
||||
):
|
||||
instance_name: str, from_token: int, to_token: int, target_row_count: int
|
||||
) -> Tuple[List[Tuple[int, Tuple]], int, bool]:
|
||||
# Dummy implementation for case where federation sender isn't offloaded
|
||||
# to a worker.
|
||||
return []
|
||||
return [], 0, False
|
||||
|
||||
@@ -15,11 +15,10 @@
|
||||
# limitations under the License.
|
||||
import datetime
|
||||
import logging
|
||||
from typing import Dict, Hashable, Iterable, List, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
import synapse.server
|
||||
from synapse.api.errors import (
|
||||
FederationDeniedError,
|
||||
HttpResponseException,
|
||||
@@ -34,6 +33,9 @@ from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import ReadReceipt
|
||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import synapse.server
|
||||
|
||||
# This is defined in the Matrix spec and enforced by the receiver.
|
||||
MAX_EDUS_PER_TRANSACTION = 100
|
||||
|
||||
|
||||
@@ -13,11 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
import synapse.server
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.events import EventBase
|
||||
from synapse.federation.persistence import TransactionActions
|
||||
@@ -31,6 +30,9 @@ from synapse.logging.opentracing import (
|
||||
)
|
||||
from synapse.util.metrics import measure_func
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import synapse.server
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -122,17 +122,17 @@ class DeviceWorkerHandler(BaseHandler):
|
||||
|
||||
# First we check if any devices have changed for users that we share
|
||||
# rooms with.
|
||||
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||
users_who_share_room = yield defer.ensureDeferred(self.store.get_users_who_share_room_with_user(
|
||||
user_id
|
||||
)
|
||||
))
|
||||
|
||||
tracked_users = set(users_who_share_room)
|
||||
#tracked_users = set(users_who_share_room)
|
||||
|
||||
# Always tell the user about their own devices
|
||||
tracked_users.add(user_id)
|
||||
#tracked_users.add(user_id)
|
||||
|
||||
changed = yield self.store.get_users_whose_devices_changed(
|
||||
from_token.device_list_key, tracked_users
|
||||
from_token.device_list_key, users_who_share_room
|
||||
)
|
||||
|
||||
# Then work out if any users have since joined
|
||||
@@ -444,9 +444,10 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
"""Notify that a user's device(s) has changed. Pokes the notifier, and
|
||||
remote servers if the user is local.
|
||||
"""
|
||||
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||
logger.info("get_users_who_share_room... called from notify_device_update")
|
||||
users_who_share_room = yield defer.ensureDeferred(self.store.get_users_who_share_room_with_user(
|
||||
user_id
|
||||
)
|
||||
))
|
||||
|
||||
hosts = set()
|
||||
if self.hs.is_mine_id(user_id):
|
||||
|
||||
@@ -172,6 +172,7 @@ class EventHandler(BaseHandler):
|
||||
if not event:
|
||||
return None
|
||||
|
||||
logger.info("get_users_in_room called from get_event!")
|
||||
users = await self.store.get_users_in_room(event.room_id)
|
||||
is_peeking = user.to_string() not in users
|
||||
|
||||
|
||||
@@ -1547,7 +1547,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
async def do_remotely_reject_invite(
|
||||
self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict
|
||||
) -> Optional[EventBase]:
|
||||
) -> EventBase:
|
||||
origin, event, room_version = await self._make_and_verify_event(
|
||||
target_hosts, room_id, user_id, "leave", content=content
|
||||
)
|
||||
@@ -1564,26 +1564,12 @@ class FederationHandler(BaseHandler):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await self.federation_client.send_leave(target_hosts, event)
|
||||
return event
|
||||
except Exception as e:
|
||||
# if we were unable to reject the exception, just mark
|
||||
# it as rejected on our end and plough ahead.
|
||||
#
|
||||
# The 'except' clause is very broad, but we need to
|
||||
# capture everything from DNS failures upwards
|
||||
#
|
||||
logger.warning("Failed to reject invite: %s", e)
|
||||
await self.federation_client.send_leave(target_hosts, event)
|
||||
|
||||
await self.store.locally_reject_invite(user_id, room_id)
|
||||
return None
|
||||
finally:
|
||||
# This block will always run before returning, and will return with
|
||||
# whatever value was returned in the try/except blocks
|
||||
# (it will not, for example, be over-written by None)
|
||||
context = await self.state_handler.compute_event_context(event)
|
||||
await self.persist_events_and_notify([(event, context)])
|
||||
context = await self.state_handler.compute_event_context(event)
|
||||
await self.persist_events_and_notify([(event, context)])
|
||||
|
||||
return event
|
||||
|
||||
async def _make_and_verify_event(
|
||||
self,
|
||||
|
||||
@@ -1098,6 +1098,7 @@ class PresenceEventSource(object):
|
||||
users_interested_in = set()
|
||||
users_interested_in.add(user_id) # So that we receive our own presence
|
||||
|
||||
logger.info("get_users_who_share_room... _get_interested_in")
|
||||
users_who_share_room = await self.store.get_users_who_share_room_with_user(
|
||||
user_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
@@ -976,12 +976,25 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
||||
):
|
||||
"""Implements RoomMemberHandler._remote_reject_invite
|
||||
"""
|
||||
ret = yield defer.ensureDeferred(
|
||||
self.federation_handler.do_remotely_reject_invite(
|
||||
remote_room_hosts, room_id, target.to_string(), content=content,
|
||||
fed_handler = self.federation_handler
|
||||
try:
|
||||
ret = yield defer.ensureDeferred(
|
||||
fed_handler.do_remotely_reject_invite(
|
||||
remote_room_hosts, room_id, target.to_string(), content=content,
|
||||
)
|
||||
)
|
||||
)
|
||||
return ret if ret else {}
|
||||
return ret
|
||||
except Exception as e:
|
||||
# if we were unable to reject the exception, just mark
|
||||
# it as rejected on our end and plough ahead.
|
||||
#
|
||||
# The 'except' clause is very broad, but we need to
|
||||
# capture everything from DNS failures upwards
|
||||
#
|
||||
logger.warning("Failed to reject invite: %s", e)
|
||||
|
||||
yield self.store.locally_reject_invite(target.to_string(), room_id)
|
||||
return {}
|
||||
|
||||
def _user_joined_room(self, target, room_id):
|
||||
"""Implements RoomMemberHandler._user_joined_room
|
||||
|
||||
@@ -308,12 +308,14 @@ class SyncHandler(object):
|
||||
if timeout == 0 or since_token is None or full_state:
|
||||
# we are going to return immediately, so don't bother calling
|
||||
# notifier.wait_for_events.
|
||||
logger.info("_wait_for_sync_for_user1")
|
||||
result = await self.current_sync_for_user(
|
||||
sync_config, since_token, full_state=full_state
|
||||
)
|
||||
else:
|
||||
|
||||
def current_sync_callback(before_token, after_token):
|
||||
logger.info("_wait_for_sync_for_user2")
|
||||
return self.current_sync_for_user(sync_config, since_token)
|
||||
|
||||
result = await self.notifier.wait_for_events(
|
||||
@@ -340,6 +342,7 @@ class SyncHandler(object):
|
||||
) -> SyncResult:
|
||||
"""Get the sync for client needed to match what the server has now.
|
||||
"""
|
||||
logger.info("current_sync_for_user")
|
||||
return await self.generate_sync_result(sync_config, since_token, full_state)
|
||||
|
||||
async def push_rules_for_user(self, user: UserID) -> JsonDict:
|
||||
@@ -1139,18 +1142,24 @@ class SyncHandler(object):
|
||||
# room with by looking at all users that have left a room plus users
|
||||
# that were in a room we've left.
|
||||
|
||||
logger.info("get_users_who_share_room... called from _generate_sync_entry")
|
||||
logger.info("*Called with %s", user_id)
|
||||
users_who_share_room = await self.store.get_users_who_share_room_with_user(
|
||||
user_id
|
||||
)
|
||||
|
||||
tracked_users = set(users_who_share_room)
|
||||
# Always tell the user about their own devices. We check as the user
|
||||
# ID is almost certainly already included (unless they're not in any
|
||||
# rooms) and taking a copy of the set is relatively expensive.
|
||||
#if user_id not in users_who_share_room:
|
||||
# users_who_share_room = set(users_who_share_room)
|
||||
# users_who_share_room.add(user_id)
|
||||
|
||||
# Always tell the user about their own devices
|
||||
tracked_users.add(user_id)
|
||||
#tracked_users = users_who_share_room
|
||||
|
||||
# Step 1a, check for changes in devices of users we share a room with
|
||||
users_that_have_changed = await self.store.get_users_whose_devices_changed(
|
||||
since_token.device_list_key, tracked_users
|
||||
since_token.device_list_key, users_who_share_room
|
||||
)
|
||||
|
||||
# Step 1b, check for newly joined rooms
|
||||
|
||||
@@ -27,6 +27,7 @@ import inspect
|
||||
import logging
|
||||
import threading
|
||||
import types
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
|
||||
|
||||
from typing_extensions import Literal
|
||||
@@ -287,6 +288,46 @@ class LoggingContext(object):
|
||||
return str(self.request)
|
||||
return "%s@%x" % (self.name, id(self))
|
||||
|
||||
@classmethod
|
||||
def current_context(cls) -> LoggingContextOrSentinel:
|
||||
"""Get the current logging context from thread local storage
|
||||
|
||||
This exists for backwards compatibility. ``current_context()`` should be
|
||||
called directly.
|
||||
|
||||
Returns:
|
||||
LoggingContext: the current logging context
|
||||
"""
|
||||
warnings.warn(
|
||||
"synapse.logging.context.LoggingContext.current_context() is deprecated "
|
||||
"in favor of synapse.logging.context.current_context().",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return current_context()
|
||||
|
||||
@classmethod
|
||||
def set_current_context(
|
||||
cls, context: LoggingContextOrSentinel
|
||||
) -> LoggingContextOrSentinel:
|
||||
"""Set the current logging context in thread local storage
|
||||
|
||||
This exists for backwards compatibility. ``set_current_context()`` should be
|
||||
called directly.
|
||||
|
||||
Args:
|
||||
context(LoggingContext): The context to activate.
|
||||
Returns:
|
||||
The context that was previously active
|
||||
"""
|
||||
warnings.warn(
|
||||
"synapse.logging.context.LoggingContext.set_current_context() is deprecated "
|
||||
"in favor of synapse.logging.context.set_current_context().",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return set_current_context(context)
|
||||
|
||||
def __enter__(self) -> "LoggingContext":
|
||||
"""Enters this logging context into thread local storage"""
|
||||
old_context = set_current_context(self)
|
||||
|
||||
@@ -135,10 +135,24 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
||||
|
||||
logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
|
||||
|
||||
event = await self.federation_handler.do_remotely_reject_invite(
|
||||
remote_room_hosts, room_id, user_id, event_content,
|
||||
)
|
||||
return 200, event.get_pdu_json() if event else 200, {}
|
||||
try:
|
||||
event = await self.federation_handler.do_remotely_reject_invite(
|
||||
remote_room_hosts, room_id, user_id, event_content,
|
||||
)
|
||||
ret = event.get_pdu_json()
|
||||
except Exception as e:
|
||||
# if we were unable to reject the exception, just mark
|
||||
# it as rejected on our end and plough ahead.
|
||||
#
|
||||
# The 'except' clause is very broad, but we need to
|
||||
# capture everything from DNS failures upwards
|
||||
#
|
||||
logger.warning("Failed to reject invite: %s", e)
|
||||
|
||||
await self.store.locally_reject_invite(user_id, room_id)
|
||||
ret = {}
|
||||
|
||||
return 200, ret
|
||||
|
||||
|
||||
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
|
||||
|
||||
@@ -135,6 +135,7 @@ class SlavedEventStore(
|
||||
)
|
||||
|
||||
if data.type == EventTypes.Member:
|
||||
logger.info("INVALIDATING get_rooms_for_user_with_stream_ordering")
|
||||
self.get_rooms_for_user_with_stream_ordering.invalidate(
|
||||
(data.state_key,)
|
||||
)
|
||||
|
||||
@@ -81,9 +81,6 @@ class ReplicationCommandHandler:
|
||||
self._instance_id = hs.get_instance_id()
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
# Set of streams that we've caught up with.
|
||||
self._streams_connected = set() # type: Set[str]
|
||||
|
||||
self._streams = {
|
||||
stream.NAME: stream(hs) for stream in STREAMS_MAP.values()
|
||||
} # type: Dict[str, Stream]
|
||||
@@ -99,9 +96,13 @@ class ReplicationCommandHandler:
|
||||
# The factory used to create connections.
|
||||
self._factory = None # type: Optional[ReconnectingClientFactory]
|
||||
|
||||
# The currently connected connections.
|
||||
# The currently connected connections. (The list of places we need to send
|
||||
# outgoing replication commands to.)
|
||||
self._connections = [] # type: List[AbstractConnection]
|
||||
|
||||
# For each connection, the incoming streams that are coming from that connection
|
||||
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
|
||||
|
||||
LaterGauge(
|
||||
"synapse_replication_tcp_resource_total_connections",
|
||||
"",
|
||||
@@ -257,12 +258,14 @@ class ReplicationCommandHandler:
|
||||
# 2. so we don't race with getting a POSITION command and fetching
|
||||
# missing RDATA.
|
||||
with await self._position_linearizer.queue(cmd.stream_name):
|
||||
if stream_name not in self._streams_connected:
|
||||
# If the stream isn't marked as connected then we haven't seen a
|
||||
# `POSITION` command yet, and so we may have missed some rows.
|
||||
# make sure that we've processed a POSITION for this stream *on this
|
||||
# connection*. (A POSITION on another connection is no good, as there
|
||||
# is no guarantee that we have seen all the intermediate updates.)
|
||||
sbc = self._streams_by_connection.get(conn)
|
||||
if not sbc or stream_name not in sbc:
|
||||
# Let's drop the row for now, on the assumption we'll receive a
|
||||
# `POSITION` soon and we'll catch up correctly then.
|
||||
logger.warning(
|
||||
logger.debug(
|
||||
"Discarding RDATA for unconnected stream %s -> %s",
|
||||
stream_name,
|
||||
cmd.token,
|
||||
@@ -302,21 +305,25 @@ class ReplicationCommandHandler:
|
||||
# Ignore POSITION that are just our own echoes
|
||||
return
|
||||
|
||||
stream = self._streams.get(cmd.stream_name)
|
||||
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
|
||||
|
||||
stream_name = cmd.stream_name
|
||||
stream = self._streams.get(stream_name)
|
||||
if not stream:
|
||||
logger.error("Got POSITION for unknown stream: %s", cmd.stream_name)
|
||||
logger.error("Got POSITION for unknown stream: %s", stream_name)
|
||||
return
|
||||
|
||||
# We protect catching up with a linearizer in case the replication
|
||||
# connection reconnects under us.
|
||||
with await self._position_linearizer.queue(cmd.stream_name):
|
||||
with await self._position_linearizer.queue(stream_name):
|
||||
# We're about to go and catch up with the stream, so remove from set
|
||||
# of connected streams.
|
||||
self._streams_connected.discard(cmd.stream_name)
|
||||
for streams in self._streams_by_connection.values():
|
||||
streams.discard(stream_name)
|
||||
|
||||
# We clear the pending batches for the stream as the fetching of the
|
||||
# missing updates below will fetch all rows in the batch.
|
||||
self._pending_batches.pop(cmd.stream_name, [])
|
||||
self._pending_batches.pop(stream_name, [])
|
||||
|
||||
# Find where we previously streamed up to.
|
||||
current_token = stream.current_token()
|
||||
@@ -326,6 +333,12 @@ class ReplicationCommandHandler:
|
||||
# between then and now.
|
||||
missing_updates = cmd.token != current_token
|
||||
while missing_updates:
|
||||
logger.info(
|
||||
"Fetching replication rows for '%s' between %i and %i",
|
||||
stream_name,
|
||||
current_token,
|
||||
cmd.token,
|
||||
)
|
||||
(
|
||||
updates,
|
||||
current_token,
|
||||
@@ -341,16 +354,18 @@ class ReplicationCommandHandler:
|
||||
|
||||
for token, rows in _batch_updates(updates):
|
||||
await self.on_rdata(
|
||||
cmd.stream_name,
|
||||
stream_name,
|
||||
cmd.instance_name,
|
||||
token,
|
||||
[stream.parse_row(row) for row in rows],
|
||||
)
|
||||
|
||||
# We've now caught up to position sent to us, notify handler.
|
||||
await self._replication_data_handler.on_position(cmd.stream_name, cmd.token)
|
||||
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
|
||||
|
||||
self._streams_connected.add(cmd.stream_name)
|
||||
# We've now caught up to position sent to us, notify handler.
|
||||
await self._replication_data_handler.on_position(stream_name, cmd.token)
|
||||
|
||||
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
|
||||
|
||||
async def on_REMOTE_SERVER_UP(
|
||||
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
|
||||
@@ -408,6 +423,12 @@ class ReplicationCommandHandler:
|
||||
def lost_connection(self, connection: AbstractConnection):
|
||||
"""Called when a connection is closed/lost.
|
||||
"""
|
||||
# we no longer need _streams_by_connection for this connection.
|
||||
streams = self._streams_by_connection.pop(connection, None)
|
||||
if streams:
|
||||
logger.info(
|
||||
"Lost replication connection; streams now disconnected: %s", streams
|
||||
)
|
||||
try:
|
||||
self._connections.remove(connection)
|
||||
except ValueError:
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import txredisapi
|
||||
|
||||
from synapse.logging.context import PreserveLoggingContext
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.commands import (
|
||||
Command,
|
||||
@@ -41,8 +41,14 @@ logger = logging.getLogger(__name__)
|
||||
class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
"""Connection to redis subscribed to replication stream.
|
||||
|
||||
Parses incoming messages from redis into replication commands, and passes
|
||||
them to `ReplicationCommandHandler`
|
||||
This class fulfils two functions:
|
||||
|
||||
(a) it implements the twisted Protocol API, where it handles the SUBSCRIBEd redis
|
||||
connection, parsing *incoming* messages into replication commands, and passing them
|
||||
to `ReplicationCommandHandler`
|
||||
|
||||
(b) it implements the AbstractConnection API, where it sends *outgoing* commands
|
||||
onto outbound_redis_connection.
|
||||
|
||||
Due to the vagaries of `txredisapi` we don't want to have a custom
|
||||
constructor, so instead we expect the defined attributes below to be set
|
||||
@@ -50,8 +56,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
|
||||
Attributes:
|
||||
handler: The command handler to handle incoming commands.
|
||||
stream_name: The *redis* stream name to subscribe to (not anything to
|
||||
do with Synapse replication streams).
|
||||
stream_name: The *redis* stream name to subscribe to and publish from
|
||||
(not anything to do with Synapse replication streams).
|
||||
outbound_redis_connection: The connection to redis to use to send
|
||||
commands.
|
||||
"""
|
||||
@@ -61,12 +67,23 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
outbound_redis_connection = None # type: txredisapi.RedisProtocol
|
||||
|
||||
def connectionMade(self):
|
||||
logger.info("Connected to redis instance")
|
||||
self.subscribe(self.stream_name)
|
||||
self.send_command(ReplicateCommand())
|
||||
|
||||
logger.info("Connected to redis")
|
||||
super().connectionMade()
|
||||
run_as_background_process("subscribe-replication", self._send_subscribe)
|
||||
self.handler.new_connection(self)
|
||||
|
||||
async def _send_subscribe(self):
|
||||
# it's important to make sure that we only send the REPLICATE command once we
|
||||
# have successfully subscribed to the stream - otherwise we might miss the
|
||||
# POSITION response sent back by the other end.
|
||||
logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
|
||||
await make_deferred_yieldable(self.subscribe(self.stream_name))
|
||||
logger.info(
|
||||
"Successfully subscribed to redis stream, sending REPLICATE command"
|
||||
)
|
||||
await self._async_send_command(ReplicateCommand())
|
||||
logger.info("REPLICATE successfully sent")
|
||||
|
||||
def messageReceived(self, pattern: str, channel: str, message: str):
|
||||
"""Received a message from redis.
|
||||
"""
|
||||
@@ -119,7 +136,8 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
logger.warning("Unhandled command: %r", cmd)
|
||||
|
||||
def connectionLost(self, reason):
|
||||
logger.info("Lost connection to redis instance")
|
||||
logger.info("Lost connection to redis")
|
||||
super().connectionLost(reason)
|
||||
self.handler.lost_connection(self)
|
||||
|
||||
def send_command(self, cmd: Command):
|
||||
@@ -128,6 +146,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
Args:
|
||||
cmd (Command)
|
||||
"""
|
||||
run_as_background_process("send-cmd", self._async_send_command, cmd)
|
||||
|
||||
async def _async_send_command(self, cmd: Command):
|
||||
"""Encode a replication command and send it over our outbound connection"""
|
||||
string = "%s %s" % (cmd.NAME, cmd.to_line())
|
||||
if "\n" in string:
|
||||
raise Exception("Unexpected newline in command: %r", string)
|
||||
@@ -138,15 +160,9 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
# remote instances.
|
||||
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
|
||||
|
||||
async def _send():
|
||||
with PreserveLoggingContext():
|
||||
# Note that we use the other connection as we can't send
|
||||
# commands using the subscription connection.
|
||||
await self.outbound_redis_connection.publish(
|
||||
self.stream_name, encoded_string
|
||||
)
|
||||
|
||||
run_as_background_process("send-cmd", _send)
|
||||
await make_deferred_yieldable(
|
||||
self.outbound_redis_connection.publish(self.stream_name, encoded_string)
|
||||
)
|
||||
|
||||
|
||||
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
|
||||
@@ -189,5 +205,6 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
|
||||
p.handler = self.handler
|
||||
p.outbound_redis_connection = self.outbound_redis_connection
|
||||
p.stream_name = self.stream_name
|
||||
p.password = self.password
|
||||
|
||||
return p
|
||||
|
||||
@@ -80,7 +80,7 @@ class ReplicationStreamer(object):
|
||||
for stream in STREAMS_MAP.values():
|
||||
if stream == FederationStream and hs.config.send_federation:
|
||||
# We only support federation stream if federation sending
|
||||
# hase been disabled on the master.
|
||||
# has been disabled on the master.
|
||||
continue
|
||||
|
||||
self.streams.append(stream(hs))
|
||||
|
||||
@@ -104,7 +104,8 @@ class Stream(object):
|
||||
implemented by subclasses.
|
||||
|
||||
current_token_function is called to get the current token of the underlying
|
||||
stream.
|
||||
stream. It is only meaningful on the process that is the source of the
|
||||
replication stream (ie, usually the master).
|
||||
|
||||
update_function is called to get updates for this stream between a pair of
|
||||
stream tokens. See the UpdateFunction type definition for more info.
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
from collections import namedtuple
|
||||
|
||||
from synapse.replication.tcp.streams._base import Stream, db_query_to_update_function
|
||||
from synapse.replication.tcp.streams._base import Stream, make_http_update_function
|
||||
|
||||
|
||||
class FederationStream(Stream):
|
||||
@@ -35,21 +35,33 @@ class FederationStream(Stream):
|
||||
ROW_TYPE = FederationStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
# Not all synapse instances will have a federation sender instance,
|
||||
# whether that's a `FederationSender` or a `FederationRemoteSendQueue`,
|
||||
# so we stub the stream out when that is the case.
|
||||
if hs.config.worker_app is None or hs.should_send_federation():
|
||||
if hs.config.worker_app is None:
|
||||
# master process: get updates from the FederationRemoteSendQueue.
|
||||
# (if the master is configured to send federation itself, federation_sender
|
||||
# will be a real FederationSender, which has stubs for current_token and
|
||||
# get_replication_rows.)
|
||||
federation_sender = hs.get_federation_sender()
|
||||
current_token = federation_sender.get_current_token
|
||||
update_function = db_query_to_update_function(
|
||||
federation_sender.get_replication_rows
|
||||
)
|
||||
update_function = federation_sender.get_replication_rows
|
||||
|
||||
elif hs.should_send_federation():
|
||||
# federation sender: Query master process
|
||||
update_function = make_http_update_function(hs, self.NAME)
|
||||
current_token = self._stub_current_token
|
||||
|
||||
else:
|
||||
current_token = lambda: 0
|
||||
# other worker: stub out the update function (we're not interested in
|
||||
# any updates so when we get a POSITION we do nothing)
|
||||
update_function = self._stub_update_function
|
||||
current_token = self._stub_current_token
|
||||
|
||||
super().__init__(hs.get_instance_name(), current_token, update_function)
|
||||
|
||||
@staticmethod
|
||||
def _stub_current_token():
|
||||
# dummy current-token method for use on workers
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
async def _stub_update_function(instance_name, from_token, upto_token, limit):
|
||||
return [], upto_token, False
|
||||
|
||||
@@ -171,6 +171,7 @@ class SyncRestServlet(RestServlet):
|
||||
user.to_string(), affect_presence=affect_presence
|
||||
)
|
||||
with context:
|
||||
logger.info("Sync servlet")
|
||||
sync_result = await self.sync_handler.wait_for_sync_for_user(
|
||||
sync_config,
|
||||
since_token=since_token,
|
||||
|
||||
@@ -55,6 +55,10 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
|
||||
|
||||
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
|
||||
|
||||
BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX = (
|
||||
"drop_device_lists_outbound_last_success_non_unique_idx"
|
||||
)
|
||||
|
||||
|
||||
class DeviceWorkerStore(SQLBaseStore):
|
||||
def get_device(self, user_id, device_id):
|
||||
@@ -342,32 +346,23 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
|
||||
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
|
||||
# We update the device_lists_outbound_last_success with the successfully
|
||||
# poked users. We do the join to see which users need to be inserted and
|
||||
# which updated.
|
||||
# poked users.
|
||||
sql = """
|
||||
SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
|
||||
SELECT user_id, coalesce(max(o.stream_id), 0)
|
||||
FROM device_lists_outbound_pokes as o
|
||||
LEFT JOIN device_lists_outbound_last_success as s
|
||||
USING (destination, user_id)
|
||||
WHERE destination = ? AND o.stream_id <= ?
|
||||
GROUP BY user_id
|
||||
"""
|
||||
txn.execute(sql, (destination, stream_id))
|
||||
rows = txn.fetchall()
|
||||
|
||||
sql = """
|
||||
UPDATE device_lists_outbound_last_success
|
||||
SET stream_id = ?
|
||||
WHERE destination = ? AND user_id = ?
|
||||
"""
|
||||
txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
|
||||
|
||||
sql = """
|
||||
INSERT INTO device_lists_outbound_last_success
|
||||
(destination, user_id, stream_id) VALUES (?, ?, ?)
|
||||
"""
|
||||
txn.executemany(
|
||||
sql, ((destination, row[0], row[1]) for row in rows if not row[2])
|
||||
self.db.simple_upsert_many_txn(
|
||||
txn=txn,
|
||||
table="device_lists_outbound_last_success",
|
||||
key_names=("destination", "user_id"),
|
||||
key_values=((destination, user_id) for user_id, _ in rows),
|
||||
value_names=("stream_id",),
|
||||
value_values=((stream_id,) for _, stream_id in rows),
|
||||
)
|
||||
|
||||
# Delete all sent outbound pokes
|
||||
@@ -541,8 +536,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
|
||||
# Get set of users who *may* have changed. Users not in the returned
|
||||
# list have definitely not changed.
|
||||
to_check = list(
|
||||
self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
|
||||
to_check = self._device_list_stream_cache.get_entities_changed(
|
||||
user_ids, from_key
|
||||
)
|
||||
|
||||
if not to_check:
|
||||
@@ -725,6 +720,21 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||
BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
|
||||
)
|
||||
|
||||
# create a unique index on device_lists_outbound_last_success
|
||||
self.db.updates.register_background_index_update(
|
||||
"device_lists_outbound_last_success_unique_idx",
|
||||
index_name="device_lists_outbound_last_success_unique_idx",
|
||||
table="device_lists_outbound_last_success",
|
||||
columns=["destination", "user_id"],
|
||||
unique=True,
|
||||
)
|
||||
|
||||
# once that completes, we can remove the old non-unique index.
|
||||
self.db.updates.register_background_update_handler(
|
||||
BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX,
|
||||
self._drop_device_lists_outbound_last_success_non_unique_idx,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
|
||||
def f(conn):
|
||||
@@ -799,6 +809,20 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
return rows
|
||||
|
||||
async def _drop_device_lists_outbound_last_success_non_unique_idx(
|
||||
self, progress, batch_size
|
||||
):
|
||||
def f(txn):
|
||||
txn.execute("DROP INDEX IF EXISTS device_lists_outbound_last_success_idx")
|
||||
|
||||
await self.db.runInteraction(
|
||||
"drop_device_lists_outbound_last_success_non_unique_idx", f,
|
||||
)
|
||||
await self.db.updates._end_background_update(
|
||||
BG_UPDATE_DROP_DEVICE_LISTS_OUTBOUND_LAST_SUCCESS_NON_UNIQUE_IDX
|
||||
)
|
||||
return 1
|
||||
|
||||
|
||||
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
|
||||
@@ -25,7 +25,9 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import make_in_list_sql_clause
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
|
||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
@@ -268,53 +270,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||
)
|
||||
|
||||
def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
|
||||
"""Returns a user's cross-signing key.
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
user_id (str): the user whose key is being requested
|
||||
key_type (str): the type of key that is being requested: either 'master'
|
||||
for a master key, 'self_signing' for a self-signing key, or
|
||||
'user_signing' for a user-signing key
|
||||
from_user_id (str): if specified, signatures made by this user on
|
||||
the key will be included in the result
|
||||
|
||||
Returns:
|
||||
dict of the key data or None if not found
|
||||
"""
|
||||
sql = (
|
||||
"SELECT keydata "
|
||||
" FROM e2e_cross_signing_keys "
|
||||
" WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
|
||||
)
|
||||
txn.execute(sql, (user_id, key_type))
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
key = json.loads(row[0])
|
||||
|
||||
device_id = None
|
||||
for k in key["keys"].values():
|
||||
device_id = k
|
||||
|
||||
if from_user_id is not None:
|
||||
sql = (
|
||||
"SELECT key_id, signature "
|
||||
" FROM e2e_cross_signing_signatures "
|
||||
" WHERE user_id = ? "
|
||||
" AND target_user_id = ? "
|
||||
" AND target_device_id = ? "
|
||||
)
|
||||
txn.execute(sql, (from_user_id, user_id, device_id))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
key.setdefault("signatures", {}).setdefault(from_user_id, {})[
|
||||
row[0]
|
||||
] = row[1]
|
||||
|
||||
return key
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
|
||||
"""Returns a user's cross-signing key.
|
||||
|
||||
@@ -329,13 +285,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
Returns:
|
||||
dict of the key data or None if not found
|
||||
"""
|
||||
return self.db.runInteraction(
|
||||
"get_e2e_cross_signing_key",
|
||||
self._get_e2e_cross_signing_key_txn,
|
||||
user_id,
|
||||
key_type,
|
||||
from_user_id,
|
||||
)
|
||||
res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
|
||||
user_keys = res.get(user_id)
|
||||
if not user_keys:
|
||||
return None
|
||||
return user_keys.get(key_type)
|
||||
|
||||
@cached(num_args=1)
|
||||
def _get_bare_e2e_cross_signing_keys(self, user_id):
|
||||
@@ -391,26 +345,24 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
result = {}
|
||||
|
||||
batch_size = 100
|
||||
chunks = [
|
||||
user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
|
||||
]
|
||||
for user_chunk in chunks:
|
||||
sql = """
|
||||
for user_chunk in batch_iter(user_ids, 100):
|
||||
clause, params = make_in_list_sql_clause(
|
||||
txn.database_engine, "k.user_id", user_chunk
|
||||
)
|
||||
sql = (
|
||||
"""
|
||||
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
|
||||
FROM e2e_cross_signing_keys k
|
||||
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
|
||||
FROM e2e_cross_signing_keys
|
||||
GROUP BY user_id, keytype) s
|
||||
USING (user_id, stream_id, keytype)
|
||||
WHERE k.user_id IN (%s)
|
||||
""" % (
|
||||
",".join("?" for u in user_chunk),
|
||||
WHERE
|
||||
"""
|
||||
+ clause
|
||||
)
|
||||
query_params = []
|
||||
query_params.extend(user_chunk)
|
||||
|
||||
txn.execute(sql, query_params)
|
||||
txn.execute(sql, params)
|
||||
rows = self.db.cursor_to_dict(txn)
|
||||
|
||||
for row in rows:
|
||||
@@ -453,15 +405,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
device_id = k
|
||||
devices[(user_id, device_id)] = key_type
|
||||
|
||||
device_list = list(devices)
|
||||
|
||||
# split into batches
|
||||
batch_size = 100
|
||||
chunks = [
|
||||
device_list[i : i + batch_size]
|
||||
for i in range(0, len(device_list), batch_size)
|
||||
]
|
||||
for user_chunk in chunks:
|
||||
for batch in batch_iter(devices.keys(), size=100):
|
||||
sql = """
|
||||
SELECT target_user_id, target_device_id, key_id, signature
|
||||
FROM e2e_cross_signing_signatures
|
||||
@@ -469,11 +413,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
AND (%s)
|
||||
""" % (
|
||||
" OR ".join(
|
||||
"(target_user_id = ? AND target_device_id = ?)" for d in devices
|
||||
"(target_user_id = ? AND target_device_id = ?)" for _ in batch
|
||||
)
|
||||
)
|
||||
query_params = [from_user_id]
|
||||
for item in devices:
|
||||
for item in batch:
|
||||
# item is a (user_id, device_id) tuple
|
||||
query_params.extend(item)
|
||||
|
||||
|
||||
@@ -605,6 +605,7 @@ class EventsStore(
|
||||
}
|
||||
|
||||
for member in members_changed:
|
||||
logger.info("INVALIDATING get_rooms_for_user_with_stream_ordering")
|
||||
txn.call_after(
|
||||
self.get_rooms_for_user_with_stream_ordering.invalidate, (member,)
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from typing import Iterable, List, Set
|
||||
|
||||
from six import iteritems, itervalues
|
||||
@@ -163,8 +164,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
|
||||
return hosts
|
||||
|
||||
@cached(max_entries=100000, iterable=True)
|
||||
@cached(max_entries=1000000, iterable=True)
|
||||
def get_users_in_room(self, room_id):
|
||||
logger.info("Traceback: %s", traceback.format_stack())
|
||||
return self.db.runInteraction(
|
||||
"get_users_in_room", self.get_users_in_room_txn, room_id
|
||||
)
|
||||
@@ -484,17 +486,18 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
)
|
||||
return frozenset(r.room_id for r in rooms)
|
||||
|
||||
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
|
||||
def get_users_who_share_room_with_user(self, user_id, cache_context):
|
||||
@cached(max_entries=500000, cache_context=True, iterable=True)
|
||||
async def get_users_who_share_room_with_user(self, user_id, cache_context):
|
||||
"""Returns the set of users who share a room with `user_id`
|
||||
"""
|
||||
room_ids = yield self.get_rooms_for_user(
|
||||
logger.info("Called with %s %s", user_id, cache_context)
|
||||
room_ids = await self.get_rooms_for_user(
|
||||
user_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
user_who_share_room = set()
|
||||
for room_id in room_ids:
|
||||
user_ids = yield self.get_users_in_room(
|
||||
user_ids = await self.get_users_in_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
user_who_share_room.update(user_ids)
|
||||
|
||||
+28
@@ -0,0 +1,28 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- register a background update which will create a unique index on
|
||||
-- device_lists_outbound_last_success
|
||||
INSERT into background_updates (ordering, update_name, progress_json)
|
||||
VALUES (5804, 'device_lists_outbound_last_success_unique_idx', '{}');
|
||||
|
||||
-- once that completes, we can drop the old index.
|
||||
INSERT into background_updates (ordering, update_name, progress_json, depends_on)
|
||||
VALUES (
|
||||
5804,
|
||||
'drop_device_lists_outbound_last_success_non_unique_idx',
|
||||
'{}',
|
||||
'device_lists_outbound_last_success_unique_idx'
|
||||
);
|
||||
+44
-30
@@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.background_updates import BackgroundUpdater
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.types import Connection, Cursor
|
||||
from synapse.types import Collection
|
||||
from synapse.util.stringutils import exception_to_unicode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -78,6 +79,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
|
||||
"device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
|
||||
"device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
|
||||
"event_search": "event_search_event_id_idx",
|
||||
"device_lists_outbound_last_success": "device_lists_outbound_last_success_unique_idx",
|
||||
}
|
||||
|
||||
|
||||
@@ -889,20 +891,24 @@ class Database(object):
|
||||
txn.execute(sql, list(allvalues.values()))
|
||||
|
||||
def simple_upsert_many_txn(
|
||||
self, txn, table, key_names, key_values, value_names, value_values
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
key_names: Collection[str],
|
||||
key_values: Collection[Iterable[Any]],
|
||||
value_names: Collection[str],
|
||||
value_values: Iterable[Iterable[str]],
|
||||
) -> None:
|
||||
"""
|
||||
Upsert, many times.
|
||||
|
||||
Args:
|
||||
table (str): The table to upsert into
|
||||
key_names (list[str]): The key column names.
|
||||
key_values (list[list]): A list of each row's key column values.
|
||||
value_names (list[str]): The value column names. If empty, no
|
||||
values will be used, even if value_values is provided.
|
||||
value_values (list[list]): A list of each row's value column values.
|
||||
Returns:
|
||||
None
|
||||
table: The table to upsert into
|
||||
key_names: The key column names.
|
||||
key_values: A list of each row's key column values.
|
||||
value_names: The value column names
|
||||
value_values: A list of each row's value column values.
|
||||
Ignored if value_names is empty.
|
||||
"""
|
||||
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
|
||||
return self.simple_upsert_many_txn_native_upsert(
|
||||
@@ -914,20 +920,24 @@ class Database(object):
|
||||
)
|
||||
|
||||
def simple_upsert_many_txn_emulated(
|
||||
self, txn, table, key_names, key_values, value_names, value_values
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
key_names: Iterable[str],
|
||||
key_values: Collection[Iterable[Any]],
|
||||
value_names: Collection[str],
|
||||
value_values: Iterable[Iterable[str]],
|
||||
) -> None:
|
||||
"""
|
||||
Upsert, many times, but without native UPSERT support or batching.
|
||||
|
||||
Args:
|
||||
table (str): The table to upsert into
|
||||
key_names (list[str]): The key column names.
|
||||
key_values (list[list]): A list of each row's key column values.
|
||||
value_names (list[str]): The value column names. If empty, no
|
||||
values will be used, even if value_values is provided.
|
||||
value_values (list[list]): A list of each row's value column values.
|
||||
Returns:
|
||||
None
|
||||
table: The table to upsert into
|
||||
key_names: The key column names.
|
||||
key_values: A list of each row's key column values.
|
||||
value_names: The value column names
|
||||
value_values: A list of each row's value column values.
|
||||
Ignored if value_names is empty.
|
||||
"""
|
||||
# No value columns, therefore make a blank list so that the following
|
||||
# zip() works correctly.
|
||||
@@ -941,20 +951,24 @@ class Database(object):
|
||||
self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
|
||||
|
||||
def simple_upsert_many_txn_native_upsert(
|
||||
self, txn, table, key_names, key_values, value_names, value_values
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
key_names: Collection[str],
|
||||
key_values: Collection[Iterable[Any]],
|
||||
value_names: Collection[str],
|
||||
value_values: Iterable[Iterable[Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Upsert, many times, using batching where possible.
|
||||
|
||||
Args:
|
||||
table (str): The table to upsert into
|
||||
key_names (list[str]): The key column names.
|
||||
key_values (list[list]): A list of each row's key column values.
|
||||
value_names (list[str]): The value column names. If empty, no
|
||||
values will be used, even if value_values is provided.
|
||||
value_values (list[list]): A list of each row's value column values.
|
||||
Returns:
|
||||
None
|
||||
table: The table to upsert into
|
||||
key_names: The key column names.
|
||||
key_values: A list of each row's key column values.
|
||||
value_names: The value column names
|
||||
value_values: A list of each row's value column values.
|
||||
Ignored if value_names is empty.
|
||||
"""
|
||||
allnames = [] # type: List[str]
|
||||
allnames.extend(key_names)
|
||||
|
||||
@@ -16,6 +16,11 @@
|
||||
import contextlib
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Dict, Set, Tuple
|
||||
|
||||
from typing_extensions import Deque
|
||||
|
||||
from synapse.storage.database import Database, LoggingTransaction
|
||||
|
||||
|
||||
class IdGenerator(object):
|
||||
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
|
||||
self._current = (max if step > 0 else min)(
|
||||
self._current, _load_current_id(db_conn, table, column, step)
|
||||
)
|
||||
self._unfinished_ids = deque()
|
||||
self._unfinished_ids = deque() # type: Deque[int]
|
||||
|
||||
def get_next(self):
|
||||
"""
|
||||
@@ -163,7 +168,7 @@ class ChainedIdGenerator(object):
|
||||
self.chained_generator = chained_generator
|
||||
self._lock = threading.Lock()
|
||||
self._current_max = _load_current_id(db_conn, table, column)
|
||||
self._unfinished_ids = deque()
|
||||
self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
|
||||
|
||||
def get_next(self):
|
||||
"""
|
||||
@@ -198,3 +203,163 @@ class ChainedIdGenerator(object):
|
||||
return stream_id - 1, chained_id
|
||||
|
||||
return self._current_max, self.chained_generator.get_current_token()
|
||||
|
||||
|
||||
class MultiWriterIdGenerator:
|
||||
"""An ID generator that tracks a stream that can have multiple writers.
|
||||
|
||||
Uses a Postgres sequence to coordinate ID assignment, but positions of other
|
||||
writers will only get updated when `advance` is called (by replication).
|
||||
|
||||
Note: Only works with Postgres.
|
||||
|
||||
Args:
|
||||
db_conn
|
||||
db
|
||||
instance_name: The name of this instance.
|
||||
table: Database table associated with stream.
|
||||
instance_column: Column that stores the row's writer's instance name
|
||||
id_column: Column that stores the stream ID.
|
||||
sequence_name: The name of the postgres sequence used to generate new
|
||||
IDs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_conn,
|
||||
db: Database,
|
||||
instance_name: str,
|
||||
table: str,
|
||||
instance_column: str,
|
||||
id_column: str,
|
||||
sequence_name: str,
|
||||
):
|
||||
self._db = db
|
||||
self._instance_name = instance_name
|
||||
self._sequence_name = sequence_name
|
||||
|
||||
# We lock as some functions may be called from DB threads.
|
||||
self._lock = threading.Lock()
|
||||
|
||||
self._current_positions = self._load_current_ids(
|
||||
db_conn, table, instance_column, id_column
|
||||
)
|
||||
|
||||
# Set of local IDs that we're still processing. The current position
|
||||
# should be less than the minimum of this set (if not empty).
|
||||
self._unfinished_ids = set() # type: Set[int]
|
||||
|
||||
def _load_current_ids(
|
||||
self, db_conn, table: str, instance_column: str, id_column: str
|
||||
) -> Dict[str, int]:
|
||||
sql = """
|
||||
SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
|
||||
GROUP BY %(instance)s
|
||||
""" % {
|
||||
"instance": instance_column,
|
||||
"id": id_column,
|
||||
"table": table,
|
||||
}
|
||||
|
||||
cur = db_conn.cursor()
|
||||
cur.execute(sql)
|
||||
|
||||
# `cur` is an iterable over returned rows, which are 2-tuples.
|
||||
current_positions = dict(cur)
|
||||
|
||||
cur.close()
|
||||
|
||||
return current_positions
|
||||
|
||||
def _load_next_id_txn(self, txn):
|
||||
txn.execute("SELECT nextval(?)", (self._sequence_name,))
|
||||
(next_id,) = txn.fetchone()
|
||||
return next_id
|
||||
|
||||
async def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next() as stream_id:
|
||||
# ... persist event ...
|
||||
"""
|
||||
next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
|
||||
|
||||
# Assert the fetched ID is actually greater than what we currently
|
||||
# believe the ID to be. If not, then the sequence and table have got
|
||||
# out of sync somehow.
|
||||
assert self.get_current_token() < next_id
|
||||
|
||||
with self._lock:
|
||||
self._unfinished_ids.add(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_id
|
||||
finally:
|
||||
self._mark_id_as_finished(next_id)
|
||||
|
||||
return manager()
|
||||
|
||||
def get_next_txn(self, txn: LoggingTransaction):
|
||||
"""
|
||||
Usage:
|
||||
|
||||
stream_id = stream_id_gen.get_next(txn)
|
||||
# ... persist event ...
|
||||
"""
|
||||
|
||||
next_id = self._load_next_id_txn(txn)
|
||||
|
||||
with self._lock:
|
||||
self._unfinished_ids.add(next_id)
|
||||
|
||||
txn.call_after(self._mark_id_as_finished, next_id)
|
||||
txn.call_on_exception(self._mark_id_as_finished, next_id)
|
||||
|
||||
return next_id
|
||||
|
||||
def _mark_id_as_finished(self, next_id: int):
|
||||
"""The ID has finished being processed so we should advance the
|
||||
current poistion if possible.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
self._unfinished_ids.discard(next_id)
|
||||
|
||||
# Figure out if its safe to advance the position by checking there
|
||||
# aren't any lower allocated IDs that are yet to finish.
|
||||
if all(c > next_id for c in self._unfinished_ids):
|
||||
curr = self._current_positions.get(self._instance_name, 0)
|
||||
self._current_positions[self._instance_name] = max(curr, next_id)
|
||||
|
||||
def get_current_token(self, instance_name: str = None) -> int:
|
||||
"""Gets the current position of a named writer (defaults to current
|
||||
instance).
|
||||
|
||||
Returns 0 if we don't have a position for the named writer (likely due
|
||||
to it being a new writer).
|
||||
"""
|
||||
|
||||
if instance_name is None:
|
||||
instance_name = self._instance_name
|
||||
|
||||
with self._lock:
|
||||
return self._current_positions.get(instance_name, 0)
|
||||
|
||||
def get_positions(self) -> Dict[str, int]:
|
||||
"""Get a copy of the current positon map.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return dict(self._current_positions)
|
||||
|
||||
def advance(self, instance_name: str, new_id: int):
|
||||
"""Advance the postion of the named writer to the given ID, if greater
|
||||
than existing entry.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
self._current_positions[instance_name] = max(
|
||||
new_id, self._current_positions.get(instance_name, 0)
|
||||
)
|
||||
|
||||
@@ -14,12 +14,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Iterable, List, Mapping, Optional, Set
|
||||
from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union
|
||||
|
||||
from six import integer_types
|
||||
|
||||
from sortedcontainers import SortedDict
|
||||
|
||||
from synapse.types import Collection
|
||||
from synapse.util import caches
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -85,8 +86,8 @@ class StreamChangeCache:
|
||||
return False
|
||||
|
||||
def get_entities_changed(
|
||||
self, entities: Iterable[EntityType], stream_pos: int
|
||||
) -> Set[EntityType]:
|
||||
self, entities: Collection[EntityType], stream_pos: int
|
||||
) -> Union[Set[EntityType], FrozenSet[EntityType]]:
|
||||
"""
|
||||
Returns subset of entities that have had new things since the given
|
||||
position. Entities unknown to the cache will be returned. If the
|
||||
@@ -94,7 +95,17 @@ class StreamChangeCache:
|
||||
"""
|
||||
changed_entities = self.get_all_entities_changed(stream_pos)
|
||||
if changed_entities is not None:
|
||||
result = set(changed_entities).intersection(entities)
|
||||
# We now do an intersection, trying to do so in the most efficient
|
||||
# way possible (some of these sets are *large*). First check in the
|
||||
# given iterable is already set that we can reuse, otherwise we
|
||||
# create a set of the *smallest* of the two iterables and call
|
||||
# `intersection(..)` on it (this can be twice as fast as the reverse).
|
||||
if isinstance(entities, (set, frozenset)):
|
||||
result = entities.intersection(changed_entities)
|
||||
elif len(changed_entities) < len(entities):
|
||||
result = set(changed_entities).intersection(entities)
|
||||
else:
|
||||
result = set(entities).intersection(changed_entities)
|
||||
self.metrics.inc_hits()
|
||||
else:
|
||||
result = set(entities)
|
||||
|
||||
+20
-16
@@ -52,6 +52,10 @@ class AuthTestCase(unittest.TestCase):
|
||||
self.hs.handlers = TestHandlers(self.hs)
|
||||
self.auth = Auth(self.hs)
|
||||
|
||||
# AuthBlocking reads from the hs' config on initialization. We need to
|
||||
# modify its config instead of the hs'
|
||||
self.auth_blocking = self.auth._auth_blocking
|
||||
|
||||
self.test_user = "@foo:bar"
|
||||
self.test_token = b"_test_token_"
|
||||
|
||||
@@ -321,15 +325,15 @@ class AuthTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_blocking_mau(self):
|
||||
self.hs.config.limit_usage_by_mau = False
|
||||
self.hs.config.max_mau_value = 50
|
||||
self.auth_blocking._limit_usage_by_mau = False
|
||||
self.auth_blocking._max_mau_value = 50
|
||||
lots_of_users = 100
|
||||
small_number_of_users = 1
|
||||
|
||||
# Ensure no error thrown
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking())
|
||||
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
|
||||
self.store.get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(lots_of_users)
|
||||
@@ -349,8 +353,8 @@ class AuthTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_blocking_mau__depending_on_user_type(self):
|
||||
self.hs.config.max_mau_value = 50
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.auth_blocking._max_mau_value = 50
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
|
||||
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
|
||||
# Support users allowed
|
||||
@@ -370,12 +374,12 @@ class AuthTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_reserved_threepid(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.hs.config.max_mau_value = 1
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
self.auth_blocking._max_mau_value = 1
|
||||
self.store.get_monthly_active_count = lambda: defer.succeed(2)
|
||||
threepid = {"medium": "email", "address": "reserved@server.com"}
|
||||
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
|
||||
self.hs.config.mau_limits_reserved_threepids = [threepid]
|
||||
self.auth_blocking._mau_limits_reserved_threepids = [threepid]
|
||||
|
||||
with self.assertRaises(ResourceLimitError):
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking())
|
||||
@@ -389,8 +393,8 @@ class AuthTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_hs_disabled(self):
|
||||
self.hs.config.hs_disabled = True
|
||||
self.hs.config.hs_disabled_message = "Reason for being disabled"
|
||||
self.auth_blocking._hs_disabled = True
|
||||
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
|
||||
with self.assertRaises(ResourceLimitError) as e:
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking())
|
||||
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
|
||||
@@ -404,10 +408,10 @@ class AuthTestCase(unittest.TestCase):
|
||||
"""
|
||||
# this should be the default, but we had a bug where the test was doing the wrong
|
||||
# thing, so let's make it explicit
|
||||
self.hs.config.server_notices_mxid = None
|
||||
self.auth_blocking._server_notices_mxid = None
|
||||
|
||||
self.hs.config.hs_disabled = True
|
||||
self.hs.config.hs_disabled_message = "Reason for being disabled"
|
||||
self.auth_blocking._hs_disabled = True
|
||||
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
|
||||
with self.assertRaises(ResourceLimitError) as e:
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking())
|
||||
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
|
||||
@@ -416,8 +420,8 @@ class AuthTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_server_notices_mxid_special_cased(self):
|
||||
self.hs.config.hs_disabled = True
|
||||
self.auth_blocking._hs_disabled = True
|
||||
user = "@user:server"
|
||||
self.hs.config.server_notices_mxid = user
|
||||
self.hs.config.hs_disabled_message = "Reason for being disabled"
|
||||
self.auth_blocking._server_notices_mxid = user
|
||||
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
|
||||
yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client.v1 import login, room
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils.event_injection import create_event
|
||||
|
||||
|
||||
class TestEventContext(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.storage = hs.get_storage()
|
||||
|
||||
self.user_id = self.register_user("u1", "pass")
|
||||
self.user_tok = self.login("u1", "pass")
|
||||
self.room_id = self.helper.create_room_as(tok=self.user_tok)
|
||||
|
||||
def test_serialize_deserialize_msg(self):
|
||||
"""Test that an EventContext for a message event is the same after
|
||||
serialize/deserialize.
|
||||
"""
|
||||
|
||||
event, context = create_event(
|
||||
self.hs, room_id=self.room_id, type="m.test", sender=self.user_id,
|
||||
)
|
||||
|
||||
self._check_serialize_deserialize(event, context)
|
||||
|
||||
def test_serialize_deserialize_state_no_prev(self):
|
||||
"""Test that an EventContext for a state event (with not previous entry)
|
||||
is the same after serialize/deserialize.
|
||||
"""
|
||||
event, context = create_event(
|
||||
self.hs,
|
||||
room_id=self.room_id,
|
||||
type="m.test",
|
||||
sender=self.user_id,
|
||||
state_key="",
|
||||
)
|
||||
|
||||
self._check_serialize_deserialize(event, context)
|
||||
|
||||
def test_serialize_deserialize_state_prev(self):
|
||||
"""Test that an EventContext for a state event (which replaces a
|
||||
previous entry) is the same after serialize/deserialize.
|
||||
"""
|
||||
event, context = create_event(
|
||||
self.hs,
|
||||
room_id=self.room_id,
|
||||
type="m.room.member",
|
||||
sender=self.user_id,
|
||||
state_key=self.user_id,
|
||||
content={"membership": "leave"},
|
||||
)
|
||||
|
||||
self._check_serialize_deserialize(event, context)
|
||||
|
||||
def _check_serialize_deserialize(self, event, context):
|
||||
serialized = self.get_success(context.serialize(event, self.store))
|
||||
|
||||
d_context = EventContext.deserialize(self.storage, serialized)
|
||||
|
||||
self.assertEqual(context.state_group, d_context.state_group)
|
||||
self.assertEqual(context.rejected, d_context.rejected)
|
||||
self.assertEqual(
|
||||
context.state_group_before_event, d_context.state_group_before_event
|
||||
)
|
||||
self.assertEqual(context.prev_group, d_context.prev_group)
|
||||
self.assertEqual(context.delta_ids, d_context.delta_ids)
|
||||
self.assertEqual(context.app_service, d_context.app_service)
|
||||
|
||||
self.assertEqual(
|
||||
self.get_success(context.get_current_state_ids()),
|
||||
self.get_success(d_context.get_current_state_ids()),
|
||||
)
|
||||
self.assertEqual(
|
||||
self.get_success(context.get_prev_state_ids()),
|
||||
self.get_success(d_context.get_prev_state_ids()),
|
||||
)
|
||||
@@ -39,8 +39,13 @@ class AuthTestCase(unittest.TestCase):
|
||||
self.hs.handlers = AuthHandlers(self.hs)
|
||||
self.auth_handler = self.hs.handlers.auth_handler
|
||||
self.macaroon_generator = self.hs.get_macaroon_generator()
|
||||
|
||||
# MAU tests
|
||||
self.hs.config.max_mau_value = 50
|
||||
# AuthBlocking reads from the hs' config on initialization. We need to
|
||||
# modify its config instead of the hs'
|
||||
self.auth_blocking = self.hs.get_auth()._auth_blocking
|
||||
self.auth_blocking._max_mau_value = 50
|
||||
|
||||
self.small_number_of_users = 1
|
||||
self.large_number_of_users = 100
|
||||
|
||||
@@ -119,7 +124,7 @@ class AuthTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_mau_limits_disabled(self):
|
||||
self.hs.config.limit_usage_by_mau = False
|
||||
self.auth_blocking._limit_usage_by_mau = False
|
||||
# Ensure does not throw exception
|
||||
yield defer.ensureDeferred(
|
||||
self.auth_handler.get_access_token_for_user_id(
|
||||
@@ -135,7 +140,7 @@ class AuthTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_mau_limits_exceeded_large(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.large_number_of_users)
|
||||
)
|
||||
@@ -159,11 +164,11 @@ class AuthTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_mau_limits_parity(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
|
||||
# If not in monthly active cohort
|
||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.hs.config.max_mau_value)
|
||||
return_value=defer.succeed(self.auth_blocking._max_mau_value)
|
||||
)
|
||||
with self.assertRaises(ResourceLimitError):
|
||||
yield defer.ensureDeferred(
|
||||
@@ -173,7 +178,7 @@ class AuthTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.hs.config.max_mau_value)
|
||||
return_value=defer.succeed(self.auth_blocking._max_mau_value)
|
||||
)
|
||||
with self.assertRaises(ResourceLimitError):
|
||||
yield defer.ensureDeferred(
|
||||
@@ -186,7 +191,7 @@ class AuthTestCase(unittest.TestCase):
|
||||
return_value=defer.succeed(self.hs.get_clock().time_msec())
|
||||
)
|
||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.hs.config.max_mau_value)
|
||||
return_value=defer.succeed(self.auth_blocking._max_mau_value)
|
||||
)
|
||||
yield defer.ensureDeferred(
|
||||
self.auth_handler.get_access_token_for_user_id(
|
||||
@@ -197,7 +202,7 @@ class AuthTestCase(unittest.TestCase):
|
||||
return_value=defer.succeed(self.hs.get_clock().time_msec())
|
||||
)
|
||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.hs.config.max_mau_value)
|
||||
return_value=defer.succeed(self.auth_blocking._max_mau_value)
|
||||
)
|
||||
yield defer.ensureDeferred(
|
||||
self.auth_handler.validate_short_term_login_token_and_get_user_id(
|
||||
@@ -207,7 +212,7 @@ class AuthTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_mau_limits_not_exceeded(self):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
|
||||
self.hs.get_datastore().get_monthly_active_count = Mock(
|
||||
return_value=defer.succeed(self.small_number_of_users)
|
||||
|
||||
@@ -30,28 +30,31 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
||||
self.sync_handler = self.hs.get_sync_handler()
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
def test_wait_for_sync_for_user_auth_blocking(self):
|
||||
# AuthBlocking reads from the hs' config on initialization. We need to
|
||||
# modify its config instead of the hs'
|
||||
self.auth_blocking = self.hs.get_auth()._auth_blocking
|
||||
|
||||
def test_wait_for_sync_for_user_auth_blocking(self):
|
||||
user_id1 = "@user1:test"
|
||||
user_id2 = "@user2:test"
|
||||
sync_config = self._generate_sync_config(user_id1)
|
||||
|
||||
self.reactor.advance(100) # So we get not 0 time
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.hs.config.max_mau_value = 1
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
self.auth_blocking._max_mau_value = 1
|
||||
|
||||
# Check that the happy case does not throw errors
|
||||
self.get_success(self.store.upsert_monthly_active_user(user_id1))
|
||||
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
|
||||
|
||||
# Test that global lock works
|
||||
self.hs.config.hs_disabled = True
|
||||
self.auth_blocking._hs_disabled = True
|
||||
e = self.get_failure(
|
||||
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
|
||||
)
|
||||
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
self.hs.config.hs_disabled = False
|
||||
self.auth_blocking._hs_disabled = False
|
||||
|
||||
sync_config = self._generate_sync_config(user_id2)
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from synapse.app.generic_worker import (
|
||||
GenericWorkerServer,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.replication.http import streams
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
@@ -42,6 +43,10 @@ logger = logging.getLogger(__name__)
|
||||
class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
"""Base class for tests of the replication streams"""
|
||||
|
||||
servlets = [
|
||||
streams.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# build a replication server
|
||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
@@ -49,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
self.server = server_factory.buildProtocol(None)
|
||||
|
||||
# Make a new HomeServer object for the worker
|
||||
config = self.default_config()
|
||||
config["worker_app"] = "synapse.app.generic_worker"
|
||||
config["worker_replication_host"] = "testserv"
|
||||
config["worker_replication_http_port"] = "8765"
|
||||
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
|
||||
self.worker_hs = self.setup_test_homeserver(
|
||||
http_client=None,
|
||||
homeserverToUse=GenericWorkerServer,
|
||||
config=config,
|
||||
config=self._get_worker_hs_config(),
|
||||
reactor=self.reactor,
|
||||
)
|
||||
|
||||
@@ -78,6 +77,13 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
self._client_transport = None
|
||||
self._server_transport = None
|
||||
|
||||
def _get_worker_hs_config(self) -> dict:
|
||||
config = self.default_config()
|
||||
config["worker_app"] = "synapse.app.generic_worker"
|
||||
config["worker_replication_host"] = "testserv"
|
||||
config["worker_replication_http_port"] = "8765"
|
||||
return config
|
||||
|
||||
def _build_replication_data_handler(self):
|
||||
return TestReplicationDataHandler(self.worker_hs)
|
||||
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.federation.send_queue import EduRow
|
||||
from synapse.replication.tcp.streams.federation import FederationStream
|
||||
|
||||
from tests.replication.tcp.streams._base import BaseStreamTestCase
|
||||
|
||||
|
||||
class FederationStreamTestCase(BaseStreamTestCase):
|
||||
def _get_worker_hs_config(self) -> dict:
|
||||
# enable federation sending on the worker
|
||||
config = super()._get_worker_hs_config()
|
||||
# TODO: make it so we don't need both of these
|
||||
config["send_federation"] = True
|
||||
config["worker_app"] = "synapse.app.federation_sender"
|
||||
return config
|
||||
|
||||
def test_catchup(self):
|
||||
"""Basic test of catchup on reconnect
|
||||
|
||||
Makes sure that updates sent while we are offline are received later.
|
||||
"""
|
||||
fed_sender = self.hs.get_federation_sender()
|
||||
received_rows = self.test_handler.received_rdata_rows
|
||||
|
||||
fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"})
|
||||
|
||||
self.reconnect()
|
||||
self.reactor.advance(0)
|
||||
|
||||
# check we're testing what we think we are: no rows should yet have been
|
||||
# received
|
||||
self.assertEqual(received_rows, [])
|
||||
|
||||
# We should now see an attempt to connect to the master
|
||||
request = self.handle_http_replication_attempt()
|
||||
self.assert_request_is_get_repl_stream_updates(request, "federation")
|
||||
|
||||
# we should have received an update row
|
||||
stream_name, token, row = received_rows.pop()
|
||||
self.assertEqual(stream_name, "federation")
|
||||
self.assertIsInstance(row, FederationStream.FederationStreamRow)
|
||||
self.assertEqual(row.type, EduRow.TypeId)
|
||||
edurow = EduRow.from_data(row.data)
|
||||
self.assertEqual(edurow.edu.edu_type, "m.test_edu")
|
||||
self.assertEqual(edurow.edu.origin, self.hs.hostname)
|
||||
self.assertEqual(edurow.edu.destination, "testdest")
|
||||
self.assertEqual(edurow.edu.content, {"a": "b"})
|
||||
|
||||
self.assertEqual(received_rows, [])
|
||||
|
||||
# additional updates should be transferred without an HTTP hit
|
||||
fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"})
|
||||
self.reactor.advance(0)
|
||||
# there should be no http hit
|
||||
self.assertEqual(len(self.reactor.tcpClients), 0)
|
||||
# ... but we should have a row
|
||||
self.assertEqual(len(received_rows), 1)
|
||||
|
||||
stream_name, token, row = received_rows.pop()
|
||||
self.assertEqual(stream_name, "federation")
|
||||
self.assertIsInstance(row, FederationStream.FederationStreamRow)
|
||||
self.assertEqual(row.type, EduRow.TypeId)
|
||||
edurow = EduRow.from_data(row.data)
|
||||
self.assertEqual(edurow.edu.edu_type, "m.test1")
|
||||
self.assertEqual(edurow.edu.origin, self.hs.hostname)
|
||||
self.assertEqual(edurow.edu.destination, "testdest")
|
||||
self.assertEqual(edurow.edu.content, {"c": "d"})
|
||||
@@ -15,7 +15,6 @@
|
||||
from mock import Mock
|
||||
|
||||
from synapse.handlers.typing import RoomMember
|
||||
from synapse.replication.http import streams
|
||||
from synapse.replication.tcp.streams import TypingStream
|
||||
|
||||
from tests.replication.tcp.streams._base import BaseStreamTestCase
|
||||
@@ -24,10 +23,6 @@ USER_ID = "@feeling:blue"
|
||||
|
||||
|
||||
class TypingStreamTestCase(BaseStreamTestCase):
|
||||
servlets = [
|
||||
streams.register_servlets,
|
||||
]
|
||||
|
||||
def _build_replication_data_handler(self):
|
||||
return Mock(wraps=super()._build_replication_data_handler())
|
||||
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
from tests.utils import USE_POSTGRES_FOR_TESTS
|
||||
|
||||
|
||||
class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
if not USE_POSTGRES_FOR_TESTS:
|
||||
skip = "Requires Postgres"
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.db = self.store.db # type: Database
|
||||
|
||||
self.get_success(self.db.runInteraction("_setup_db", self._setup_db))
|
||||
|
||||
def _setup_db(self, txn):
|
||||
txn.execute("CREATE SEQUENCE foobar_seq")
|
||||
txn.execute(
|
||||
"""
|
||||
CREATE TABLE foobar (
|
||||
stream_id BIGINT NOT NULL,
|
||||
instance_name TEXT NOT NULL,
|
||||
data TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
|
||||
def _create(conn):
|
||||
return MultiWriterIdGenerator(
|
||||
conn,
|
||||
self.db,
|
||||
instance_name=instance_name,
|
||||
table="foobar",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_id",
|
||||
sequence_name="foobar_seq",
|
||||
)
|
||||
|
||||
return self.get_success(self.db.runWithConnection(_create))
|
||||
|
||||
def _insert_rows(self, instance_name: str, number: int):
|
||||
def _insert(txn):
|
||||
for _ in range(number):
|
||||
txn.execute(
|
||||
"INSERT INTO foobar VALUES (nextval('foobar_seq'), ?)",
|
||||
(instance_name,),
|
||||
)
|
||||
|
||||
self.get_success(self.db.runInteraction("test_single_instance", _insert))
|
||||
|
||||
def test_empty(self):
|
||||
"""Test an ID generator against an empty database gives sensible
|
||||
current positions.
|
||||
"""
|
||||
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
# The table is empty so we expect an empty map for positions
|
||||
self.assertEqual(id_gen.get_positions(), {})
|
||||
|
||||
def test_single_instance(self):
|
||||
"""Test that reads and writes from a single process are handled
|
||||
correctly.
|
||||
"""
|
||||
|
||||
# Prefill table with 7 rows written by 'master'
|
||||
self._insert_rows("master", 7)
|
||||
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token("master"), 7)
|
||||
|
||||
# Try allocating a new ID gen and check that we only see position
|
||||
# advanced after we leave the context manager.
|
||||
|
||||
async def _get_next_async():
|
||||
with await id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 8)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token("master"), 7)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||
self.assertEqual(id_gen.get_current_token("master"), 8)
|
||||
|
||||
def test_multi_instance(self):
|
||||
"""Test that reads and writes from multiple processes are handled
|
||||
correctly.
|
||||
"""
|
||||
self._insert_rows("first", 3)
|
||||
self._insert_rows("second", 4)
|
||||
|
||||
first_id_gen = self._create_id_generator("first")
|
||||
second_id_gen = self._create_id_generator("second")
|
||||
|
||||
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
|
||||
self.assertEqual(first_id_gen.get_current_token("first"), 3)
|
||||
self.assertEqual(first_id_gen.get_current_token("second"), 7)
|
||||
|
||||
# Try allocating a new ID gen and check that we only see position
|
||||
# advanced after we leave the context manager.
|
||||
|
||||
async def _get_next_async():
|
||||
with await first_id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 8)
|
||||
|
||||
self.assertEqual(
|
||||
first_id_gen.get_positions(), {"first": 3, "second": 7}
|
||||
)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
|
||||
self.assertEqual(first_id_gen.get_positions(), {"first": 8, "second": 7})
|
||||
|
||||
# However the ID gen on the second instance won't have seen the update
|
||||
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7})
|
||||
|
||||
# ... but calling `get_next` on the second instance should give a unique
|
||||
# stream ID
|
||||
|
||||
async def _get_next_async():
|
||||
with await second_id_gen.get_next() as stream_id:
|
||||
self.assertEqual(stream_id, 9)
|
||||
|
||||
self.assertEqual(
|
||||
second_id_gen.get_positions(), {"first": 3, "second": 7}
|
||||
)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
|
||||
self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 9})
|
||||
|
||||
# If the second ID gen gets told about the first, it correctly updates
|
||||
second_id_gen.advance("first", 8)
|
||||
self.assertEqual(second_id_gen.get_positions(), {"first": 8, "second": 9})
|
||||
|
||||
def test_get_next_txn(self):
|
||||
"""Test that the `get_next_txn` function works correctly.
|
||||
"""
|
||||
|
||||
# Prefill table with 7 rows written by 'master'
|
||||
self._insert_rows("master", 7)
|
||||
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token("master"), 7)
|
||||
|
||||
# Try allocating a new ID gen and check that we only see position
|
||||
# advanced after we leave the context manager.
|
||||
|
||||
def _get_next_txn(txn):
|
||||
stream_id = id_gen.get_next_txn(txn)
|
||||
self.assertEqual(stream_id, 8)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token("master"), 7)
|
||||
|
||||
self.get_success(self.db.runInteraction("test", _get_next_txn))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||
self.assertEqual(id_gen.get_current_token("master"), 8)
|
||||
+11
-3
@@ -19,6 +19,7 @@ import json
|
||||
|
||||
from mock import Mock
|
||||
|
||||
from synapse.api.auth_blocking import AuthBlocking
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||
from synapse.rest.client.v2_alpha import register, sync
|
||||
@@ -45,11 +46,17 @@ class TestMauLimit(unittest.HomeserverTestCase):
|
||||
self.hs.config.limit_usage_by_mau = True
|
||||
self.hs.config.hs_disabled = False
|
||||
self.hs.config.max_mau_value = 2
|
||||
self.hs.config.mau_trial_days = 0
|
||||
self.hs.config.server_notices_mxid = "@server:red"
|
||||
self.hs.config.server_notices_mxid_display_name = None
|
||||
self.hs.config.server_notices_mxid_avatar_url = None
|
||||
self.hs.config.server_notices_room_name = "Test Server Notice Room"
|
||||
self.hs.config.mau_trial_days = 0
|
||||
|
||||
# AuthBlocking reads config options during hs creation. Recreate the
|
||||
# hs' copy of AuthBlocking after we've updated config values above
|
||||
self.auth_blocking = AuthBlocking(self.hs)
|
||||
self.hs.get_auth()._auth_blocking = self.auth_blocking
|
||||
|
||||
return self.hs
|
||||
|
||||
def test_simple_deny_mau(self):
|
||||
@@ -121,6 +128,7 @@ class TestMauLimit(unittest.HomeserverTestCase):
|
||||
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
def test_trial_users_cant_come_back(self):
|
||||
self.auth_blocking._mau_trial_days = 1
|
||||
self.hs.config.mau_trial_days = 1
|
||||
|
||||
# We should be able to register more than the limit initially
|
||||
@@ -169,8 +177,8 @@ class TestMauLimit(unittest.HomeserverTestCase):
|
||||
self.assertEqual(e.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||
|
||||
def test_tracked_but_not_limited(self):
|
||||
self.hs.config.max_mau_value = 1 # should not matter
|
||||
self.hs.config.limit_usage_by_mau = False
|
||||
self.auth_blocking._max_mau_value = 1 # should not matter
|
||||
self.auth_blocking._limit_usage_by_mau = False
|
||||
self.hs.config.mau_stats_only = True
|
||||
|
||||
# Simply being able to create 2 users indicates that the
|
||||
|
||||
@@ -14,12 +14,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import synapse.server
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.types import Collection
|
||||
|
||||
from tests.test_utils import get_awaitable_result
|
||||
@@ -75,6 +76,23 @@ def inject_event(
|
||||
"""
|
||||
test_reactor = hs.get_reactor()
|
||||
|
||||
event, context = create_event(hs, room_version, prev_event_ids, **kwargs)
|
||||
|
||||
d = hs.get_storage().persistence.persist_event(event, context)
|
||||
test_reactor.advance(0)
|
||||
get_awaitable_result(d)
|
||||
|
||||
return event
|
||||
|
||||
|
||||
def create_event(
|
||||
hs: synapse.server.HomeServer,
|
||||
room_version: Optional[str] = None,
|
||||
prev_event_ids: Optional[Collection[str]] = None,
|
||||
**kwargs
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
test_reactor = hs.get_reactor()
|
||||
|
||||
if room_version is None:
|
||||
d = hs.get_datastore().get_room_version_id(kwargs["room_id"])
|
||||
test_reactor.advance(0)
|
||||
@@ -89,8 +107,4 @@ def inject_event(
|
||||
test_reactor.advance(0)
|
||||
event, context = get_awaitable_result(d)
|
||||
|
||||
d = hs.get_storage().persistence.persist_event(event, context)
|
||||
test_reactor.advance(0)
|
||||
get_awaitable_result(d)
|
||||
|
||||
return event
|
||||
return event, context
|
||||
|
||||
@@ -181,11 +181,7 @@ commands = mypy \
|
||||
synapse/appservice \
|
||||
synapse/config \
|
||||
synapse/events/spamcheck.py \
|
||||
synapse/federation/federation_base.py \
|
||||
synapse/federation/federation_client.py \
|
||||
synapse/federation/federation_server.py \
|
||||
synapse/federation/sender \
|
||||
synapse/federation/transport \
|
||||
synapse/federation \
|
||||
synapse/handlers/auth.py \
|
||||
synapse/handlers/cas_handler.py \
|
||||
synapse/handlers/directory.py \
|
||||
@@ -203,6 +199,7 @@ commands = mypy \
|
||||
synapse/storage/data_stores/main/ui_auth.py \
|
||||
synapse/storage/database.py \
|
||||
synapse/storage/engines \
|
||||
synapse/storage/util \
|
||||
synapse/streams \
|
||||
synapse/util/caches/stream_change_cache.py \
|
||||
tests/replication/tcp/streams \
|
||||
|
||||
Reference in New Issue
Block a user