1
0

Compare commits

..

2 Commits

Author SHA1 Message Date
Travis Ralston
177f2b838c changelog 2019-05-22 19:16:02 -06:00
Travis Ralston
f9d7d3aa89 Remove m.relates_to from events if the client set it to null
It appears as though Python only checks to see if the key exists in a dictionary, not necessarily for a useful value. This means that when clients submit (valid) requests with `m.relates_to: null` and Synapse later reads it, it gets a None reference error on access.

This is the easier route than guarding all the places where it could be None.
2019-05-22 19:14:10 -06:00
63 changed files with 544 additions and 1197 deletions

View File

@@ -1 +0,0 @@
Synapse will now serve the experimental "room complexity" API endpoint.

View File

@@ -1 +0,0 @@
Add experimental support for relations (aka reactions and edits).

View File

@@ -1 +0,0 @@
Ability to configure default room version.

View File

@@ -1 +0,0 @@
Simplifications and comments in do_auth.

View File

@@ -1 +0,0 @@
Fix appservice timestamp massaging.

View File

@@ -1 +0,0 @@
Rewrite store_server_verify_key to store several keys at once.

View File

@@ -1 +0,0 @@
Remove unused VerifyKey.expired and .time_added fields.

View File

@@ -1 +0,0 @@
Simplify Keyring.process_v2_response.

View File

@@ -1 +0,0 @@
Store key validity time in the storage layer.

1
changelog.d/5239.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix 500 Internal Server Error when sending an event with `m.relates_to: null`.

View File

@@ -1 +0,0 @@
Refactor synapse.crypto.keyring to use a KeyFetcher interface.

View File

@@ -1 +0,0 @@
Ability to configure default room version.

View File

@@ -1 +0,0 @@
Simplification to Keyring.wait_for_previous_lookups.

View File

@@ -1 +0,0 @@
Ensure that server_keys fetched via a notary server are correctly signed.

View File

@@ -1 +0,0 @@
Fix error code when there is an invalid parameter on /_matrix/client/r0/publicRooms

View File

@@ -1 +0,0 @@
Fix error when downloading thumbnail with missing width/height parameter.

View File

@@ -1 +0,0 @@
Synapse now more efficiently collates room statistics.

View File

@@ -1 +0,0 @@
Fix schema update for account validity.

View File

@@ -1 +0,0 @@
Fix bug where we leaked extremities when we soft failed events, leading to performance degradation.

View File

@@ -1 +0,0 @@
Fix "db txn 'update_presence' from sentinel context" log messages.

View File

@@ -1 +0,0 @@
Fix dropped logcontexts during high outbound traffic.

View File

@@ -1 +0,0 @@
Fix docs on resetting the user directory.

View File

@@ -1 +0,0 @@
Specify the type of reCAPTCHA key to use.

View File

@@ -1 +0,0 @@
CAS login will now hit the r0 API, not the deprecated v1 one.

View File

@@ -1 +0,0 @@
Remove spurious debug from MatrixFederationHttpClient.get_json.

View File

@@ -1 +0,0 @@
Improve logging for logcontext leaks.

View File

@@ -7,7 +7,6 @@ Requires a public/private key pair from:
https://developers.google.com/recaptcha/
Must be a reCAPTCHA v2 key using the "I'm not a robot" Checkbox option
Setting ReCaptcha Keys
----------------------

View File

@@ -83,16 +83,6 @@ pid_file: DATADIR/homeserver.pid
#
#restrict_public_rooms_to_local_users: true
# The default room version for newly created rooms.
#
# Known room versions are listed here:
# https://matrix.org/docs/spec/#complete-list-of-room-versions
#
# For example, for room version 1, default_room_version should be set
# to "1".
#
#default_room_version: "1"
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
#
#gc_thresholds: [700, 10, 10]
@@ -1103,9 +1093,9 @@ password_config:
#
# 'search_all_users' defines whether to search all users visible to your HS
# when searching the user directory, rather than limiting to users visible
# in public rooms. Defaults to false. If you set it True, you'll have to
# rebuild the user_directory search indexes, see
# https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
# in public rooms. Defaults to false. If you set it True, you'll have to run
# UPDATE user_directory_stream_pos SET stream_id = NULL;
# on your database to tell it to rebuild the user_directory search indexes.
#
#user_directory:
# enabled: true

View File

@@ -7,7 +7,11 @@ who are present in a publicly viewable room present on the server.
The directory info is stored in various tables, which can (typically after
DB corruption) get stale or out of sync. If this happens, for now the
solution to fix it is to execute the SQL here
https://github.com/matrix-org/synapse/blob/master/synapse/storage/schema/delta/53/user_dir_populate.sql
and then restart synapse. This should then start a background task to
quickest solution to fix it is:
```
UPDATE user_directory_stream_pos SET stream_id = NULL;
```
and restart the synapse, which should then start a background task to
flush the current tables and regenerate the directory.

View File

@@ -85,6 +85,10 @@ class RoomVersions(object):
)
# the version we will give rooms which are created on this server
DEFAULT_ROOM_VERSION = RoomVersions.V1
KNOWN_ROOM_VERSIONS = {
v.identifier: v for v in (
RoomVersions.V1,

View File

@@ -26,7 +26,6 @@ CLIENT_API_PREFIX = "/_matrix/client"
FEDERATION_PREFIX = "/_matrix/federation"
FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"
FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable"
STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"

View File

@@ -344,21 +344,15 @@ class _LimitedHostnameResolver(object):
def resolveHostName(self, resolutionReceiver, hostName, portNumber=0,
addressTypes=None, transportSemantics='TCP'):
# Note this is happening deep within the reactor, so we don't need to
# worry about log contexts.
# We need this function to return `resolutionReceiver` so we do all the
# actual logic involving deferreds in a separate function.
# even though this is happening within the depths of twisted, we need to drop
# our logcontext before starting _resolve, otherwise: (a) _resolve will drop
# the logcontext if it returns an incomplete deferred; (b) _resolve will
# call the resolutionReceiver *with* a logcontext, which it won't be expecting.
with PreserveLoggingContext():
self._resolve(
resolutionReceiver,
hostName,
portNumber,
addressTypes,
transportSemantics,
)
self._resolve(
resolutionReceiver, hostName, portNumber,
addressTypes, transportSemantics,
)
return resolutionReceiver

View File

@@ -20,7 +20,6 @@ import os.path
from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.python_dependencies import DependencyException, check_requirements
@@ -36,8 +35,6 @@ logger = logging.Logger(__name__)
# in the list.
DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0']
DEFAULT_ROOM_VERSION = "1"
class ServerConfig(Config):
@@ -91,22 +88,6 @@ class ServerConfig(Config):
"restrict_public_rooms_to_local_users", False,
)
default_room_version = config.get(
"default_room_version", DEFAULT_ROOM_VERSION,
)
# Ensure room version is a str
default_room_version = str(default_room_version)
if default_room_version not in KNOWN_ROOM_VERSIONS:
raise ConfigError(
"Unknown default_room_version: %s, known room versions: %s" %
(default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
)
# Get the actual room version object rather than just the identifier
self.default_room_version = KNOWN_ROOM_VERSIONS[default_room_version]
# whether to enable search. If disabled, new entries will not be inserted
# into the search tables and they will not be indexed. Users will receive
# errors when attempting to search for messages.
@@ -329,10 +310,6 @@ class ServerConfig(Config):
unsecure_port = 8008
pid_file = os.path.join(data_dir_path, "homeserver.pid")
# Bring DEFAULT_ROOM_VERSION into the local-scope for use in the
# default config string
default_room_version = DEFAULT_ROOM_VERSION
return """\
## Server ##
@@ -407,16 +384,6 @@ class ServerConfig(Config):
#
#restrict_public_rooms_to_local_users: true
# The default room version for newly created rooms.
#
# Known room versions are listed here:
# https://matrix.org/docs/spec/#complete-list-of-room-versions
#
# For example, for room version 1, default_room_version should be set
# to "1".
#
#default_room_version: "%(default_room_version)s"
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
#
#gc_thresholds: [700, 10, 10]

View File

@@ -43,9 +43,9 @@ class UserDirectoryConfig(Config):
#
# 'search_all_users' defines whether to search all users visible to your HS
# when searching the user directory, rather than limiting to users visible
# in public rooms. Defaults to false. If you set it True, you'll have to
# rebuild the user_directory search indexes, see
# https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
# in public rooms. Defaults to false. If you set it True, you'll have to run
# UPDATE user_directory_stream_pos SET stream_id = NULL;
# on your database to tell it to rebuild the user_directory search indexes.
#
#user_directory:
# enabled: true

View File

@@ -17,10 +17,10 @@
import logging
from collections import namedtuple
import six
from six import raise_from
from six.moves import urllib
import nacl.signing
from signedjson.key import (
decode_verify_key_bytes,
encode_verify_key_base64,
@@ -43,7 +43,6 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError
from synapse.util.logcontext import (
LoggingContext,
@@ -81,13 +80,12 @@ class KeyLookupError(ValueError):
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self._key_fetchers = (
StoreKeyFetcher(hs),
PerspectivesKeyFetcher(hs),
ServerKeyFetcher(hs),
)
self.client = hs.get_http_client()
self.config = hs.get_config()
self.perspective_servers = self.config.perspectives
self.hs = hs
# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
@@ -181,7 +179,9 @@ class Keyring(object):
# We want to wait for any previous lookups to complete before
# proceeding.
yield self.wait_for_previous_lookups(server_to_deferred)
yield self.wait_for_previous_lookups(
[rq.server_name for rq in verify_requests], server_to_deferred
)
# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
@@ -214,11 +214,12 @@ class Keyring(object):
logger.exception("Error starting key lookups")
@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_to_deferred):
def wait_for_previous_lookups(self, server_names, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish.
Args:
server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
server_names (list): list of server_names we want to lookup
server_to_deferred (dict): server_name to deferred which gets
resolved once we've finished looking up keys for that server.
The Deferreds should be regular twisted ones which call their
callbacks with no logcontext.
@@ -231,7 +232,7 @@ class Keyring(object):
while True:
wait_on = [
(server_name, self.key_downloads[server_name])
for server_name in server_to_deferred.keys()
for server_name in server_names
if server_name in self.key_downloads
]
if not wait_on:
@@ -270,6 +271,13 @@ class Keyring(object):
verify_requests (list[VerifyKeyRequest]): list of verify requests
"""
# These are functions that produce keys given a list of key ids
key_fetch_fns = (
self.get_keys_from_store, # First try the local store
self.get_keys_from_perspectives, # Then try via perspectives
self.get_keys_from_server, # Then try directly
)
@defer.inlineCallbacks
def do_iterations():
with Measure(self.clock, "get_server_verify_keys"):
@@ -280,8 +288,8 @@ class Keyring(object):
verify_request.key_ids
)
for f in self._key_fetchers:
results = yield f.get_keys(missing_keys.items())
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
# We now need to figure out which verify requests we have keys
# for and which we don't
@@ -299,15 +307,11 @@ class Keyring(object):
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
fetch_key_result = result_keys.get(key_id)
if fetch_key_result:
key = result_keys.get(key_id)
if key:
with PreserveLoggingContext():
verify_request.deferred.callback(
(
server_name,
key_id,
fetch_key_result.verify_key,
)
(server_name, key_id, key)
)
break
else:
@@ -340,31 +344,17 @@ class Keyring(object):
run_in_background(do_iterations).addErrback(on_err)
class KeyFetcher(object):
def get_keys(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
"""
Args:
server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
server_name_and_key_ids (iterable(Tuple[str, iterable[str]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
Note that the iterables may be iterated more than once.
Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
map from server_name -> key_id -> FetchKeyResult
Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from
server_name -> key_id -> VerifyKey
"""
raise NotImplementedError
class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store"""
def __init__(self, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
"""see KeyFetcher.get_keys"""
keys_to_fetch = (
(server_name, key_id)
for server_name, key_ids in server_name_and_key_ids
@@ -376,135 +366,8 @@ class StoreKeyFetcher(KeyFetcher):
keys.setdefault(server_name, {})[key_id] = key
defer.returnValue(keys)
class BaseV2KeyFetcher(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.config = hs.get_config()
@defer.inlineCallbacks
def process_v2_response(
self, from_server, response_json, time_added_ms, requested_ids=[]
):
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
GET /_matrix/key/v2/server, or a single entry from the list returned by
POST /_matrix/key/v2/query.
Checks that each signature in the response that claims to come from the origin
server is valid, and that there is at least one such signature.
Stores the json in server_keys_json so that it can be used for future responses
to /_matrix/key/v2/query.
Args:
from_server (str): the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.
response_json (dict): the json-decoded Server Keys response object
time_added_ms (int): the timestamp to record in server_keys_json
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
ts_valid_until_ms = response_json[u"valid_until_ts"]
# start by extracting the keys from the response, since they may be required
# to validate the signature on the response.
verify_keys = {}
for key_id, key_data in response_json["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=ts_valid_until_ms
)
server_name = response_json["server_name"]
verified = False
for key_id in response_json["signatures"].get(server_name, {}):
# each of the keys used for the signature must be present in the response
# json.
key = verify_keys.get(key_id)
if not key:
raise KeyLookupError(
"Key response is signed by key id %s:%s but that key is not "
"present in the response" % (server_name, key_id)
)
verify_signed_json(response_json, server_name, key.verify_key)
verified = True
if not verified:
raise KeyLookupError(
"Key response for %s is not signed by the origin server"
% (server_name,)
)
for key_id, key_data in response_json["old_verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_keys[key_id] = FetchKeyResult(
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
# re-sign the json with our own key, so that it is ready if we are asked to
# give it out as a notary server
signed_key_json = sign_json(
response_json, self.config.server_name, self.config.signing_key[0]
)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
# for reasons I don't quite understand, we store this json for the key ids we
# requested, as well as those we got.
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
defer.returnValue(verify_keys)
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the "perspectives" servers"""
def __init__(self, hs):
super(PerspectivesKeyFetcher, self).__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_http_client()
self.perspective_servers = self.config.perspectives
@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
"""see KeyFetcher.get_keys"""
def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
@@ -541,31 +404,32 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
defer.returnValue(union_of_keys)
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_server_verify_key_v2_direct, server_name, key_ids
)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
merged = {}
for result in results:
merged.update(result)
defer.returnValue(
{server_name: keys for server_name, keys in merged.items() if keys}
)
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(
self, server_names_and_key_ids, perspective_name, perspective_keys
):
"""
Args:
server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
list of (server_name, iterable[key_id]) tuples to fetch keys for
perspective_name (str): name of the notary server to query for the keys
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server
Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
from server_name -> key_id -> FetchKeyResult
Raises:
KeyLookupError if there was an error processing the entire response from
the server
"""
logger.info(
"Requesting keys %s from notary server %s",
server_names_and_key_ids,
perspective_name,
)
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
@@ -589,136 +453,72 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
raise_from(KeyLookupError("Remote server returned an error"), e)
keys = {}
added_keys = []
time_now_ms = self.clock.time_msec()
responses = query_response["server_keys"]
for response in query_response["server_keys"]:
# do this first, so that we can give useful errors thereafter
server_name = response.get("server_name")
if not isinstance(server_name, six.string_types):
for response in responses:
if (
u"signatures" not in response
or perspective_name not in response[u"signatures"]
):
raise KeyLookupError(
"Malformed response from key notary server %s: invalid server_name"
% (perspective_name,)
"Key response not signed by perspective server"
" %r" % (perspective_name,)
)
try:
processed_response = yield self._process_perspectives_response(
perspective_name,
perspective_keys,
response,
time_added_ms=time_now_ms,
)
except KeyLookupError as e:
logger.warning(
"Error processing response from key notary server %s for origin "
"server %s: %s",
perspective_name,
server_name,
e,
)
# we continue to process the rest of the response
continue
verified = False
for key_id in response[u"signatures"][perspective_name]:
if key_id in perspective_keys:
verify_signed_json(
response, perspective_name, perspective_keys[key_id]
)
verified = True
added_keys.extend(
(server_name, key_id, key) for key_id, key in processed_response.items()
if not verified:
logging.info(
"Response from perspective server %r not signed with a"
" known key, signed with: %r, known keys: %r",
perspective_name,
list(response[u"signatures"][perspective_name]),
list(perspective_keys),
)
raise KeyLookupError(
"Response not signed with a known key for perspective"
" server %r" % (perspective_name,)
)
processed_response = yield self.process_v2_response(
perspective_name, response
)
server_name = response["server_name"]
keys.setdefault(server_name, {}).update(processed_response)
yield self.store.store_server_verify_keys(
perspective_name, time_now_ms, added_keys
)
defer.returnValue(keys)
def _process_perspectives_response(
self, perspective_name, perspective_keys, response, time_added_ms
):
"""Parse a 'Server Keys' structure from the result of a /key/query request
Checks that the entry is correctly signed by the perspectives server, and then
passes over to process_v2_response
Args:
perspective_name (str): the name of the notary server that produced this
result
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server
response (dict): the json-decoded Server Keys response object
time_added_ms (int): the timestamp to record in server_keys_json
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
if (
u"signatures" not in response
or perspective_name not in response[u"signatures"]
):
raise KeyLookupError("Response not signed by the notary server")
verified = False
for key_id in response[u"signatures"][perspective_name]:
if key_id in perspective_keys:
verify_signed_json(response, perspective_name, perspective_keys[key_id])
verified = True
if not verified:
raise KeyLookupError(
"Response not signed with a known key: signed with: %r, known keys: %r"
% (
list(response[u"signatures"][perspective_name].keys()),
list(perspective_keys.keys()),
)
)
return self.process_v2_response(
perspective_name, response, time_added_ms=time_added_ms
)
class ServerKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the origin servers"""
def __init__(self, hs):
super(ServerKeyFetcher, self).__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_http_client()
@defer.inlineCallbacks
def get_keys(self, server_name_and_key_ids):
"""see KeyFetcher.get_keys"""
results = yield logcontext.make_deferred_yieldable(
yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_server_verify_key_v2_direct, server_name, key_ids
self.store_keys,
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
)
for server_name, key_ids in server_name_and_key_ids
for server_name, response_keys in keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
merged = {}
for result in results:
merged.update(result)
defer.returnValue(
{server_name: keys for server_name, keys in merged.items() if keys}
)
defer.returnValue(keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {} # type: dict[str, FetchKeyResult]
keys = {} # type: dict[str, nacl.signing.VerifyKey]
for requested_key_id in key_ids:
if requested_key_id in keys:
continue
time_now_ms = self.clock.time_msec()
try:
response = yield self.client.get_json(
destination=server_name,
@@ -731,6 +531,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise_from(KeyLookupError("Remote server returned an error"), e)
if (
u"signatures" not in response
or server_name not in response[u"signatures"]
):
raise KeyLookupError("Key response not signed by remote server")
if response["server_name"] != server_name:
raise KeyLookupError(
"Expected a response for server %r not %r"
@@ -741,17 +547,135 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
from_server=server_name,
requested_ids=[requested_key_id],
response_json=response,
time_added_ms=time_now_ms,
)
yield self.store.store_server_verify_keys(
server_name,
time_now_ms,
((server_name, key_id, key) for key_id, key in response_keys.items()),
)
keys.update(response_keys)
yield self.store_keys(
server_name=server_name, from_server=server_name, verify_keys=keys
)
defer.returnValue({server_name: keys})
@defer.inlineCallbacks
def process_v2_response(self, from_server, response_json, requested_ids=[]):
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
GET /_matrix/key/v2/server, or a single entry from the list returned by
POST /_matrix/key/v2/query.
Checks that each signature in the response that claims to come from the origin
server is valid. (Does not check that there actually is such a signature, for
some reason.)
Stores the json in server_keys_json so that it can be used for future responses
to /_matrix/key/v2/query.
Args:
from_server (str): the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.
response_json (dict): the json-decoded Server Keys response object
requested_ids (iterable[str]): a list of the key IDs that were requested.
We will store the json for these key ids as well as any that are
actually in the response
Returns:
Deferred[dict[str, nacl.signing.VerifyKey]]:
map from key_id to key object
"""
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
for key_id, key_data in response_json["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key
old_verify_keys = {}
for key_id, key_data in response_json["old_verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.expired = key_data["expired_ts"]
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise KeyLookupError(
"Key response must include verification keys for all" " signatures"
)
if key_id in verify_keys:
verify_signed_json(response_json, server_name, verify_keys[key_id])
signed_key_json = sign_json(
response_json, self.config.server_name, self.config.signing_key[0]
)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
yield logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=from_server,
ts_now_ms=time_now_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
defer.returnValue(response_keys)
def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
server_name(str): The name of the server the keys are for.
from_server(str): The server the keys were downloaded from.
verify_keys(dict): A mapping of key_id to VerifyKey.
Returns:
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
return logcontext.make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.store.store_server_verify_key,
server_name,
server_name,
key.time_added,
key,
)
for key_id, key in verify_keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
@defer.inlineCallbacks
def _handle_key_deferred(verify_request):

View File

@@ -76,7 +76,6 @@ class EventBuilder(object):
# someone tries to get them when they don't exist.
_state_key = attr.ib(default=None)
_redacts = attr.ib(default=None)
_origin_server_ts = attr.ib(default=None)
internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
@@ -143,9 +142,6 @@ class EventBuilder(object):
if self._redacts is not None:
event_dict["redacts"] = self._redacts
if self._origin_server_ts is not None:
event_dict["origin_server_ts"] = self._origin_server_ts
defer.returnValue(
create_local_event_from_event_dict(
clock=self._clock,
@@ -213,7 +209,6 @@ class EventBuilderFactory(object):
content=key_values.get("content", {}),
unsigned=key_values.get("unsigned", {}),
redacts=key_values.get("redacts", None),
origin_server_ts=key_values.get("origin_server_ts", None),
)
@@ -250,7 +245,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
event_dict["event_id"] = _create_event_id(clock, hostname)
event_dict["origin"] = hostname
event_dict.setdefault("origin_server_ts", time_now)
event_dict["origin_server_ts"] = time_now
event_dict.setdefault("unsigned", {})
age = event_dict["unsigned"].pop("age", 0)

View File

@@ -330,13 +330,12 @@ class EventClientSerializer(object):
)
@defer.inlineCallbacks
def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
def serialize_event(self, event, time_now, **kwargs):
"""Serializes a single event.
Args:
event (EventBase)
time_now (int): The current time in milliseconds
bundle_aggregations (bool): Whether to bundle in related events
**kwargs: Arguments to pass to `serialize_event`
Returns:
@@ -351,7 +350,7 @@ class EventClientSerializer(object):
# If MSC1849 is enabled then we need to look if thre are any relations
# we need to bundle in with the event
if self.experimental_msc1849_support_enabled and bundle_aggregations:
if self.experimental_msc1849_support_enabled:
annotations = yield self.store.get_aggregation_groups_for_event(
event_id,
)

View File

@@ -23,11 +23,7 @@ from twisted.internet import defer
import synapse
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX,
)
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
@@ -1308,30 +1304,6 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
defer.returnValue((200, new_content))
class RoomComplexityServlet(BaseFederationServlet):
"""
Indicates to other servers how complex (and therefore likely
resource-intensive) a public room this server knows about is.
"""
PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
PREFIX = FEDERATION_UNSTABLE_PREFIX
@defer.inlineCallbacks
def on_GET(self, origin, content, query, room_id):
store = self.handler.hs.get_datastore()
is_public = yield store.is_room_world_readable_or_publicly_joinable(
room_id
)
if not is_public:
raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
complexity = yield store.get_room_complexity(room_id)
defer.returnValue((200, complexity))
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationEventServlet,
@@ -1355,7 +1327,6 @@ FEDERATION_SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
)
OPENID_SERVLET_CLASSES = (

View File

@@ -122,9 +122,6 @@ class EventStreamHandler(BaseHandler):
chunks = yield self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
bundle_aggregations=False,
)
chunk = {

View File

@@ -2013,44 +2013,15 @@ class FederationHandler(BaseHandler):
Args:
origin (str):
event (synapse.events.EventBase):
event (synapse.events.FrozenEvent):
context (synapse.events.snapshot.EventContext):
auth_events (dict[(str, str)->synapse.events.EventBase]):
Map from (event_type, state_key) to event
What we expect the event's auth_events to be, based on the event's
position in the dag. I think? maybe??
Also NB that this function adds entries to it.
Returns:
defer.Deferred[None]
"""
room_version = yield self.store.get_room_version(event.room_id)
yield self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events
)
try:
self.auth.check(room_version, event, auth_events=auth_events)
except AuthError as e:
logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
@defer.inlineCallbacks
def _update_auth_events_and_context_for_auth(
self, origin, event, context, auth_events
):
"""Helper for do_auth. See there for docs.
Args:
origin (str):
event (synapse.events.EventBase):
context (synapse.events.snapshot.EventContext):
auth_events (dict[(str, str)->synapse.events.EventBase]):
auth_events (dict[(str, str)->str]):
Returns:
defer.Deferred[None]
"""
# Check if we have all the auth events.
current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(event.auth_event_ids())
if event.is_state():
@@ -2058,21 +2029,11 @@ class FederationHandler(BaseHandler):
else:
event_key = None
# if the event's auth_events refers to events which are not in our
# calculated auth_events, we need to fetch those events from somewhere.
#
# we start by fetching them from the store, and then try calling /event_auth/.
missing_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
if missing_auth:
if event_auth_events - current_state:
# TODO: can we use store.have_seen_events here instead?
have_events = yield self.store.get_seen_events_with_rejections(
missing_auth
event_auth_events - current_state
)
logger.debug("Got events %s from store", have_events)
missing_auth.difference_update(have_events.keys())
else:
have_events = {}
@@ -2081,12 +2042,13 @@ class FederationHandler(BaseHandler):
for e in auth_events.values()
})
seen_events = set(have_events.keys())
missing_auth = event_auth_events - seen_events - current_state
if missing_auth:
logger.info("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them.
logger.info(
"auth_events contains unknown events: %s",
missing_auth,
)
try:
remote_auth_chain = yield self.federation_client.get_event_auth(
origin, event.room_id, event.event_id
@@ -2127,168 +2089,145 @@ class FederationHandler(BaseHandler):
have_events = yield self.store.get_seen_events_with_rejections(
event.auth_event_ids()
)
seen_events = set(have_events.keys())
except Exception:
# FIXME:
logger.exception("Failed to get auth chain")
if event.internal_metadata.is_outlier():
logger.info("Skipping auth_event fetch for outlier")
return
# FIXME: Assumes we have and stored all the state for all the
# prev_events
different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
if not different_auth:
return
logger.info(
"auth_events refers to events which are not in our calculated auth "
"chain: %s",
different_auth,
)
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
room_version = yield self.store.get_room_version(event.room_id)
different_events = yield logcontext.make_deferred_yieldable(
defer.gatherResults([
logcontext.run_in_background(
self.store.get_event,
d,
allow_none=True,
allow_rejected=False,
if different_auth and not event.internal_metadata.is_outlier():
# Do auth conflict res.
logger.info("Different auth: %s", different_auth)
different_events = yield logcontext.make_deferred_yieldable(
defer.gatherResults([
logcontext.run_in_background(
self.store.get_event,
d,
allow_none=True,
allow_rejected=False,
)
for d in different_auth
if d in have_events and not have_events[d]
], consumeErrors=True)
).addErrback(unwrapFirstError)
if different_events:
local_view = dict(auth_events)
remote_view = dict(auth_events)
remote_view.update({
(d.type, d.state_key): d for d in different_events if d
})
new_state = yield self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())],
event
)
for d in different_auth
if d in have_events and not have_events[d]
], consumeErrors=True)
).addErrback(unwrapFirstError)
if different_events:
local_view = dict(auth_events)
remote_view = dict(auth_events)
remote_view.update({
(d.type, d.state_key): d for d in different_events if d
})
auth_events.update(new_state)
new_state = yield self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())],
event
)
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
logger.info(
"After state res: updating auth_events with new state %s",
{
(d.type, d.state_key): d.event_id for d in new_state.values()
if auth_events.get((d.type, d.state_key)) != d
},
)
yield self._update_context_for_auth_events(
event, context, auth_events, event_key,
)
auth_events.update(new_state)
if different_auth and not event.internal_metadata.is_outlier():
logger.info("Different auth after resolution: %s", different_auth)
different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
# Only do auth resolution if we have something new to say.
# We can't rove an auth failure.
do_resolution = False
yield self._update_context_for_auth_events(
event, context, auth_events, event_key,
)
provable = [
RejectedReason.NOT_ANCESTOR, RejectedReason.NOT_ANCESTOR,
]
if not different_auth:
# we're done
return
for e_id in different_auth:
if e_id in have_events:
if have_events[e_id] in provable:
do_resolution = True
break
logger.info(
"auth_events still refers to events which are not in the calculated auth "
"chain after state resolution: %s",
different_auth,
)
# Only do auth resolution if we have something new to say.
# We can't prove an auth failure.
do_resolution = False
for e_id in different_auth:
if e_id in have_events:
if have_events[e_id] == RejectedReason.NOT_ANCESTOR:
do_resolution = True
break
if not do_resolution:
logger.info(
"Skipping auth resolution due to lack of provable rejection reasons"
)
return
logger.info("Doing auth resolution")
prev_state_ids = yield context.get_prev_state_ids(self.store)
# 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events(
event, prev_state_ids
)
local_auth_chain = yield self.store.get_auth_chain(
auth_ids, include_given=True
)
try:
# 2. Get remote difference.
result = yield self.federation_client.query_auth(
origin,
event.room_id,
event.event_id,
local_auth_chain,
)
seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in result["auth_chain"]]
)
# 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]:
if ev.event_id in seen_remotes:
continue
if ev.event_id == event.event_id:
continue
if do_resolution:
prev_state_ids = yield context.get_prev_state_ids(self.store)
# 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events(
event, prev_state_ids
)
local_auth_chain = yield self.store.get_auth_chain(
auth_ids, include_given=True
)
try:
auth_ids = ev.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in result["auth_chain"]
if e.event_id in auth_ids
or event.type == EventTypes.Create
}
ev.internal_metadata.outlier = True
logger.debug(
"do_auth %s different_auth: %s",
event.event_id, e.event_id
# 2. Get remote difference.
result = yield self.federation_client.query_auth(
origin,
event.room_id,
event.event_id,
local_auth_chain,
)
yield self._handle_new_event(
origin, ev, auth_events=auth
seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in result["auth_chain"]]
)
if ev.event_id in event_auth_events:
auth_events[(ev.type, ev.state_key)] = ev
except AuthError:
pass
# 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]:
if ev.event_id in seen_remotes:
continue
except Exception:
# FIXME:
logger.exception("Failed to query auth chain")
if ev.event_id == event.event_id:
continue
# 4. Look at rejects and their proofs.
# TODO.
try:
auth_ids = ev.auth_event_ids()
auth = {
(e.type, e.state_key): e
for e in result["auth_chain"]
if e.event_id in auth_ids
or event.type == EventTypes.Create
}
ev.internal_metadata.outlier = True
yield self._update_context_for_auth_events(
event, context, auth_events, event_key,
)
logger.debug(
"do_auth %s different_auth: %s",
event.event_id, e.event_id
)
yield self._handle_new_event(
origin, ev, auth_events=auth
)
if ev.event_id in event_auth_events:
auth_events[(ev.type, ev.state_key)] = ev
except AuthError:
pass
except Exception:
# FIXME:
logger.exception("Failed to query auth chain")
# 4. Look at rejects and their proofs.
# TODO.
yield self._update_context_for_auth_events(
event, context, auth_events, event_key,
)
try:
self.auth.check(room_version, event, auth_events=auth_events)
except AuthError as e:
logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
@defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events,

View File

@@ -166,9 +166,6 @@ class MessageHandler(object):
now = self.clock.time_msec()
events = yield self._event_serializer.serialize_events(
room_state.values(), now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_aggregations=False,
)
defer.returnValue(events)

View File

@@ -182,27 +182,17 @@ class PresenceHandler(object):
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
def run_timeout_handler():
return run_as_background_process(
"handle_presence_timeouts", self._handle_timeouts
)
self.clock.call_later(
30,
self.clock.looping_call,
run_timeout_handler,
self._handle_timeouts,
5000,
)
def run_persister():
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
self.clock.call_later(
60,
self.clock.looping_call,
run_persister,
self._persist_unpersisted_changes,
60 * 1000,
)
@@ -239,7 +229,6 @@ class PresenceHandler(object):
)
if self.unpersisted_users_changes:
yield self.store.update_presence([
self.user_to_current_state[user_id]
for user_id in self.unpersisted_users_changes
@@ -251,18 +240,30 @@ class PresenceHandler(object):
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
"""
logger.info(
"Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
len(self.unpersisted_users_changes)
)
unpersisted = self.unpersisted_users_changes
self.unpersisted_users_changes = set()
if unpersisted:
logger.info(
"Persisting %d upersisted presence updates", len(unpersisted)
)
yield self.store.update_presence([
self.user_to_current_state[user_id]
for user_id in unpersisted
])
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks
def _update_states_and_catch_exception(self, new_states):
try:
res = yield self._update_states(new_states)
defer.returnValue(res)
except Exception:
logger.exception("Error updating presence")
@defer.inlineCallbacks
def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
@@ -337,41 +338,45 @@ class PresenceHandler(object):
logger.info("Handling presence timeouts")
now = self.clock.time_msec()
# Fetch the list of users that *may* have timed out. Things may have
# changed since the timeout was set, so we won't necessarily have to
# take any action.
users_to_check = set(self.wheel_timer.fetch(now))
try:
with Measure(self.clock, "presence_handle_timeouts"):
# Fetch the list of users that *may* have timed out. Things may have
# changed since the timeout was set, so we won't necessarily have to
# take any action.
users_to_check = set(self.wheel_timer.fetch(now))
# Check whether the lists of syncing processes from an external
# process have expired.
expired_process_ids = [
process_id for process_id, last_update
in self.external_process_last_updated_ms.items()
if now - last_update > EXTERNAL_PROCESS_EXPIRY
]
for process_id in expired_process_ids:
users_to_check.update(
self.external_process_last_updated_ms.pop(process_id, ())
)
self.external_process_last_update.pop(process_id)
# Check whether the lists of syncing processes from an external
# process have expired.
expired_process_ids = [
process_id for process_id, last_update
in self.external_process_last_updated_ms.items()
if now - last_update > EXTERNAL_PROCESS_EXPIRY
]
for process_id in expired_process_ids:
users_to_check.update(
self.external_process_last_updated_ms.pop(process_id, ())
)
self.external_process_last_update.pop(process_id)
states = [
self.user_to_current_state.get(
user_id, UserPresenceState.default(user_id)
)
for user_id in users_to_check
]
states = [
self.user_to_current_state.get(
user_id, UserPresenceState.default(user_id)
)
for user_id in users_to_check
]
timers_fired_counter.inc(len(states))
timers_fired_counter.inc(len(states))
changes = handle_timeouts(
states,
is_mine_fn=self.is_mine_id,
syncing_user_ids=self.get_currently_syncing_users(),
now=now,
)
changes = handle_timeouts(
states,
is_mine_fn=self.is_mine_id,
syncing_user_ids=self.get_currently_syncing_users(),
now=now,
)
return self._update_states(changes)
run_in_background(self._update_states_and_catch_exception, changes)
except Exception:
logger.exception("Exception in _handle_timeouts loop")
@defer.inlineCallbacks
def bump_presence_active_time(self, user):

View File

@@ -27,7 +27,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
@@ -70,7 +70,6 @@ class RoomCreationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.config = hs.config
# linearizer to stop two upgrades happening at once
self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
@@ -476,11 +475,7 @@ class RoomCreationHandler(BaseHandler):
if ratelimit:
yield self.ratelimit(requester)
room_version = config.get(
"room_version",
self.config.default_room_version.identifier,
)
room_version = config.get("room_version", DEFAULT_ROOM_VERSION.identifier)
if not isinstance(room_version, string_types):
raise SynapseError(
400,

View File

@@ -711,6 +711,10 @@ class MatrixFederationHttpClient(object):
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
logger.debug("get_json args: %s", args)
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
request = MatrixFederationRequest(
method="GET",
destination=destination,

View File

@@ -55,7 +55,7 @@ def parse_integer_from_args(args, name, default=None, required=False):
return int(args[name][0])
except Exception:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
raise SynapseError(400, message)
else:
if required:
message = "Missing integer query parameter %r" % (name,)

View File

@@ -822,16 +822,10 @@ class AdminRestResource(JsonResource):
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
register_servlets(hs, self)
def register_servlets(hs, http_server):
"""
Register all the admin servlets.
"""
register_servlets_for_client_rest_resource(hs, http_server)
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
register_servlets_for_client_rest_resource(hs, self)
SendServerNoticeServlet(hs).register(self)
VersionServlet(hs).register(self)
def register_servlets_for_client_rest_resource(hs, http_server):

View File

@@ -386,7 +386,7 @@ class CasRedirectServlet(RestServlet):
b"redirectUrl": args[b"redirectUrl"][0]
}).encode('ascii')
hs_redirect_url = (self.cas_service_url +
b"/_matrix/client/r0/login/cas/ticket")
b"/_matrix/client/api/v1/login/cas/ticket")
service_param = urllib.parse.urlencode({
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
}).encode('ascii')
@@ -395,7 +395,7 @@ class CasRedirectServlet(RestServlet):
class CasTicketServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas/ticket")
PATTERNS = client_path_patterns("/login/cas/ticket", releases=())
def __init__(self, hs):
super(CasTicketServlet, self).__init__(hs)

View File

@@ -201,6 +201,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
# Pull out the relationship early if the client sent us something
# which cannot possibly be processed by us.
if content.get("m.relates_to", "not None") is None:
del content["m.relates_to"]
event_dict = {
"type": event_type,
"content": content,

View File

@@ -16,7 +16,7 @@ import logging
from twisted.internet import defer
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet
from ._base import client_v2_patterns
@@ -36,7 +36,6 @@ class CapabilitiesRestServlet(RestServlet):
"""
super(CapabilitiesRestServlet, self).__init__()
self.hs = hs
self.config = hs.config
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -49,7 +48,7 @@ class CapabilitiesRestServlet(RestServlet):
response = {
"capabilities": {
"m.room_versions": {
"default": self.config.default_room_version.identifier,
"default": DEFAULT_ROOM_VERSION.identifier,
"available": {
v.identifier: v.disposition
for v in KNOWN_ROOM_VERSIONS.values()

View File

@@ -358,9 +358,6 @@ class SyncRestServlet(RestServlet):
def serialize(events):
return self._event_serializer.serialize_events(
events, time_now=time_now,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
bundle_aggregations=False,
token_id=token_id,
event_format=event_formatter,
only_event_fields=only_fields,

View File

@@ -20,7 +20,7 @@ from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import KeyLookupError, ServerKeyFetcher
from synapse.crypto.keyring import KeyLookupError
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
from synapse.http.servlet import parse_integer, parse_json_object_from_request
@@ -89,7 +89,7 @@ class RemoteKey(Resource):
isLeaf = True
def __init__(self, hs):
self.fetcher = ServerKeyFetcher(hs)
self.keyring = hs.get_keyring()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@@ -217,7 +217,7 @@ class RemoteKey(Resource):
if cache_misses and query_remote_on_cache_miss:
for server_name, key_ids in cache_misses.items():
try:
yield self.fetcher.get_server_verify_key_v2_direct(
yield self.keyring.get_server_verify_key_v2_direct(
server_name, key_ids
)
except KeyLookupError as e:

View File

@@ -56,8 +56,8 @@ class ThumbnailResource(Resource):
def _async_render_GET(self, request):
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True)
height = parse_integer(request, "height", required=True)
width = parse_integer(request, "width")
height = parse_integer(request, "height")
method = parse_string(request, "method", "scale")
m_type = parse_string(request, "type", "image/png")

View File

@@ -554,18 +554,10 @@ class EventsStore(
e_id for event in new_events for e_id in event.prev_event_ids()
)
# Remove any events which are prev_events of any existing events.
# Finally, remove any events which are prev_events of any existing events.
existing_prevs = yield self._get_events_which_are_prevs(result)
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev
# events. If they do we need to remove them and their prev events,
# otherwise we end up with dangling extremities.
existing_prevs = yield self._get_prevs_before_rejected(
e_id for event in new_events for e_id in event.prev_event_ids()
)
result.difference_update(existing_prevs)
defer.returnValue(result)
@defer.inlineCallbacks
@@ -581,7 +573,7 @@ class EventsStore(
"""
results = []
def _get_events_which_are_prevs_txn(txn, batch):
def _get_events(txn, batch):
sql = """
SELECT prev_event_id, internal_metadata
FROM event_edges
@@ -604,78 +596,10 @@ class EventsStore(
)
for chunk in batch_iter(event_ids, 100):
yield self.runInteraction(
"_get_events_which_are_prevs",
_get_events_which_are_prevs_txn,
chunk,
)
yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk)
defer.returnValue(results)
@defer.inlineCallbacks
def _get_prevs_before_rejected(self, event_ids):
"""Get soft-failed ancestors to remove from the extremities.
Given a set of events, find all those that have been soft-failed or
rejected. Returns those soft failed/rejected events and their prev
events (whether soft-failed/rejected or not), and recurses up the
prev-event graph until it finds no more soft-failed/rejected events.
This is used to find extremities that are ancestors of new events, but
are separated by soft failed events.
Args:
event_ids (Iterable[str]): Events to find prev events for. Note
that these must have already been persisted.
Returns:
Deferred[set[str]]
"""
# The set of event_ids to return. This includes all soft-failed events
# and their prev events.
existing_prevs = set()
def _get_prevs_before_rejected_txn(txn, batch):
to_recursively_check = batch
while to_recursively_check:
sql = """
SELECT
event_id, prev_event_id, internal_metadata,
rejections.event_id IS NOT NULL
FROM event_edges
INNER JOIN events USING (event_id)
LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id)
WHERE
event_id IN (%s)
AND NOT events.outlier
""" % (
",".join("?" for _ in to_recursively_check),
)
txn.execute(sql, to_recursively_check)
to_recursively_check = []
for event_id, prev_event_id, metadata, rejected in txn:
if prev_event_id in existing_prevs:
continue
soft_failed = json.loads(metadata).get("soft_failed")
if soft_failed or rejected:
to_recursively_check.append(prev_event_id)
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
yield self.runInteraction(
"_get_prevs_before_rejected",
_get_prevs_before_rejected_txn,
chunk,
)
defer.returnValue(existing_prevs)
@defer.inlineCallbacks
def _get_new_state_after_events(
self, room_id, events_context, old_latest_event_ids, new_latest_event_ids

View File

@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import itertools
import logging
from collections import namedtuple
@@ -612,11 +610,11 @@ class EventsWorkerStore(SQLBaseStore):
return res
return self.runInteraction("get_seen_events_with_rejections", f)
return self.runInteraction("get_rejection_reasons", f)
def _get_total_state_event_counts_txn(self, txn, room_id):
"""
See get_total_state_event_counts.
See get_state_event_counts.
"""
sql = "SELECT COUNT(*) FROM state_events WHERE room_id=?"
txn.execute(sql, (room_id,))
@@ -637,49 +635,3 @@ class EventsWorkerStore(SQLBaseStore):
"get_total_state_event_counts",
self._get_total_state_event_counts_txn, room_id
)
def _get_current_state_event_counts_txn(self, txn, room_id):
"""
See get_current_state_event_counts.
"""
sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?"
txn.execute(sql, (room_id,))
row = txn.fetchone()
return row[0] if row else 0
def get_current_state_event_counts(self, room_id):
"""
Gets the current number of state events in a room.
Args:
room_id (str)
Returns:
Deferred[int]
"""
return self.runInteraction(
"get_current_state_event_counts",
self._get_current_state_event_counts_txn, room_id
)
@defer.inlineCallbacks
def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
Higher complexity value indicates that being in the room will consume
more resources.
Args:
room_id (str)
Returns:
Deferred[dict[str:int]] of complexity version to complexity.
"""
state_events = yield self.get_current_state_event_counts(room_id)
# Call this one "v1", so we can introduce new ones as we want to develop
# it.
complexity_v1 = round(state_events / 500, 2)
defer.returnValue({"v1": complexity_v1})

View File

@@ -19,7 +19,6 @@ import logging
import six
import attr
from signedjson.key import decode_verify_key_bytes
from synapse.util import batch_iter
@@ -37,12 +36,6 @@ else:
db_binary_type = memoryview
@attr.s(slots=True, frozen=True)
class FetchKeyResult(object):
verify_key = attr.ib() # VerifyKey: the key itself
valid_until_ts = attr.ib() # int: how long we can use this key for
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys
"""
@@ -61,8 +54,8 @@ class KeyStore(SQLBaseStore):
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
map from (server_name, key_id) -> FetchKeyResult, or None if the key is
Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]:
map from (server_name, key_id) -> VerifyKey, or None if the key is
unknown
"""
keys = {}
@@ -72,19 +65,17 @@ class KeyStore(SQLBaseStore):
# batch_iter always returns tuples so it's safe to do len(batch)
sql = (
"SELECT server_name, key_id, verify_key, ts_valid_until_ms "
"FROM server_signature_keys WHERE 1=0"
"SELECT server_name, key_id, verify_key FROM server_signature_keys "
"WHERE 1=0"
) + " OR (server_name=? AND key_id=?)" * len(batch)
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
for row in txn:
server_name, key_id, key_bytes, ts_valid_until_ms = row
res = FetchKeyResult(
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
valid_until_ts=ts_valid_until_ms,
server_name, key_id, key_bytes = row
keys[(server_name, key_id)] = decode_verify_key_bytes(
key_id, bytes(key_bytes)
)
keys[(server_name, key_id)] = res
def _txn(txn):
for batch in batch_iter(server_name_and_key_ids, 50):
@@ -93,53 +84,38 @@ class KeyStore(SQLBaseStore):
return self.runInteraction("get_server_verify_keys", _txn)
def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
"""Stores NACL verification keys for remote servers.
def store_server_verify_key(
self, server_name, from_server, time_now_ms, verify_key
):
"""Stores a NACL verification key for the given server.
Args:
from_server (str): Where the verification keys were looked up
ts_added_ms (int): The time to record that the key was added
verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
keys to be stored. Each entry is a triplet of
(server_name, key_id, key).
server_name (str): The name of the server.
from_server (str): Where the verification key was looked up
time_now_ms (int): The time now in milliseconds
verify_key (nacl.signing.VerifyKey): The NACL verify key.
"""
key_values = []
value_values = []
invalidations = []
for server_name, key_id, fetch_result in verify_keys:
key_values.append((server_name, key_id))
value_values.append(
(
from_server,
ts_added_ms,
fetch_result.valid_until_ts,
db_binary_type(fetch_result.verify_key.encode()),
)
key_id = "%s:%s" % (verify_key.alg, verify_key.version)
# XXX fix this to not need a lock (#3819)
def _txn(txn):
self._simple_upsert_txn(
txn,
table="server_signature_keys",
keyvalues={"server_name": server_name, "key_id": key_id},
values={
"from_server": from_server,
"ts_added_ms": time_now_ms,
"verify_key": db_binary_type(verify_key.encode()),
},
)
# invalidate takes a tuple corresponding to the params of
# _get_server_verify_key. _get_server_verify_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
invalidations.append((server_name, key_id))
txn.call_after(
self._get_server_verify_key.invalidate, ((server_name, key_id),)
)
def _invalidate(res):
f = self._get_server_verify_key.invalidate
for i in invalidations:
f((i, ))
return res
return self.runInteraction(
"store_server_verify_keys",
self._simple_upsert_many_txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
value_names=(
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"verify_key",
),
value_values=value_values,
).addCallback(_invalidate)
return self.runInteraction("store_server_verify_key", _txn)
def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes

View File

@@ -13,9 +13,6 @@
* limitations under the License.
*/
-- We previously changed the schema for this table without renaming the file, which means
-- that some databases might still be using the old schema. This ensures Synapse uses the
-- right schema for the table.
DROP TABLE IF EXISTS account_validity;
-- Track what users are in public rooms.

View File

@@ -1,23 +0,0 @@
/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/* When we can use this key until, before we have to refresh it. */
ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT;
UPDATE server_signature_keys SET ts_valid_until_ms = (
SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE
skj.server_name = server_signature_keys.server_name AND
skj.key_id = server_signature_keys.key_id
);

View File

@@ -169,7 +169,7 @@ class StatsStore(StateDeltasStore):
logger.info(
"Processing the next %d rooms of %d remaining",
len(rooms_to_work_on), progress["remaining"],
(len(rooms_to_work_on), progress["remaining"]),
)
# Number of state events we've processed by going through each room

View File

@@ -226,8 +226,6 @@ class LoggingContext(object):
self.request = request
def __str__(self):
if self.request:
return str(self.request)
return "%s@%x" % (self.name, id(self))
@classmethod
@@ -276,10 +274,12 @@ class LoggingContext(object):
current = self.set_current_context(self.previous_context)
if current is not self:
if current is self.sentinel:
logger.warning("Expected logging context %s was lost", self)
logger.warn("Expected logging context %s has been lost", self)
else:
logger.warning(
"Expected logging context %s but found %s", self, current
logger.warn(
"Current logging context %s is not expected context %s",
current,
self
)
self.previous_context = None
self.alive = False
@@ -433,14 +433,10 @@ class PreserveLoggingContext(object):
context = LoggingContext.set_current_context(self.current_context)
if context != self.new_context:
if context is LoggingContext.sentinel:
logger.warning("Expected logging context %s was lost", self.new_context)
else:
logger.warning(
"Expected logging context %s but found %s",
self.new_context,
context,
)
logger.warn(
"Unexpected logging context: %s is not %s",
context, self.new_context,
)
if self.current_context is not LoggingContext.sentinel:
if not self.current_context.alive:

View File

@@ -24,12 +24,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
from synapse.crypto.keyring import (
KeyLookupError,
PerspectivesKeyFetcher,
ServerKeyFetcher,
)
from synapse.storage.keys import FetchKeyResult
from synapse.crypto.keyring import KeyLookupError
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext
@@ -55,11 +50,11 @@ class MockPerspectiveServer(object):
key_id: {"key": signedjson.key.encode_verify_key_base64(verify_key)}
},
}
self.sign_response(res)
return res
return self.get_signed_response(res)
def sign_response(self, res):
def get_signed_response(self, res):
signedjson.sign.sign_json(res, self.server_name, self.key)
return res
class KeyringTestCase(unittest.HomeserverTestCase):
@@ -85,7 +80,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# we run the lookup in a logcontext so that the patched inlineCallbacks can check
# it is doing the right thing with logcontexts.
wait_1_deferred = run_in_context(
kr.wait_for_previous_lookups, {"server1": lookup_1_deferred}
kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_1_deferred}
)
# there were no previous lookups, so the deferred should be ready
@@ -94,7 +89,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
# set off another wait. It should block because the first lookup
# hasn't yet completed.
wait_2_deferred = run_in_context(
kr.wait_for_previous_lookups, {"server1": lookup_2_deferred}
kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_2_deferred}
)
self.assertFalse(wait_2_deferred.called)
@@ -197,18 +192,8 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
key1_id = "%s:%s" % (key1.alg, key1.version)
r = self.hs.datastore.store_server_verify_keys(
"server9",
time.time() * 1000,
[
(
"server9",
key1_id,
FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
),
],
r = self.hs.datastore.store_server_verify_key(
"server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
)
self.get_success(r)
json1 = {}
@@ -222,23 +207,16 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.assertFalse(d.called)
self.get_success(d)
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.http_client = Mock()
hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
return hs
def test_get_keys_from_server(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
SERVER_NAME = "server2"
fetcher = ServerKeyFetcher(self.hs)
kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
VALID_UNTIL_TS = 200 * 1000
VALID_UNTIL_TS = 1000
# valid response
response = {
@@ -261,12 +239,11 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
self.assertEqual(k.verify_key.alg, "ed25519")
self.assertEqual(k.verify_key.version, "ver1")
self.assertEqual(k, testverifykey)
self.assertEqual(k.alg, "ed25519")
self.assertEqual(k.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -289,26 +266,15 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
# change the server name: it should cause a rejection
response["server_name"] = "OTHER_SERVER"
self.get_failure(
fetcher.get_keys(server_name_and_key_ids), KeyLookupError
kr.get_keys_from_server(server_name_and_key_ids), KeyLookupError
)
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock()
hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
keys = self.mock_perspective_server.get_verify_keys()
hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
return hs
def test_get_keys_from_perspectives(self):
# arbitrarily advance the clock a bit
self.reactor.advance(100)
fetcher = PerspectivesKeyFetcher(self.hs)
SERVER_NAME = "server2"
kr = keyring.Keyring(self.hs)
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
@@ -326,10 +292,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
},
}
# the response must be signed by both the origin server and the perspectives
# server.
signedjson.sign.sign_json(response, SERVER_NAME, testkey)
self.mock_perspective_server.sign_response(response)
persp_resp = {
"server_keys": [self.mock_perspective_server.get_signed_response(response)]
}
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -338,18 +303,17 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
# check that the request is for the expected key
q = data["server_keys"]
self.assertEqual(list(q[SERVER_NAME].keys()), ["key1"])
return {"server_keys": [response]}
return persp_resp
self.http_client.post_json.side_effect = post_json
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
self.assertEqual(k.verify_key.alg, "ed25519")
self.assertEqual(k.verify_key.version, "ver1")
self.assertEqual(k, testverifykey)
self.assertEqual(k.alg, "ed25519")
self.assertEqual(k.version, "ver1")
# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -366,81 +330,13 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(
bytes(res["key_json"]),
canonicaljson.encode_canonical_json(response),
canonicaljson.encode_canonical_json(persp_resp["server_keys"][0]),
)
def test_invalid_perspectives_responses(self):
"""Check that invalid responses from the perspectives server are rejected"""
# arbitrarily advance the clock a bit
self.reactor.advance(100)
SERVER_NAME = "server2"
testkey = signedjson.key.generate_signing_key("ver1")
testverifykey = signedjson.key.get_verify_key(testkey)
testverifykey_id = "ed25519:ver1"
VALID_UNTIL_TS = 200 * 1000
def build_response():
# valid response
response = {
"server_name": SERVER_NAME,
"old_verify_keys": {},
"valid_until_ts": VALID_UNTIL_TS,
"verify_keys": {
testverifykey_id: {
"key": signedjson.key.encode_verify_key_base64(testverifykey)
}
},
}
# the response must be signed by both the origin server and the perspectives
# server.
signedjson.sign.sign_json(response, SERVER_NAME, testkey)
self.mock_perspective_server.sign_response(response)
return response
def get_key_from_perspectives(response):
fetcher = PerspectivesKeyFetcher(self.hs)
server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
self.assertEqual(path, "/_matrix/key/v2/query")
return {"server_keys": [response]}
self.http_client.post_json.side_effect = post_json
return self.get_success(
fetcher.get_keys(server_name_and_key_ids)
)
# start with a valid response so we can check we are testing the right thing
response = build_response()
keys = get_key_from_perspectives(response)
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.verify_key, testverifykey)
# remove the perspectives server's signature
response = build_response()
del response["signatures"][self.mock_perspective_server.server_name]
self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig")
# remove the origin server's signature
response = build_response()
del response["signatures"][SERVER_NAME]
self.http_client.post_json.return_value = {"server_keys": [response]}
keys = get_key_from_perspectives(response)
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx") as ctx:
# we set the "request" prop to make it easier to follow what's going on in the
# logs.
ctx.request = "testctx"
with LoggingContext("testctx"):
rv = yield f(*args, **kwargs)
defer.returnValue(rv)

View File

@@ -1,90 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2019 Matrix.org Foundation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.federation.transport import server
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
class RoomComplexityTests(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def default_config(self, name='test'):
config = super(RoomComplexityTests, self).default_config(name=name)
config["limit_large_remote_room_joins"] = True
config["limit_large_remote_room_complexity"] = 0.05
return config
def prepare(self, reactor, clock, homeserver):
class Authenticator(object):
def authenticate_request(self, request, content):
return defer.succeed("otherserver.nottld")
ratelimiter = FederationRateLimiter(
clock,
FederationRateLimitConfig(
window_size=1,
sleep_limit=1,
sleep_msec=1,
reject_limit=1000,
concurrent_requests=1000,
),
)
server.register_servlets(
homeserver, self.resource, Authenticator(), ratelimiter
)
def test_complexity_simple(self):
u1 = self.register_user("u1", "pass")
u1_token = self.login("u1", "pass")
room_1 = self.helper.create_room_as(u1, tok=u1_token)
self.helper.send_state(
room_1, event_type="m.room.topic", body={"topic": "foo"}, tok=u1_token
)
# Get the room complexity
request, channel = self.make_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.render(request)
self.assertEquals(200, channel.code)
complexity = channel.json_body["v1"]
self.assertTrue(complexity > 0, complexity)
# Artificially raise the complexity
store = self.hs.get_datastore()
store.get_current_state_event_counts = lambda x: defer.succeed(500 * 1.23)
# Get the room complexity again -- make sure it's our artificial value
request, channel = self.make_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.render(request)
self.assertEquals(200, channel.code)
complexity = channel.json_body["v1"]
self.assertEqual(complexity, 1.23)

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse.rest.admin
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import capabilities
@@ -32,7 +32,6 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.url = b"/_matrix/client/r0/capabilities"
hs = self.setup_test_homeserver()
self.store = hs.get_datastore()
self.config = hs.config
return hs
def test_check_auth_required(self):
@@ -52,10 +51,8 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
for room_version in capabilities['m.room_versions']['available'].keys():
self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version)
self.assertEqual(
self.config.default_room_version.identifier,
capabilities['m.room_versions']['default'],
DEFAULT_ROOM_VERSION.identifier, capabilities['m.room_versions']['default']
)
def test_get_change_password_capabilities(self):

View File

@@ -17,8 +17,6 @@ import signedjson.key
from twisted.internet.defer import Deferred
from synapse.storage.keys import FetchKeyResult
import tests.unittest
KEY_1 = signedjson.key.decode_verify_key_base64(
@@ -33,34 +31,23 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
def test_get_server_verify_keys(self):
store = self.hs.get_datastore()
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:KEY_ID_2"
d = store.store_server_verify_keys(
"from_server",
10,
[
("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
],
)
d = store.store_server_verify_key("server1", "from_server", 0, KEY_1)
self.get_success(d)
d = store.store_server_verify_key("server1", "from_server", 0, KEY_2)
self.get_success(d)
d = store.get_server_verify_keys(
[("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")]
[
("server1", "ed25519:key1"),
("server1", "ed25519:key2"),
("server1", "ed25519:key3"),
]
)
res = self.get_success(d)
self.assertEqual(len(res.keys()), 3)
res1 = res[("server1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.verify_key.version, "key1")
self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("server1", key_id_2)]
self.assertEqual(res2.verify_key, KEY_2)
# version comes from the ID it was stored with
self.assertEqual(res2.verify_key.version, "KEY_ID_2")
self.assertEqual(res2.valid_until_ts, 200)
self.assertEqual(res[("server1", "ed25519:key1")].version, "key1")
self.assertEqual(res[("server1", "ed25519:key2")].version, "key2")
# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
@@ -73,51 +60,32 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
key_id_1 = "ed25519:key1"
key_id_2 = "ed25519:key2"
d = store.store_server_verify_keys(
"from_server",
0,
[
("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
],
)
d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1)
self.get_success(d)
d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2)
self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("srv1", key_id_2)]
self.assertEqual(res2.verify_key, KEY_2)
self.assertEqual(res2.valid_until_ts, 200)
self.assertEqual(res[("srv1", key_id_1)], KEY_1)
self.assertEqual(res[("srv1", key_id_2)], KEY_2)
# we should be able to look up the same thing again without a db hit
res = store.get_server_verify_keys([("srv1", key_id_1)])
if isinstance(res, Deferred):
res = self.successResultOf(res)
self.assertEqual(len(res.keys()), 1)
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
self.assertEqual(res[("srv1", key_id_1)], KEY_1)
new_key_2 = signedjson.key.get_verify_key(
signedjson.key.generate_signing_key("key2")
)
d = store.store_server_verify_keys(
"from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
)
d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2)
self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
res1 = res[("srv1", key_id_1)]
self.assertEqual(res1.verify_key, KEY_1)
self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("srv1", key_id_2)]
self.assertEqual(res2.verify_key, new_key_2)
self.assertEqual(res2.valid_until_ts, 300)
self.assertEqual(res[("srv1", key_id_1)], KEY_1)
self.assertEqual(res[("srv1", key_id_2)], new_key_2)