Merge commit 'a973bcb8a' into anoa/dinsic_release_1_18_x
* commit 'a973bcb8a': Add some tiny type annotations (#7870) Remove obsolete comment. Ensure that calls to `json.dumps` are compatible with the standard library json. (#7836) Avoid brand new rooms in `delete_old_current_state_events` (#7854) Allow accounts to be re-activated from the admin APIs. (#7847) Fix tests Fix typo Newsfile Use get_users_in_room rather than state handler in typing for speed Fix client reader sharding tests (#7853) Convert E2E key and room key handlers to async/await. (#7851) Return the proper 403 Forbidden error during errors with JWT logins. (#7844) remove `retry_on_integrity_error` wrapper for persist_events (#7848)
This commit is contained in:
1
changelog.d/7836.misc
Normal file
1
changelog.d/7836.misc
Normal file
@@ -0,0 +1 @@
|
||||
Ensure that calls to `json.dumps` are compatible with the standard library json.
|
||||
1
changelog.d/7844.bugfix
Normal file
1
changelog.d/7844.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`.
|
||||
1
changelog.d/7847.feature
Normal file
1
changelog.d/7847.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add the ability to re-activate an account from the admin API.
|
||||
1
changelog.d/7848.misc
Normal file
1
changelog.d/7848.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove redundant `retry_on_integrity_error` wrapper for event persistence code.
|
||||
1
changelog.d/7851.misc
Normal file
1
changelog.d/7851.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert E2E keys and room keys handlers to async/await.
|
||||
1
changelog.d/7853.misc
Normal file
1
changelog.d/7853.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add support for handling registration requests across multiple client reader workers.
|
||||
1
changelog.d/7854.bugfix
Normal file
1
changelog.d/7854.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation.
|
||||
1
changelog.d/7856.misc
Normal file
1
changelog.d/7856.misc
Normal file
@@ -0,0 +1 @@
|
||||
Small performance improvement in typing processing.
|
||||
1
changelog.d/7870.misc
Normal file
1
changelog.d/7870.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add some type annotations to `HomeServer` and `BaseHandler`.
|
||||
@@ -91,10 +91,14 @@ Body parameters:
|
||||
|
||||
- ``admin``, optional, defaults to ``false``.
|
||||
|
||||
- ``deactivated``, optional, defaults to ``false``.
|
||||
- ``deactivated``, optional. If unspecified, deactivation state will be left
|
||||
unchanged on existing accounts and set to ``false`` for new accounts.
|
||||
|
||||
If the user already exists then optional parameters default to the current value.
|
||||
|
||||
In order to re-activate an account ``deactivated`` must be set to ``false``. If
|
||||
users do not login via single-sign-on, a new ``password`` must be provided.
|
||||
|
||||
List Accounts
|
||||
=============
|
||||
|
||||
|
||||
@@ -31,10 +31,7 @@ The `token` field should include the JSON web token with the following claims:
|
||||
Providing the audience claim when not configured will cause validation to fail.
|
||||
|
||||
In the case that the token is not valid, the homeserver must respond with
|
||||
`401 Unauthorized` and an error code of `M_UNAUTHORIZED`.
|
||||
|
||||
(Note that this differs from the token based logins which return a
|
||||
`403 Forbidden` and an error code of `M_FORBIDDEN` if an error occurs.)
|
||||
`403 Forbidden` and an error code of `M_FORBIDDEN`.
|
||||
|
||||
As with other login types, there are additional fields (e.g. `device_id` and
|
||||
`initial_device_display_name`) which can be included in the above request.
|
||||
|
||||
@@ -16,12 +16,14 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains exceptions and error codes."""
|
||||
import json
|
||||
|
||||
import logging
|
||||
import typing
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.web import http
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
|
||||
@@ -14,10 +14,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
|
||||
|
||||
from canonicaljson import json
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -61,8 +61,6 @@ class TransactionManager(object):
|
||||
# all the edus in that transaction. This needs to be done since there is
|
||||
# no active span here, so if the edus were not received by the remote the
|
||||
# span would have no causality and it would be forgotten.
|
||||
# The span_contexts is a generator so that it won't be evaluated if
|
||||
# opentracing is disabled. (Yay speed!)
|
||||
|
||||
span_contexts = []
|
||||
keep_destination = whitelisted_homeserver(destination)
|
||||
|
||||
@@ -17,6 +17,8 @@ import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.state
|
||||
import synapse.storage
|
||||
import synapse.types
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
@@ -28,10 +30,6 @@ logger = logging.getLogger(__name__)
|
||||
class BaseHandler(object):
|
||||
"""
|
||||
Common base class for the event handlers.
|
||||
|
||||
Attributes:
|
||||
store (synapse.storage.DataStore):
|
||||
state_handler (synapse.state.StateHandler):
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
@@ -39,10 +37,10 @@ class BaseHandler(object):
|
||||
Args:
|
||||
hs (synapse.server.HomeServer):
|
||||
"""
|
||||
self.store = hs.get_datastore()
|
||||
self.store = hs.get_datastore() # type: synapse.storage.DataStore
|
||||
self.auth = hs.get_auth()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler
|
||||
self.distributor = hs.get_distributor()
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
@@ -46,19 +47,20 @@ class DeactivateAccountHandler(BaseHandler):
|
||||
|
||||
self._account_validity_enabled = hs.config.account_validity.enabled
|
||||
|
||||
async def deactivate_account(self, user_id, erase_data, id_server=None):
|
||||
async def deactivate_account(
|
||||
self, user_id: str, erase_data: bool, id_server: Optional[str] = None
|
||||
) -> bool:
|
||||
"""Deactivate a user's account
|
||||
|
||||
Args:
|
||||
user_id (str): ID of user to be deactivated
|
||||
erase_data (bool): whether to GDPR-erase the user's data
|
||||
id_server (str|None): Use the given identity server when unbinding
|
||||
user_id: ID of user to be deactivated
|
||||
erase_data: whether to GDPR-erase the user's data
|
||||
id_server: Use the given identity server when unbinding
|
||||
any threepids. If None then will attempt to unbind using the
|
||||
identity server specified when binding (if known).
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: True if identity server supports removing
|
||||
threepids, otherwise False.
|
||||
True if identity server supports removing threepids, otherwise False.
|
||||
"""
|
||||
# FIXME: Theoretically there is a race here wherein user resets
|
||||
# password using threepid.
|
||||
@@ -138,11 +140,11 @@ class DeactivateAccountHandler(BaseHandler):
|
||||
|
||||
return identity_server_supports_unbinding
|
||||
|
||||
async def _reject_pending_invites_for_user(self, user_id):
|
||||
async def _reject_pending_invites_for_user(self, user_id: str):
|
||||
"""Reject pending invites addressed to a given user ID.
|
||||
|
||||
Args:
|
||||
user_id (str): The user ID to reject pending invites for.
|
||||
user_id: The user ID to reject pending invites for.
|
||||
"""
|
||||
user = UserID.from_string(user_id)
|
||||
pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
|
||||
@@ -170,22 +172,16 @@ class DeactivateAccountHandler(BaseHandler):
|
||||
room.room_id,
|
||||
)
|
||||
|
||||
def _start_user_parting(self):
|
||||
def _start_user_parting(self) -> None:
|
||||
"""
|
||||
Start the process that goes through the table of users
|
||||
pending deactivation, if it isn't already running.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not self._user_parter_running:
|
||||
run_as_background_process("user_parter_loop", self._user_parter_loop)
|
||||
|
||||
async def _user_parter_loop(self):
|
||||
async def _user_parter_loop(self) -> None:
|
||||
"""Loop that parts deactivated users from rooms
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self._user_parter_running = True
|
||||
logger.info("Starting user parter")
|
||||
@@ -202,11 +198,8 @@ class DeactivateAccountHandler(BaseHandler):
|
||||
finally:
|
||||
self._user_parter_running = False
|
||||
|
||||
async def _part_user(self, user_id):
|
||||
async def _part_user(self, user_id: str) -> None:
|
||||
"""Causes the given user_id to leave all the rooms they're joined to
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
@@ -228,3 +221,18 @@ class DeactivateAccountHandler(BaseHandler):
|
||||
user_id,
|
||||
room_id,
|
||||
)
|
||||
|
||||
async def activate_account(self, user_id: str) -> None:
|
||||
"""
|
||||
Activate an account that was previously deactivated.
|
||||
|
||||
This simply marks the user as activate in the database and does not
|
||||
attempt to rejoin rooms, re-add threepids, etc.
|
||||
|
||||
The user will also need a password hash set to actually login.
|
||||
|
||||
Args:
|
||||
user_id: ID of user to be deactivated
|
||||
"""
|
||||
# Mark the user as activate.
|
||||
await self.store.set_user_deactivated_status(user_id, False)
|
||||
|
||||
@@ -77,8 +77,7 @@ class E2eKeysHandler(object):
|
||||
)
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def query_devices(self, query_body, timeout, from_user_id):
|
||||
async def query_devices(self, query_body, timeout, from_user_id):
|
||||
""" Handle a device key query from a client
|
||||
|
||||
{
|
||||
@@ -124,7 +123,7 @@ class E2eKeysHandler(object):
|
||||
failures = {}
|
||||
results = {}
|
||||
if local_query:
|
||||
local_result = yield self.query_local_devices(local_query)
|
||||
local_result = await self.query_local_devices(local_query)
|
||||
for user_id, keys in local_result.items():
|
||||
if user_id in local_query:
|
||||
results[user_id] = keys
|
||||
@@ -142,7 +141,7 @@ class E2eKeysHandler(object):
|
||||
(
|
||||
user_ids_not_in_cache,
|
||||
remote_results,
|
||||
) = yield self.store.get_user_devices_from_cache(query_list)
|
||||
) = await self.store.get_user_devices_from_cache(query_list)
|
||||
for user_id, devices in remote_results.items():
|
||||
user_devices = results.setdefault(user_id, {})
|
||||
for device_id, device in devices.items():
|
||||
@@ -161,14 +160,13 @@ class E2eKeysHandler(object):
|
||||
r[user_id] = remote_queries[user_id]
|
||||
|
||||
# Get cached cross-signing keys
|
||||
cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
|
||||
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
|
||||
device_keys_query, from_user_id
|
||||
)
|
||||
|
||||
# Now fetch any devices that we don't have in our cache
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def do_remote_query(destination):
|
||||
async def do_remote_query(destination):
|
||||
"""This is called when we are querying the device list of a user on
|
||||
a remote homeserver and their device list is not in the device list
|
||||
cache. If we share a room with this user and we're not querying for
|
||||
@@ -192,7 +190,7 @@ class E2eKeysHandler(object):
|
||||
if device_list:
|
||||
continue
|
||||
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
if not room_ids:
|
||||
continue
|
||||
|
||||
@@ -201,11 +199,11 @@ class E2eKeysHandler(object):
|
||||
# done an initial sync on the device list so we do it now.
|
||||
try:
|
||||
if self._is_master:
|
||||
user_devices = yield self.device_handler.device_list_updater.user_device_resync(
|
||||
user_devices = await self.device_handler.device_list_updater.user_device_resync(
|
||||
user_id
|
||||
)
|
||||
else:
|
||||
user_devices = yield self._user_device_resync_client(
|
||||
user_devices = await self._user_device_resync_client(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
@@ -227,7 +225,7 @@ class E2eKeysHandler(object):
|
||||
destination_query.pop(user_id)
|
||||
|
||||
try:
|
||||
remote_result = yield self.federation.query_client_keys(
|
||||
remote_result = await self.federation.query_client_keys(
|
||||
destination, {"device_keys": destination_query}, timeout=timeout
|
||||
)
|
||||
|
||||
@@ -251,7 +249,7 @@ class E2eKeysHandler(object):
|
||||
set_tag("error", True)
|
||||
set_tag("reason", failure)
|
||||
|
||||
yield make_deferred_yieldable(
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(do_remote_query, destination)
|
||||
@@ -267,8 +265,7 @@ class E2eKeysHandler(object):
|
||||
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_cross_signing_keys_from_cache(self, query, from_user_id):
|
||||
async def get_cross_signing_keys_from_cache(self, query, from_user_id):
|
||||
"""Get cross-signing keys for users from the database
|
||||
|
||||
Args:
|
||||
@@ -289,7 +286,7 @@ class E2eKeysHandler(object):
|
||||
|
||||
user_ids = list(query)
|
||||
|
||||
keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
|
||||
keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
|
||||
|
||||
for user_id, user_info in keys.items():
|
||||
if user_info is None:
|
||||
@@ -315,8 +312,7 @@ class E2eKeysHandler(object):
|
||||
}
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def query_local_devices(self, query):
|
||||
async def query_local_devices(self, query):
|
||||
"""Get E2E device keys for local users
|
||||
|
||||
Args:
|
||||
@@ -354,7 +350,7 @@ class E2eKeysHandler(object):
|
||||
# make sure that each queried user appears in the result dict
|
||||
result_dict[user_id] = {}
|
||||
|
||||
results = yield self.store.get_e2e_device_keys(local_query)
|
||||
results = await self.store.get_e2e_device_keys(local_query)
|
||||
|
||||
# Build the result structure
|
||||
for user_id, device_keys in results.items():
|
||||
@@ -364,16 +360,15 @@ class E2eKeysHandler(object):
|
||||
log_kv(results)
|
||||
return result_dict
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_federation_query_client_keys(self, query_body):
|
||||
async def on_federation_query_client_keys(self, query_body):
|
||||
""" Handle a device key query from a federated server
|
||||
"""
|
||||
device_keys_query = query_body.get("device_keys", {})
|
||||
res = yield self.query_local_devices(device_keys_query)
|
||||
res = await self.query_local_devices(device_keys_query)
|
||||
ret = {"device_keys": res}
|
||||
|
||||
# add in the cross-signing keys
|
||||
cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
|
||||
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
|
||||
device_keys_query, None
|
||||
)
|
||||
|
||||
@@ -382,8 +377,7 @@ class E2eKeysHandler(object):
|
||||
return ret
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def claim_one_time_keys(self, query, timeout):
|
||||
async def claim_one_time_keys(self, query, timeout):
|
||||
local_query = []
|
||||
remote_queries = {}
|
||||
|
||||
@@ -399,7 +393,7 @@ class E2eKeysHandler(object):
|
||||
set_tag("local_key_query", local_query)
|
||||
set_tag("remote_key_query", remote_queries)
|
||||
|
||||
results = yield self.store.claim_e2e_one_time_keys(local_query)
|
||||
results = await self.store.claim_e2e_one_time_keys(local_query)
|
||||
|
||||
json_result = {}
|
||||
failures = {}
|
||||
@@ -411,12 +405,11 @@ class E2eKeysHandler(object):
|
||||
}
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def claim_client_keys(destination):
|
||||
async def claim_client_keys(destination):
|
||||
set_tag("destination", destination)
|
||||
device_keys = remote_queries[destination]
|
||||
try:
|
||||
remote_result = yield self.federation.claim_client_keys(
|
||||
remote_result = await self.federation.claim_client_keys(
|
||||
destination, {"one_time_keys": device_keys}, timeout=timeout
|
||||
)
|
||||
for user_id, keys in remote_result["one_time_keys"].items():
|
||||
@@ -429,7 +422,7 @@ class E2eKeysHandler(object):
|
||||
set_tag("error", True)
|
||||
set_tag("reason", failure)
|
||||
|
||||
yield make_deferred_yieldable(
|
||||
await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(claim_client_keys, destination)
|
||||
@@ -454,9 +447,8 @@ class E2eKeysHandler(object):
|
||||
log_kv({"one_time_keys": json_result, "failures": failures})
|
||||
return {"one_time_keys": json_result, "failures": failures}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@tag_args
|
||||
def upload_keys_for_user(self, user_id, device_id, keys):
|
||||
async def upload_keys_for_user(self, user_id, device_id, keys):
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
@@ -477,12 +469,12 @@ class E2eKeysHandler(object):
|
||||
}
|
||||
)
|
||||
# TODO: Sign the JSON with the server key
|
||||
changed = yield self.store.set_e2e_device_keys(
|
||||
changed = await self.store.set_e2e_device_keys(
|
||||
user_id, device_id, time_now, device_keys
|
||||
)
|
||||
if changed:
|
||||
# Only notify about device updates *if* the keys actually changed
|
||||
yield self.device_handler.notify_device_update(user_id, [device_id])
|
||||
await self.device_handler.notify_device_update(user_id, [device_id])
|
||||
else:
|
||||
log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
|
||||
one_time_keys = keys.get("one_time_keys", None)
|
||||
@@ -494,7 +486,7 @@ class E2eKeysHandler(object):
|
||||
"device_id": device_id,
|
||||
}
|
||||
)
|
||||
yield self._upload_one_time_keys_for_user(
|
||||
await self._upload_one_time_keys_for_user(
|
||||
user_id, device_id, time_now, one_time_keys
|
||||
)
|
||||
else:
|
||||
@@ -507,15 +499,14 @@ class E2eKeysHandler(object):
|
||||
# old access_token without an associated device_id. Either way, we
|
||||
# need to double-check the device is registered to avoid ending up with
|
||||
# keys without a corresponding device.
|
||||
yield self.device_handler.check_device_registered(user_id, device_id)
|
||||
await self.device_handler.check_device_registered(user_id, device_id)
|
||||
|
||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
|
||||
set_tag("one_time_key_counts", result)
|
||||
return {"one_time_key_counts": result}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _upload_one_time_keys_for_user(
|
||||
async def _upload_one_time_keys_for_user(
|
||||
self, user_id, device_id, time_now, one_time_keys
|
||||
):
|
||||
logger.info(
|
||||
@@ -533,7 +524,7 @@ class E2eKeysHandler(object):
|
||||
key_list.append((algorithm, key_id, key_obj))
|
||||
|
||||
# First we check if we have already persisted any of the keys.
|
||||
existing_key_map = yield self.store.get_e2e_one_time_keys(
|
||||
existing_key_map = await self.store.get_e2e_one_time_keys(
|
||||
user_id, device_id, [k_id for _, k_id, _ in key_list]
|
||||
)
|
||||
|
||||
@@ -556,10 +547,9 @@ class E2eKeysHandler(object):
|
||||
)
|
||||
|
||||
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
|
||||
yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
|
||||
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def upload_signing_keys_for_user(self, user_id, keys):
|
||||
async def upload_signing_keys_for_user(self, user_id, keys):
|
||||
"""Upload signing keys for cross-signing
|
||||
|
||||
Args:
|
||||
@@ -574,7 +564,7 @@ class E2eKeysHandler(object):
|
||||
|
||||
_check_cross_signing_key(master_key, user_id, "master")
|
||||
else:
|
||||
master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
|
||||
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
|
||||
|
||||
# if there is no master key, then we can't do anything, because all the
|
||||
# other cross-signing keys need to be signed by the master key
|
||||
@@ -613,10 +603,10 @@ class E2eKeysHandler(object):
|
||||
# if everything checks out, then store the keys and send notifications
|
||||
deviceids = []
|
||||
if "master_key" in keys:
|
||||
yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
|
||||
await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
|
||||
deviceids.append(master_verify_key.version)
|
||||
if "self_signing_key" in keys:
|
||||
yield self.store.set_e2e_cross_signing_key(
|
||||
await self.store.set_e2e_cross_signing_key(
|
||||
user_id, "self_signing", self_signing_key
|
||||
)
|
||||
try:
|
||||
@@ -626,23 +616,22 @@ class E2eKeysHandler(object):
|
||||
except ValueError:
|
||||
raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM)
|
||||
if "user_signing_key" in keys:
|
||||
yield self.store.set_e2e_cross_signing_key(
|
||||
await self.store.set_e2e_cross_signing_key(
|
||||
user_id, "user_signing", user_signing_key
|
||||
)
|
||||
# the signature stream matches the semantics that we want for
|
||||
# user-signing key updates: only the user themselves is notified of
|
||||
# their own user-signing key updates
|
||||
yield self.device_handler.notify_user_signature_update(user_id, [user_id])
|
||||
await self.device_handler.notify_user_signature_update(user_id, [user_id])
|
||||
|
||||
# master key and self-signing key updates match the semantics of device
|
||||
# list updates: all users who share an encrypted room are notified
|
||||
if len(deviceids):
|
||||
yield self.device_handler.notify_device_update(user_id, deviceids)
|
||||
await self.device_handler.notify_device_update(user_id, deviceids)
|
||||
|
||||
return {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def upload_signatures_for_device_keys(self, user_id, signatures):
|
||||
async def upload_signatures_for_device_keys(self, user_id, signatures):
|
||||
"""Upload device signatures for cross-signing
|
||||
|
||||
Args:
|
||||
@@ -667,13 +656,13 @@ class E2eKeysHandler(object):
|
||||
self_signatures = signatures.get(user_id, {})
|
||||
other_signatures = {k: v for k, v in signatures.items() if k != user_id}
|
||||
|
||||
self_signature_list, self_failures = yield self._process_self_signatures(
|
||||
self_signature_list, self_failures = await self._process_self_signatures(
|
||||
user_id, self_signatures
|
||||
)
|
||||
signature_list.extend(self_signature_list)
|
||||
failures.update(self_failures)
|
||||
|
||||
other_signature_list, other_failures = yield self._process_other_signatures(
|
||||
other_signature_list, other_failures = await self._process_other_signatures(
|
||||
user_id, other_signatures
|
||||
)
|
||||
signature_list.extend(other_signature_list)
|
||||
@@ -681,21 +670,20 @@ class E2eKeysHandler(object):
|
||||
|
||||
# store the signature, and send the appropriate notifications for sync
|
||||
logger.debug("upload signature failures: %r", failures)
|
||||
yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
|
||||
await self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
|
||||
|
||||
self_device_ids = [item.target_device_id for item in self_signature_list]
|
||||
if self_device_ids:
|
||||
yield self.device_handler.notify_device_update(user_id, self_device_ids)
|
||||
await self.device_handler.notify_device_update(user_id, self_device_ids)
|
||||
signed_users = [item.target_user_id for item in other_signature_list]
|
||||
if signed_users:
|
||||
yield self.device_handler.notify_user_signature_update(
|
||||
await self.device_handler.notify_user_signature_update(
|
||||
user_id, signed_users
|
||||
)
|
||||
|
||||
return {"failures": failures}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _process_self_signatures(self, user_id, signatures):
|
||||
async def _process_self_signatures(self, user_id, signatures):
|
||||
"""Process uploaded signatures of the user's own keys.
|
||||
|
||||
Signatures of the user's own keys from this API come in two forms:
|
||||
@@ -728,7 +716,7 @@ class E2eKeysHandler(object):
|
||||
_,
|
||||
self_signing_key_id,
|
||||
self_signing_verify_key,
|
||||
) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
|
||||
) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
|
||||
|
||||
# get our master key, since we may have received a signature of it.
|
||||
# We need to fetch it here so that we know what its key ID is, so
|
||||
@@ -738,12 +726,12 @@ class E2eKeysHandler(object):
|
||||
master_key,
|
||||
_,
|
||||
master_verify_key,
|
||||
) = yield self._get_e2e_cross_signing_verify_key(user_id, "master")
|
||||
) = await self._get_e2e_cross_signing_verify_key(user_id, "master")
|
||||
|
||||
# fetch our stored devices. This is used to 1. verify
|
||||
# signatures on the master key, and 2. to compare with what
|
||||
# was sent if the device was signed
|
||||
devices = yield self.store.get_e2e_device_keys([(user_id, None)])
|
||||
devices = await self.store.get_e2e_device_keys([(user_id, None)])
|
||||
|
||||
if user_id not in devices:
|
||||
raise NotFoundError("No device keys found")
|
||||
@@ -853,8 +841,7 @@ class E2eKeysHandler(object):
|
||||
|
||||
return master_key_signature_list
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _process_other_signatures(self, user_id, signatures):
|
||||
async def _process_other_signatures(self, user_id, signatures):
|
||||
"""Process uploaded signatures of other users' keys. These will be the
|
||||
target user's master keys, signed by the uploading user's user-signing
|
||||
key.
|
||||
@@ -882,7 +869,7 @@ class E2eKeysHandler(object):
|
||||
user_signing_key,
|
||||
user_signing_key_id,
|
||||
user_signing_verify_key,
|
||||
) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
|
||||
) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
|
||||
except SynapseError as e:
|
||||
failure = _exception_to_failure(e)
|
||||
for user, devicemap in signatures.items():
|
||||
@@ -905,7 +892,7 @@ class E2eKeysHandler(object):
|
||||
master_key,
|
||||
master_key_id,
|
||||
_,
|
||||
) = yield self._get_e2e_cross_signing_verify_key(
|
||||
) = await self._get_e2e_cross_signing_verify_key(
|
||||
target_user, "master", user_id
|
||||
)
|
||||
|
||||
@@ -958,8 +945,7 @@ class E2eKeysHandler(object):
|
||||
|
||||
return signature_list, failures
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_e2e_cross_signing_verify_key(
|
||||
async def _get_e2e_cross_signing_verify_key(
|
||||
self, user_id: str, key_type: str, from_user_id: str = None
|
||||
):
|
||||
"""Fetch locally or remotely query for a cross-signing public key.
|
||||
@@ -983,7 +969,7 @@ class E2eKeysHandler(object):
|
||||
SynapseError: if `user_id` is invalid
|
||||
"""
|
||||
user = UserID.from_string(user_id)
|
||||
key = yield self.store.get_e2e_cross_signing_key(
|
||||
key = await self.store.get_e2e_cross_signing_key(
|
||||
user_id, key_type, from_user_id
|
||||
)
|
||||
|
||||
@@ -1009,15 +995,14 @@ class E2eKeysHandler(object):
|
||||
key,
|
||||
key_id,
|
||||
verify_key,
|
||||
) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
|
||||
) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
|
||||
|
||||
if key is None:
|
||||
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
|
||||
|
||||
return key, key_id, verify_key
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _retrieve_cross_signing_keys_for_remote_user(
|
||||
async def _retrieve_cross_signing_keys_for_remote_user(
|
||||
self, user: UserID, desired_key_type: str,
|
||||
):
|
||||
"""Queries cross-signing keys for a remote user and saves them to the database
|
||||
@@ -1035,7 +1020,7 @@ class E2eKeysHandler(object):
|
||||
If the key cannot be retrieved, all values in the tuple will instead be None.
|
||||
"""
|
||||
try:
|
||||
remote_result = yield self.federation.query_user_devices(
|
||||
remote_result = await self.federation.query_user_devices(
|
||||
user.domain, user.to_string()
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -1101,14 +1086,14 @@ class E2eKeysHandler(object):
|
||||
desired_key_id = key_id
|
||||
|
||||
# At the same time, store this key in the db for subsequent queries
|
||||
yield self.store.set_e2e_cross_signing_key(
|
||||
await self.store.set_e2e_cross_signing_key(
|
||||
user.to_string(), key_type, key_content
|
||||
)
|
||||
|
||||
# Notify clients that new devices for this user have been discovered
|
||||
if retrieved_device_ids:
|
||||
# XXX is this necessary?
|
||||
yield self.device_handler.notify_device_update(
|
||||
await self.device_handler.notify_device_update(
|
||||
user.to_string(), retrieved_device_ids
|
||||
)
|
||||
|
||||
@@ -1250,8 +1235,7 @@ class SigningKeyEduUpdater(object):
|
||||
iterable=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def incoming_signing_key_update(self, origin, edu_content):
|
||||
async def incoming_signing_key_update(self, origin, edu_content):
|
||||
"""Called on incoming signing key update from federation. Responsible for
|
||||
parsing the EDU and adding to pending updates list.
|
||||
|
||||
@@ -1268,7 +1252,7 @@ class SigningKeyEduUpdater(object):
|
||||
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
|
||||
return
|
||||
|
||||
room_ids = yield self.store.get_rooms_for_user(user_id)
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
if not room_ids:
|
||||
# We don't share any rooms with this user. Ignore update, as we
|
||||
# probably won't get any further updates.
|
||||
@@ -1278,10 +1262,9 @@ class SigningKeyEduUpdater(object):
|
||||
(master_key, self_signing_key)
|
||||
)
|
||||
|
||||
yield self._handle_signing_key_updates(user_id)
|
||||
await self._handle_signing_key_updates(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_signing_key_updates(self, user_id):
|
||||
async def _handle_signing_key_updates(self, user_id):
|
||||
"""Actually handle pending updates.
|
||||
|
||||
Args:
|
||||
@@ -1291,7 +1274,7 @@ class SigningKeyEduUpdater(object):
|
||||
device_handler = self.e2e_keys_handler.device_handler
|
||||
device_list_updater = device_handler.device_list_updater
|
||||
|
||||
with (yield self._remote_edu_linearizer.queue(user_id)):
|
||||
with (await self._remote_edu_linearizer.queue(user_id)):
|
||||
pending_updates = self._pending_updates.pop(user_id, [])
|
||||
if not pending_updates:
|
||||
# This can happen since we batch updates
|
||||
@@ -1302,9 +1285,9 @@ class SigningKeyEduUpdater(object):
|
||||
logger.info("pending updates: %r", pending_updates)
|
||||
|
||||
for master_key, self_signing_key in pending_updates:
|
||||
new_device_ids = yield device_list_updater.process_cross_signing_key_update(
|
||||
new_device_ids = await device_list_updater.process_cross_signing_key_update(
|
||||
user_id, master_key, self_signing_key,
|
||||
)
|
||||
device_ids = device_ids + new_device_ids
|
||||
|
||||
yield device_handler.notify_device_update(user_id, device_ids)
|
||||
await device_handler.notify_device_update(user_id, device_ids)
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import (
|
||||
Codes,
|
||||
NotFoundError,
|
||||
@@ -50,8 +48,7 @@ class E2eRoomKeysHandler(object):
|
||||
self._upload_linearizer = Linearizer("upload_room_keys_lock")
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def get_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
|
||||
room, or a given session.
|
||||
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
|
||||
@@ -71,17 +68,17 @@ class E2eRoomKeysHandler(object):
|
||||
|
||||
# we deliberately take the lock to get keys so that changing the version
|
||||
# works atomically
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
# make sure the backup version exists
|
||||
try:
|
||||
yield self.store.get_e2e_room_keys_version_info(user_id, version)
|
||||
await self.store.get_e2e_room_keys_version_info(user_id, version)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise NotFoundError("Unknown backup version")
|
||||
else:
|
||||
raise
|
||||
|
||||
results = yield self.store.get_e2e_room_keys(
|
||||
results = await self.store.get_e2e_room_keys(
|
||||
user_id, version, room_id, session_id
|
||||
)
|
||||
|
||||
@@ -89,8 +86,7 @@ class E2eRoomKeysHandler(object):
|
||||
return results
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
|
||||
room or a given session.
|
||||
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
|
||||
@@ -109,10 +105,10 @@ class E2eRoomKeysHandler(object):
|
||||
"""
|
||||
|
||||
# lock for consistency with uploading
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
# make sure the backup version exists
|
||||
try:
|
||||
version_info = yield self.store.get_e2e_room_keys_version_info(
|
||||
version_info = await self.store.get_e2e_room_keys_version_info(
|
||||
user_id, version
|
||||
)
|
||||
except StoreError as e:
|
||||
@@ -121,19 +117,18 @@ class E2eRoomKeysHandler(object):
|
||||
else:
|
||||
raise
|
||||
|
||||
yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
|
||||
await self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
|
||||
|
||||
version_etag = version_info["etag"] + 1
|
||||
yield self.store.update_e2e_room_keys_version(
|
||||
await self.store.update_e2e_room_keys_version(
|
||||
user_id, version, None, version_etag
|
||||
)
|
||||
|
||||
count = yield self.store.count_e2e_room_keys(user_id, version)
|
||||
count = await self.store.count_e2e_room_keys(user_id, version)
|
||||
return {"etag": str(version_etag), "count": count}
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def upload_room_keys(self, user_id, version, room_keys):
|
||||
async def upload_room_keys(self, user_id, version, room_keys):
|
||||
"""Bulk upload a list of room keys into a given backup version, asserting
|
||||
that the given version is the current backup version. room_keys are merged
|
||||
into the current backup as described in RoomKeysServlet.on_PUT().
|
||||
@@ -169,11 +164,11 @@ class E2eRoomKeysHandler(object):
|
||||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
|
||||
# XXX: perhaps we should use a finer grained lock here?
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
|
||||
# Check that the version we're trying to upload is the current version
|
||||
try:
|
||||
version_info = yield self.store.get_e2e_room_keys_version_info(user_id)
|
||||
version_info = await self.store.get_e2e_room_keys_version_info(user_id)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise NotFoundError("Version '%s' not found" % (version,))
|
||||
@@ -183,7 +178,7 @@ class E2eRoomKeysHandler(object):
|
||||
if version_info["version"] != version:
|
||||
# Check that the version we're trying to upload actually exists
|
||||
try:
|
||||
version_info = yield self.store.get_e2e_room_keys_version_info(
|
||||
version_info = await self.store.get_e2e_room_keys_version_info(
|
||||
user_id, version
|
||||
)
|
||||
# if we get this far, the version must exist
|
||||
@@ -198,7 +193,7 @@ class E2eRoomKeysHandler(object):
|
||||
# submitted. Then compare them with the submitted keys. If the
|
||||
# key is new, insert it; if the key should be updated, then update
|
||||
# it; otherwise, drop it.
|
||||
existing_keys = yield self.store.get_e2e_room_keys_multi(
|
||||
existing_keys = await self.store.get_e2e_room_keys_multi(
|
||||
user_id, version, room_keys["rooms"]
|
||||
)
|
||||
to_insert = [] # batch the inserts together
|
||||
@@ -227,7 +222,7 @@ class E2eRoomKeysHandler(object):
|
||||
# updates are done one at a time in the DB, so send
|
||||
# updates right away rather than batching them up,
|
||||
# like we do with the inserts
|
||||
yield self.store.update_e2e_room_key(
|
||||
await self.store.update_e2e_room_key(
|
||||
user_id, version, room_id, session_id, room_key
|
||||
)
|
||||
changed = True
|
||||
@@ -246,16 +241,16 @@ class E2eRoomKeysHandler(object):
|
||||
changed = True
|
||||
|
||||
if len(to_insert):
|
||||
yield self.store.add_e2e_room_keys(user_id, version, to_insert)
|
||||
await self.store.add_e2e_room_keys(user_id, version, to_insert)
|
||||
|
||||
version_etag = version_info["etag"]
|
||||
if changed:
|
||||
version_etag = version_etag + 1
|
||||
yield self.store.update_e2e_room_keys_version(
|
||||
await self.store.update_e2e_room_keys_version(
|
||||
user_id, version, None, version_etag
|
||||
)
|
||||
|
||||
count = yield self.store.count_e2e_room_keys(user_id, version)
|
||||
count = await self.store.count_e2e_room_keys(user_id, version)
|
||||
return {"etag": str(version_etag), "count": count}
|
||||
|
||||
@staticmethod
|
||||
@@ -291,8 +286,7 @@ class E2eRoomKeysHandler(object):
|
||||
return True
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def create_version(self, user_id, version_info):
|
||||
async def create_version(self, user_id, version_info):
|
||||
"""Create a new backup version. This automatically becomes the new
|
||||
backup version for the user's keys; previous backups will no longer be
|
||||
writeable to.
|
||||
@@ -313,14 +307,13 @@ class E2eRoomKeysHandler(object):
|
||||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
|
||||
# lock everyone out until we've switched version
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
new_version = yield self.store.create_e2e_room_keys_version(
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
new_version = await self.store.create_e2e_room_keys_version(
|
||||
user_id, version_info
|
||||
)
|
||||
return new_version
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_version_info(self, user_id, version=None):
|
||||
async def get_version_info(self, user_id, version=None):
|
||||
"""Get the info about a given version of the user's backup
|
||||
|
||||
Args:
|
||||
@@ -339,22 +332,21 @@ class E2eRoomKeysHandler(object):
|
||||
}
|
||||
"""
|
||||
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
try:
|
||||
res = yield self.store.get_e2e_room_keys_version_info(user_id, version)
|
||||
res = await self.store.get_e2e_room_keys_version_info(user_id, version)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise NotFoundError("Unknown backup version")
|
||||
else:
|
||||
raise
|
||||
|
||||
res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"])
|
||||
res["count"] = await self.store.count_e2e_room_keys(user_id, res["version"])
|
||||
res["etag"] = str(res["etag"])
|
||||
return res
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def delete_version(self, user_id, version=None):
|
||||
async def delete_version(self, user_id, version=None):
|
||||
"""Deletes a given version of the user's e2e_room_keys backup
|
||||
|
||||
Args:
|
||||
@@ -364,9 +356,9 @@ class E2eRoomKeysHandler(object):
|
||||
NotFoundError: if this backup version doesn't exist
|
||||
"""
|
||||
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
try:
|
||||
yield self.store.delete_e2e_room_keys_version(user_id, version)
|
||||
await self.store.delete_e2e_room_keys_version(user_id, version)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise NotFoundError("Unknown backup version")
|
||||
@@ -374,8 +366,7 @@ class E2eRoomKeysHandler(object):
|
||||
raise
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def update_version(self, user_id, version, version_info):
|
||||
async def update_version(self, user_id, version, version_info):
|
||||
"""Update the info about a given version of the user's backup
|
||||
|
||||
Args:
|
||||
@@ -393,9 +384,9 @@ class E2eRoomKeysHandler(object):
|
||||
raise SynapseError(
|
||||
400, "Version in body does not match", Codes.INVALID_PARAM
|
||||
)
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
with (await self._upload_linearizer.queue(user_id)):
|
||||
try:
|
||||
old_info = yield self.store.get_e2e_room_keys_version_info(
|
||||
old_info = await self.store.get_e2e_room_keys_version_info(
|
||||
user_id, version
|
||||
)
|
||||
except StoreError as e:
|
||||
@@ -406,7 +397,7 @@ class E2eRoomKeysHandler(object):
|
||||
if old_info["algorithm"] != version_info["algorithm"]:
|
||||
raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM)
|
||||
|
||||
yield self.store.update_e2e_room_keys_version(
|
||||
await self.store.update_e2e_room_keys_version(
|
||||
user_id, version, version_info
|
||||
)
|
||||
|
||||
|
||||
@@ -185,7 +185,7 @@ class TypingHandler(object):
|
||||
|
||||
async def _push_remote(self, member, typing):
|
||||
try:
|
||||
users = await self.state.get_current_users_in_room(member.room_id)
|
||||
users = await self.store.get_users_in_room(member.room_id)
|
||||
self._member_last_federation_poke[member] = self.clock.time_msec()
|
||||
|
||||
now = self.clock.time_msec()
|
||||
@@ -224,7 +224,7 @@ class TypingHandler(object):
|
||||
)
|
||||
return
|
||||
|
||||
users = await self.state.get_current_users_in_room(room_id)
|
||||
users = await self.store.get_users_in_room(room_id)
|
||||
domains = {get_domain_from_id(u) for u in users}
|
||||
|
||||
if self.server_name in domains:
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from canonicaljson import json
|
||||
@@ -117,7 +118,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted is silly
|
||||
data = pde.response
|
||||
resp_body = json.loads(data)
|
||||
resp_body = json.loads(data.decode("utf-8"))
|
||||
|
||||
if "success" in resp_body:
|
||||
# Note that we do NOT check the hostname here: we explicitly
|
||||
|
||||
@@ -13,13 +13,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
from io import BytesIO
|
||||
|
||||
import treq
|
||||
from canonicaljson import encode_canonical_json
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
from netaddr import IPAddress
|
||||
from prometheus_client import Counter
|
||||
from zope.interface import implementer, provider
|
||||
@@ -31,6 +31,7 @@ from twisted.internet.interfaces import (
|
||||
IReactorPluggableNameResolver,
|
||||
IResolutionReceiver,
|
||||
)
|
||||
from twisted.internet.task import Cooperator
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web._newclient import ResponseDone
|
||||
from twisted.web.client import Agent, HTTPConnectionPool, readBody
|
||||
@@ -69,6 +70,21 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
|
||||
return False
|
||||
|
||||
|
||||
_EPSILON = 0.00000001
|
||||
|
||||
|
||||
def _make_scheduler(reactor):
|
||||
"""Makes a schedular suitable for a Cooperator using the given reactor.
|
||||
|
||||
(This is effectively just a copy from `twisted.internet.task`)
|
||||
"""
|
||||
|
||||
def _scheduler(x):
|
||||
return reactor.callLater(_EPSILON, x)
|
||||
|
||||
return _scheduler
|
||||
|
||||
|
||||
class IPBlacklistingResolver(object):
|
||||
"""
|
||||
A proxy for reactor.nameResolver which only produces non-blacklisted IP
|
||||
@@ -212,6 +228,10 @@ class SimpleHttpClient(object):
|
||||
if hs.config.user_agent_suffix:
|
||||
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
|
||||
|
||||
# We use this for our body producers to ensure that they use the correct
|
||||
# reactor.
|
||||
self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))
|
||||
|
||||
self.user_agent = self.user_agent.encode("ascii")
|
||||
|
||||
if self._ip_blacklist:
|
||||
@@ -292,7 +312,9 @@ class SimpleHttpClient(object):
|
||||
try:
|
||||
body_producer = None
|
||||
if data is not None:
|
||||
body_producer = QuieterFileBodyProducer(BytesIO(data))
|
||||
body_producer = QuieterFileBodyProducer(
|
||||
BytesIO(data), cooperator=self._cooperator,
|
||||
)
|
||||
|
||||
request_deferred = treq.request(
|
||||
method,
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
""" This module contains base REST classes for constructing REST servlets. """
|
||||
import json
|
||||
|
||||
import logging
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -239,6 +239,15 @@ class UserRestServletV2(RestServlet):
|
||||
await self.deactivate_account_handler.deactivate_account(
|
||||
target_user.to_string(), False
|
||||
)
|
||||
elif not deactivate and user["deactivated"]:
|
||||
if "password" not in body:
|
||||
raise SynapseError(
|
||||
400, "Must provide a password to re-activate an account."
|
||||
)
|
||||
|
||||
await self.deactivate_account_handler.activate_account(
|
||||
target_user.to_string()
|
||||
)
|
||||
|
||||
user = await self.admin_handler.get_user(target_user)
|
||||
return 200, user
|
||||
@@ -254,7 +263,6 @@ class UserRestServletV2(RestServlet):
|
||||
admin = body.get("admin", None)
|
||||
user_type = body.get("user_type", None)
|
||||
displayname = body.get("displayname", None)
|
||||
threepids = body.get("threepids", None)
|
||||
|
||||
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
|
||||
raise SynapseError(400, "Invalid user type")
|
||||
|
||||
@@ -371,7 +371,7 @@ class LoginRestServlet(RestServlet):
|
||||
token = login_submission.get("token", None)
|
||||
if token is None:
|
||||
raise LoginError(
|
||||
401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
|
||||
403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
import jwt
|
||||
@@ -387,14 +387,12 @@ class LoginRestServlet(RestServlet):
|
||||
except jwt.PyJWTError as e:
|
||||
# A JWT error occurred, return some info back to the client.
|
||||
raise LoginError(
|
||||
401,
|
||||
"JWT validation failed: %s" % (str(e),),
|
||||
errcode=Codes.UNAUTHORIZED,
|
||||
403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
user = payload.get("sub", None)
|
||||
if user is None:
|
||||
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
||||
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
|
||||
|
||||
user_id = UserID(user, self.hs.hostname).to_string()
|
||||
result = await self._complete_login(
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
@@ -515,9 +516,9 @@ class RoomMessageListRestServlet(RestServlet):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(request, default_limit=10)
|
||||
as_client_event = b"raw" not in request.args
|
||||
filter_bytes = parse_string(request, b"filter", encoding=None)
|
||||
if filter_bytes:
|
||||
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
|
||||
filter_str = parse_string(request, b"filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
|
||||
if (
|
||||
event_filter
|
||||
@@ -627,9 +628,9 @@ class RoomEventContextServlet(RestServlet):
|
||||
limit = parse_integer(request, "limit", default=10)
|
||||
|
||||
# picking the API shape for symmetry with /messages
|
||||
filter_bytes = parse_string(request, "filter")
|
||||
if filter_bytes:
|
||||
filter_json = urlparse.unquote(filter_bytes)
|
||||
filter_str = parse_string(request, b"filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
|
||||
else:
|
||||
event_filter = None
|
||||
|
||||
@@ -202,9 +202,11 @@ class RemoteKey(DirectServeJsonResource):
|
||||
|
||||
if miss:
|
||||
cache_misses.setdefault(server_name, set()).add(key_id)
|
||||
# Cast to bytes since postgresql returns a memoryview.
|
||||
json_results.add(bytes(most_recent_result["key_json"]))
|
||||
else:
|
||||
for ts_added, result in results:
|
||||
# Cast to bytes since postgresql returns a memoryview.
|
||||
json_results.add(bytes(result["key_json"]))
|
||||
|
||||
if cache_misses and query_remote_on_cache_miss:
|
||||
@@ -213,7 +215,7 @@ class RemoteKey(DirectServeJsonResource):
|
||||
else:
|
||||
signed_keys = []
|
||||
for key_json in json_results:
|
||||
key_json = json.loads(key_json)
|
||||
key_json = json.loads(key_json.decode("utf-8"))
|
||||
for signing_key in self.config.key_server_signing_keys:
|
||||
key_json = sign_json(key_json, self.config.server_name, signing_key)
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ from synapse.server_notices.worker_server_notices_sender import (
|
||||
WorkerServerNoticesSender,
|
||||
)
|
||||
from synapse.state import StateHandler, StateResolutionHandler
|
||||
from synapse.storage import DataStores, Storage
|
||||
from synapse.storage import DataStore, DataStores, Storage
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.util import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
@@ -314,7 +314,7 @@ class HomeServer(object):
|
||||
def get_clock(self):
|
||||
return self.clock
|
||||
|
||||
def get_datastore(self):
|
||||
def get_datastore(self) -> DataStore:
|
||||
return self.datastores.main
|
||||
|
||||
def get_datastores(self):
|
||||
|
||||
@@ -20,6 +20,7 @@ import synapse.handlers.room
|
||||
import synapse.handlers.room_member
|
||||
import synapse.handlers.set_password
|
||||
import synapse.http.client
|
||||
import synapse.http.matrixfederationclient
|
||||
import synapse.notifier
|
||||
import synapse.push.pusherpool
|
||||
import synapse.replication.tcp.client
|
||||
@@ -143,3 +144,7 @@ class HomeServer(object):
|
||||
pass
|
||||
def get_replication_streams(self) -> Dict[str, Stream]:
|
||||
pass
|
||||
def get_http_client(
|
||||
self,
|
||||
) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient:
|
||||
pass
|
||||
|
||||
@@ -17,7 +17,6 @@
|
||||
import itertools
|
||||
import logging
|
||||
from collections import OrderedDict, namedtuple
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
|
||||
|
||||
import attr
|
||||
@@ -69,27 +68,6 @@ def encode_json(json_object):
|
||||
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
|
||||
|
||||
|
||||
def _retry_on_integrity_error(func):
|
||||
"""Wraps a database function so that it gets retried on IntegrityError,
|
||||
with `delete_existing=True` passed in.
|
||||
|
||||
Args:
|
||||
func: function that returns a Deferred and accepts a `delete_existing` arg
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
@defer.inlineCallbacks
|
||||
def f(self, *args, **kwargs):
|
||||
try:
|
||||
res = yield func(self, *args, delete_existing=False, **kwargs)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
logger.exception("IntegrityError, retrying.")
|
||||
res = yield func(self, *args, delete_existing=True, **kwargs)
|
||||
return res
|
||||
|
||||
return f
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class DeltaState:
|
||||
"""Deltas to use to update the `current_state_events` table.
|
||||
@@ -134,7 +112,6 @@ class PersistEventsStore:
|
||||
hs.config.worker.writers.events == hs.get_instance_name()
|
||||
), "Can only instantiate EventsStore on master"
|
||||
|
||||
@_retry_on_integrity_error
|
||||
@defer.inlineCallbacks
|
||||
def _persist_events_and_state_updates(
|
||||
self,
|
||||
@@ -143,7 +120,6 @@ class PersistEventsStore:
|
||||
state_delta_for_room: Dict[str, DeltaState],
|
||||
new_forward_extremeties: Dict[str, List[str]],
|
||||
backfilled: bool = False,
|
||||
delete_existing: bool = False,
|
||||
):
|
||||
"""Persist a set of events alongside updates to the current state and
|
||||
forward extremities tables.
|
||||
@@ -157,7 +133,6 @@ class PersistEventsStore:
|
||||
new_forward_extremities: Map from room_id to list of event IDs
|
||||
that are the new forward extremities of the room.
|
||||
backfilled
|
||||
delete_existing
|
||||
|
||||
Returns:
|
||||
Deferred: resolves when the events have been persisted
|
||||
@@ -197,7 +172,6 @@ class PersistEventsStore:
|
||||
self._persist_events_txn,
|
||||
events_and_contexts=events_and_contexts,
|
||||
backfilled=backfilled,
|
||||
delete_existing=delete_existing,
|
||||
state_delta_for_room=state_delta_for_room,
|
||||
new_forward_extremeties=new_forward_extremeties,
|
||||
)
|
||||
@@ -341,7 +315,6 @@ class PersistEventsStore:
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
backfilled: bool,
|
||||
delete_existing: bool = False,
|
||||
state_delta_for_room: Dict[str, DeltaState] = {},
|
||||
new_forward_extremeties: Dict[str, List[str]] = {},
|
||||
):
|
||||
@@ -393,13 +366,6 @@ class PersistEventsStore:
|
||||
# From this point onwards the events are only events that we haven't
|
||||
# seen before.
|
||||
|
||||
if delete_existing:
|
||||
# For paranoia reasons, we go and delete all the existing entries
|
||||
# for these events so we can reinsert them.
|
||||
# This gets around any problems with some tables already having
|
||||
# entries.
|
||||
self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts)
|
||||
|
||||
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
|
||||
|
||||
# Insert into event_to_state_groups.
|
||||
@@ -797,39 +763,6 @@ class PersistEventsStore:
|
||||
|
||||
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
|
||||
|
||||
@classmethod
|
||||
def _delete_existing_rows_txn(cls, txn, events_and_contexts):
|
||||
if not events_and_contexts:
|
||||
# nothing to do here
|
||||
return
|
||||
|
||||
logger.info("Deleting existing")
|
||||
|
||||
for table in (
|
||||
"events",
|
||||
"event_auth",
|
||||
"event_json",
|
||||
"event_edges",
|
||||
"event_forward_extremities",
|
||||
"event_reference_hashes",
|
||||
"event_search",
|
||||
"event_to_state_groups",
|
||||
"state_events",
|
||||
"rejections",
|
||||
"redactions",
|
||||
"room_memberships",
|
||||
):
|
||||
txn.executemany(
|
||||
"DELETE FROM %s WHERE event_id = ?" % (table,),
|
||||
[(ev.event_id,) for ev, _ in events_and_contexts],
|
||||
)
|
||||
|
||||
for table in ("event_push_actions",):
|
||||
txn.executemany(
|
||||
"DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
|
||||
[(ev.room_id, ev.event_id) for ev, _ in events_and_contexts],
|
||||
)
|
||||
|
||||
def _store_event_txn(self, txn, events_and_contexts):
|
||||
"""Insert new events into the event and event_json tables
|
||||
|
||||
|
||||
@@ -353,6 +353,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
||||
last_room_id = progress.get("last_room_id", "")
|
||||
|
||||
def _background_remove_left_rooms_txn(txn):
|
||||
# get a batch of room ids to consider
|
||||
sql = """
|
||||
SELECT DISTINCT room_id FROM current_state_events
|
||||
WHERE room_id > ? ORDER BY room_id LIMIT ?
|
||||
@@ -363,24 +364,68 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
||||
if not room_ids:
|
||||
return True, set()
|
||||
|
||||
###########################################################################
|
||||
#
|
||||
# exclude rooms where we have active members
|
||||
|
||||
sql = """
|
||||
SELECT room_id
|
||||
FROM current_state_events
|
||||
FROM local_current_membership
|
||||
WHERE
|
||||
room_id > ? AND room_id <= ?
|
||||
AND type = 'm.room.member'
|
||||
AND membership = 'join'
|
||||
AND state_key LIKE ?
|
||||
GROUP BY room_id
|
||||
"""
|
||||
|
||||
txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
|
||||
|
||||
txn.execute(sql, (last_room_id, room_ids[-1]))
|
||||
joined_room_ids = {row[0] for row in txn}
|
||||
to_delete = set(room_ids) - joined_room_ids
|
||||
|
||||
left_rooms = set(room_ids) - joined_room_ids
|
||||
###########################################################################
|
||||
#
|
||||
# exclude rooms which we are in the process of constructing; these otherwise
|
||||
# qualify as "rooms with no local users", and would have their
|
||||
# forward extremities cleaned up.
|
||||
|
||||
logger.info("Deleting current state left rooms: %r", left_rooms)
|
||||
# the following query will return a list of rooms which have forward
|
||||
# extremities that are *not* also the create event in the room - ie
|
||||
# those that are not being created currently.
|
||||
|
||||
sql = """
|
||||
SELECT DISTINCT efe.room_id
|
||||
FROM event_forward_extremities efe
|
||||
LEFT JOIN current_state_events cse ON
|
||||
cse.event_id = efe.event_id
|
||||
AND cse.type = 'm.room.create'
|
||||
AND cse.state_key = ''
|
||||
WHERE
|
||||
cse.event_id IS NULL
|
||||
AND efe.room_id > ? AND efe.room_id <= ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (last_room_id, room_ids[-1]))
|
||||
|
||||
# build a set of those rooms within `to_delete` that do not appear in
|
||||
# the above, leaving us with the rooms in `to_delete` that *are* being
|
||||
# created.
|
||||
creating_rooms = to_delete.difference(row[0] for row in txn)
|
||||
logger.info("skipping rooms which are being created: %s", creating_rooms)
|
||||
|
||||
# now remove the rooms being created from the list of those to delete.
|
||||
#
|
||||
# (we could have just taken the intersection of `to_delete` with the result
|
||||
# of the sql query, but it's useful to be able to log `creating_rooms`; and
|
||||
# having done so, it's quicker to remove the (few) creating rooms from
|
||||
# `to_delete` than it is to form the intersection with the (larger) list of
|
||||
# not-creating-rooms)
|
||||
|
||||
to_delete -= creating_rooms
|
||||
|
||||
###########################################################################
|
||||
#
|
||||
# now clear the state for the rooms
|
||||
|
||||
logger.info("Deleting current state left rooms: %r", to_delete)
|
||||
|
||||
# First we get all users that we still think were joined to the
|
||||
# room. This is so that we can mark those device lists as
|
||||
@@ -391,7 +436,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
||||
txn,
|
||||
table="current_state_events",
|
||||
column="room_id",
|
||||
iterable=left_rooms,
|
||||
iterable=to_delete,
|
||||
keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
|
||||
retcols=("state_key",),
|
||||
)
|
||||
@@ -403,7 +448,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
||||
txn,
|
||||
table="current_state_events",
|
||||
column="room_id",
|
||||
iterable=left_rooms,
|
||||
iterable=to_delete,
|
||||
keyvalues={},
|
||||
)
|
||||
|
||||
@@ -411,7 +456,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
|
||||
txn,
|
||||
table="event_forward_extremities",
|
||||
column="room_id",
|
||||
iterable=left_rooms,
|
||||
iterable=to_delete,
|
||||
keyvalues={},
|
||||
)
|
||||
|
||||
|
||||
@@ -46,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
"""If the user has no devices, we expect an empty list.
|
||||
"""
|
||||
local_user = "@boris:" + self.hs.hostname
|
||||
res = yield self.handler.query_local_devices({local_user: None})
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.query_local_devices({local_user: None})
|
||||
)
|
||||
self.assertDictEqual(res, {local_user: {}})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -60,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
"alg2:k3": {"key": "key3"},
|
||||
}
|
||||
|
||||
res = yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
|
||||
|
||||
# we should be able to change the signature without a problem
|
||||
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
|
||||
res = yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
|
||||
|
||||
@@ -84,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
"alg2:k3": {"key": "key3"},
|
||||
}
|
||||
|
||||
res = yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
|
||||
|
||||
try:
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
|
||||
)
|
||||
)
|
||||
self.fail("No error when changing string key")
|
||||
except errors.SynapseError:
|
||||
pass
|
||||
|
||||
try:
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
|
||||
)
|
||||
)
|
||||
self.fail("No error when replacing dict key with string")
|
||||
except errors.SynapseError:
|
||||
pass
|
||||
|
||||
try:
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id,
|
||||
{"one_time_keys": {"alg1:k1": {"key": "key"}}},
|
||||
)
|
||||
)
|
||||
self.fail("No error when replacing string key with dict")
|
||||
except errors.SynapseError:
|
||||
pass
|
||||
|
||||
try:
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id,
|
||||
{
|
||||
"one_time_keys": {
|
||||
"alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
|
||||
}
|
||||
},
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user,
|
||||
device_id,
|
||||
{
|
||||
"one_time_keys": {
|
||||
"alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
self.fail("No error when replacing dict key")
|
||||
except errors.SynapseError:
|
||||
@@ -133,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
device_id = "xyz"
|
||||
keys = {"alg1:k1": "key1"}
|
||||
|
||||
res = yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"one_time_keys": keys}
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
|
||||
|
||||
res2 = yield self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
res2 = yield defer.ensureDeferred(
|
||||
self.handler.claim_one_time_keys(
|
||||
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
res2,
|
||||
@@ -163,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
},
|
||||
}
|
||||
}
|
||||
yield self.handler.upload_signing_keys_for_user(local_user, keys1)
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_signing_keys_for_user(local_user, keys1)
|
||||
)
|
||||
|
||||
keys2 = {
|
||||
"master_key": {
|
||||
@@ -175,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
},
|
||||
}
|
||||
}
|
||||
yield self.handler.upload_signing_keys_for_user(local_user, keys2)
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_signing_keys_for_user(local_user, keys2)
|
||||
)
|
||||
|
||||
devices = yield self.handler.query_devices(
|
||||
{"device_keys": {local_user: []}}, 0, local_user
|
||||
devices = yield defer.ensureDeferred(
|
||||
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
|
||||
)
|
||||
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
||||
|
||||
@@ -215,7 +241,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
"nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
|
||||
"2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
|
||||
)
|
||||
yield self.handler.upload_signing_keys_for_user(local_user, keys1)
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_signing_keys_for_user(local_user, keys1)
|
||||
)
|
||||
|
||||
# upload two device keys, which will be signed later by the self-signing key
|
||||
device_key_1 = {
|
||||
@@ -245,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
"signatures": {local_user: {"ed25519:def": "base64+signature"}},
|
||||
}
|
||||
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, "abc", {"device_keys": device_key_1}
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, "abc", {"device_keys": device_key_1}
|
||||
)
|
||||
)
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, "def", {"device_keys": device_key_2}
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, "def", {"device_keys": device_key_2}
|
||||
)
|
||||
)
|
||||
|
||||
# sign the first device key and upload it
|
||||
del device_key_1["signatures"]
|
||||
sign.sign_json(device_key_1, local_user, signing_key)
|
||||
yield self.handler.upload_signatures_for_device_keys(
|
||||
local_user, {local_user: {"abc": device_key_1}}
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_signatures_for_device_keys(
|
||||
local_user, {local_user: {"abc": device_key_1}}
|
||||
)
|
||||
)
|
||||
|
||||
# sign the second device key and upload both device keys. The server
|
||||
@@ -264,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
# signature for it
|
||||
del device_key_2["signatures"]
|
||||
sign.sign_json(device_key_2, local_user, signing_key)
|
||||
yield self.handler.upload_signatures_for_device_keys(
|
||||
local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_signatures_for_device_keys(
|
||||
local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
|
||||
)
|
||||
)
|
||||
|
||||
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
|
||||
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
|
||||
devices = yield self.handler.query_devices(
|
||||
{"device_keys": {local_user: []}}, 0, local_user
|
||||
devices = yield defer.ensureDeferred(
|
||||
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
|
||||
)
|
||||
del devices["device_keys"][local_user]["abc"]["unsigned"]
|
||||
del devices["device_keys"][local_user]["def"]["unsigned"]
|
||||
@@ -292,7 +328,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
},
|
||||
}
|
||||
}
|
||||
yield self.handler.upload_signing_keys_for_user(local_user, keys1)
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_signing_keys_for_user(local_user, keys1)
|
||||
)
|
||||
|
||||
res = None
|
||||
try:
|
||||
@@ -305,7 +343,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
res = e.code
|
||||
self.assertEqual(res, 400)
|
||||
|
||||
res = yield self.handler.query_local_devices({local_user: None})
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.query_local_devices({local_user: None})
|
||||
)
|
||||
self.assertDictEqual(res, {local_user: {}})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -331,8 +371,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
"ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
|
||||
)
|
||||
|
||||
yield self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"device_keys": device_key}
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_keys_for_user(
|
||||
local_user, device_id, {"device_keys": device_key}
|
||||
)
|
||||
)
|
||||
|
||||
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
|
||||
@@ -372,7 +414,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
"user_signing_key": usersigning_key,
|
||||
"self_signing_key": selfsigning_key,
|
||||
}
|
||||
yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
|
||||
)
|
||||
|
||||
# set up another user with a master key. This user will be signed by
|
||||
# the first user
|
||||
@@ -384,76 +428,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
"usage": ["master"],
|
||||
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
|
||||
}
|
||||
yield self.handler.upload_signing_keys_for_user(
|
||||
other_user, {"master_key": other_master_key}
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_signing_keys_for_user(
|
||||
other_user, {"master_key": other_master_key}
|
||||
)
|
||||
)
|
||||
|
||||
# test various signature failures (see below)
|
||||
ret = yield self.handler.upload_signatures_for_device_keys(
|
||||
local_user,
|
||||
{
|
||||
local_user: {
|
||||
# fails because the signature is invalid
|
||||
# should fail with INVALID_SIGNATURE
|
||||
device_id: {
|
||||
"user_id": local_user,
|
||||
"device_id": device_id,
|
||||
"algorithms": [
|
||||
"m.olm.curve25519-aes-sha2",
|
||||
RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
|
||||
],
|
||||
"keys": {
|
||||
"curve25519:xyz": "curve25519+key",
|
||||
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
|
||||
"ed25519:xyz": device_pubkey,
|
||||
ret = yield defer.ensureDeferred(
|
||||
self.handler.upload_signatures_for_device_keys(
|
||||
local_user,
|
||||
{
|
||||
local_user: {
|
||||
# fails because the signature is invalid
|
||||
# should fail with INVALID_SIGNATURE
|
||||
device_id: {
|
||||
"user_id": local_user,
|
||||
"device_id": device_id,
|
||||
"algorithms": [
|
||||
"m.olm.curve25519-aes-sha2",
|
||||
RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
|
||||
],
|
||||
"keys": {
|
||||
"curve25519:xyz": "curve25519+key",
|
||||
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
|
||||
"ed25519:xyz": device_pubkey,
|
||||
},
|
||||
"signatures": {
|
||||
local_user: {
|
||||
"ed25519:" + selfsigning_pubkey: "something"
|
||||
}
|
||||
},
|
||||
},
|
||||
"signatures": {
|
||||
local_user: {"ed25519:" + selfsigning_pubkey: "something"}
|
||||
# fails because device is unknown
|
||||
# should fail with NOT_FOUND
|
||||
"unknown": {
|
||||
"user_id": local_user,
|
||||
"device_id": "unknown",
|
||||
"signatures": {
|
||||
local_user: {
|
||||
"ed25519:" + selfsigning_pubkey: "something"
|
||||
}
|
||||
},
|
||||
},
|
||||
# fails because the signature is invalid
|
||||
# should fail with INVALID_SIGNATURE
|
||||
master_pubkey: {
|
||||
"user_id": local_user,
|
||||
"usage": ["master"],
|
||||
"keys": {"ed25519:" + master_pubkey: master_pubkey},
|
||||
"signatures": {
|
||||
local_user: {"ed25519:" + device_pubkey: "something"}
|
||||
},
|
||||
},
|
||||
},
|
||||
# fails because device is unknown
|
||||
# should fail with NOT_FOUND
|
||||
"unknown": {
|
||||
"user_id": local_user,
|
||||
"device_id": "unknown",
|
||||
"signatures": {
|
||||
local_user: {"ed25519:" + selfsigning_pubkey: "something"}
|
||||
other_user: {
|
||||
# fails because the device is not the user's master-signing key
|
||||
# should fail with NOT_FOUND
|
||||
"unknown": {
|
||||
"user_id": other_user,
|
||||
"device_id": "unknown",
|
||||
"signatures": {
|
||||
local_user: {
|
||||
"ed25519:" + usersigning_pubkey: "something"
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
# fails because the signature is invalid
|
||||
# should fail with INVALID_SIGNATURE
|
||||
master_pubkey: {
|
||||
"user_id": local_user,
|
||||
"usage": ["master"],
|
||||
"keys": {"ed25519:" + master_pubkey: master_pubkey},
|
||||
"signatures": {
|
||||
local_user: {"ed25519:" + device_pubkey: "something"}
|
||||
other_master_pubkey: {
|
||||
# fails because the key doesn't match what the server has
|
||||
# should fail with UNKNOWN
|
||||
"user_id": other_user,
|
||||
"usage": ["master"],
|
||||
"keys": {
|
||||
"ed25519:" + other_master_pubkey: other_master_pubkey
|
||||
},
|
||||
"something": "random",
|
||||
"signatures": {
|
||||
local_user: {
|
||||
"ed25519:" + usersigning_pubkey: "something"
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
other_user: {
|
||||
# fails because the device is not the user's master-signing key
|
||||
# should fail with NOT_FOUND
|
||||
"unknown": {
|
||||
"user_id": other_user,
|
||||
"device_id": "unknown",
|
||||
"signatures": {
|
||||
local_user: {"ed25519:" + usersigning_pubkey: "something"}
|
||||
},
|
||||
},
|
||||
other_master_pubkey: {
|
||||
# fails because the key doesn't match what the server has
|
||||
# should fail with UNKNOWN
|
||||
"user_id": other_user,
|
||||
"usage": ["master"],
|
||||
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
|
||||
"something": "random",
|
||||
"signatures": {
|
||||
local_user: {"ed25519:" + usersigning_pubkey: "something"}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
user_failures = ret["failures"][local_user]
|
||||
@@ -478,19 +536,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
sign.sign_json(device_key, local_user, selfsigning_signing_key)
|
||||
sign.sign_json(master_key, local_user, device_signing_key)
|
||||
sign.sign_json(other_master_key, local_user, usersigning_signing_key)
|
||||
ret = yield self.handler.upload_signatures_for_device_keys(
|
||||
local_user,
|
||||
{
|
||||
local_user: {device_id: device_key, master_pubkey: master_key},
|
||||
other_user: {other_master_pubkey: other_master_key},
|
||||
},
|
||||
ret = yield defer.ensureDeferred(
|
||||
self.handler.upload_signatures_for_device_keys(
|
||||
local_user,
|
||||
{
|
||||
local_user: {device_id: device_key, master_pubkey: master_key},
|
||||
other_user: {other_master_pubkey: other_master_key},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(ret["failures"], {})
|
||||
|
||||
# fetch the signed keys/devices and make sure that the signatures are there
|
||||
ret = yield self.handler.query_devices(
|
||||
{"device_keys": {local_user: [], other_user: []}}, 0, local_user
|
||||
ret = yield defer.ensureDeferred(
|
||||
self.handler.query_devices(
|
||||
{"device_keys": {local_user: [], other_user: []}}, 0, local_user
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
|
||||
@@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
"""
|
||||
res = None
|
||||
try:
|
||||
yield self.handler.get_version_info(self.local_user)
|
||||
yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
|
||||
except errors.SynapseError as e:
|
||||
res = e.code
|
||||
self.assertEqual(res, 404)
|
||||
@@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
"""
|
||||
res = None
|
||||
try:
|
||||
yield self.handler.get_version_info(self.local_user, "bogus_version")
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.get_version_info(self.local_user, "bogus_version")
|
||||
)
|
||||
except errors.SynapseError as e:
|
||||
res = e.code
|
||||
self.assertEqual(res, 404)
|
||||
@@ -87,14 +89,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
def test_create_version(self):
|
||||
"""Check that we can create and then retrieve versions.
|
||||
"""
|
||||
res = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "first_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(res, "1")
|
||||
|
||||
# check we can retrieve it as the current version
|
||||
res = yield self.handler.get_version_info(self.local_user)
|
||||
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
|
||||
version_etag = res["etag"]
|
||||
self.assertIsInstance(version_etag, str)
|
||||
del res["etag"]
|
||||
@@ -109,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
# check we can retrieve it as a specific version
|
||||
res = yield self.handler.get_version_info(self.local_user, "1")
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.get_version_info(self.local_user, "1")
|
||||
)
|
||||
self.assertEqual(res["etag"], version_etag)
|
||||
del res["etag"]
|
||||
self.assertDictEqual(
|
||||
@@ -123,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
# upload a new one...
|
||||
res = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "second_version_auth_data",
|
||||
},
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "second_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(res, "2")
|
||||
|
||||
# check we can retrieve it as the current version
|
||||
res = yield self.handler.get_version_info(self.local_user)
|
||||
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
|
||||
del res["etag"]
|
||||
self.assertDictEqual(
|
||||
res,
|
||||
@@ -149,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
def test_update_version(self):
|
||||
"""Check that we can update versions.
|
||||
"""
|
||||
version = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||
version = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "first_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(version, "1")
|
||||
|
||||
res = yield self.handler.update_version(
|
||||
self.local_user,
|
||||
version,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "revised_first_version_auth_data",
|
||||
"version": version,
|
||||
},
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.update_version(
|
||||
self.local_user,
|
||||
version,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "revised_first_version_auth_data",
|
||||
"version": version,
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(res, {})
|
||||
|
||||
# check we can retrieve it as the current version
|
||||
res = yield self.handler.get_version_info(self.local_user)
|
||||
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
|
||||
del res["etag"]
|
||||
self.assertDictEqual(
|
||||
res,
|
||||
@@ -185,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
"""
|
||||
res = None
|
||||
try:
|
||||
yield self.handler.update_version(
|
||||
self.local_user,
|
||||
"1",
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "revised_first_version_auth_data",
|
||||
"version": "1",
|
||||
},
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.update_version(
|
||||
self.local_user,
|
||||
"1",
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "revised_first_version_auth_data",
|
||||
"version": "1",
|
||||
},
|
||||
)
|
||||
)
|
||||
except errors.SynapseError as e:
|
||||
res = e.code
|
||||
@@ -202,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
def test_update_omitted_version(self):
|
||||
"""Check that the update succeeds if the version is missing from the body
|
||||
"""
|
||||
version = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||
version = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "first_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(version, "1")
|
||||
|
||||
yield self.handler.update_version(
|
||||
self.local_user,
|
||||
version,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "revised_first_version_auth_data",
|
||||
},
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.update_version(
|
||||
self.local_user,
|
||||
version,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "revised_first_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# check we can retrieve it as the current version
|
||||
res = yield self.handler.get_version_info(self.local_user)
|
||||
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
|
||||
del res["etag"] # etag is opaque, so don't test its contents
|
||||
self.assertDictEqual(
|
||||
res,
|
||||
@@ -234,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
def test_update_bad_version(self):
|
||||
"""Check that we get a 400 if the version in the body doesn't match
|
||||
"""
|
||||
version = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||
version = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "first_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(version, "1")
|
||||
|
||||
res = None
|
||||
try:
|
||||
yield self.handler.update_version(
|
||||
self.local_user,
|
||||
version,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "revised_first_version_auth_data",
|
||||
"version": "incorrect",
|
||||
},
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.update_version(
|
||||
self.local_user,
|
||||
version,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "revised_first_version_auth_data",
|
||||
"version": "incorrect",
|
||||
},
|
||||
)
|
||||
)
|
||||
except errors.SynapseError as e:
|
||||
res = e.code
|
||||
@@ -261,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
"""
|
||||
res = None
|
||||
try:
|
||||
yield self.handler.delete_version(self.local_user, "1")
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.delete_version(self.local_user, "1")
|
||||
)
|
||||
except errors.SynapseError as e:
|
||||
res = e.code
|
||||
self.assertEqual(res, 404)
|
||||
@@ -272,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
"""
|
||||
res = None
|
||||
try:
|
||||
yield self.handler.delete_version(self.local_user)
|
||||
yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
|
||||
except errors.SynapseError as e:
|
||||
res = e.code
|
||||
self.assertEqual(res, 404)
|
||||
@@ -281,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
def test_delete_version(self):
|
||||
"""Check that we can create and then delete versions.
|
||||
"""
|
||||
res = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "first_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(res, "1")
|
||||
|
||||
# check we can delete it
|
||||
yield self.handler.delete_version(self.local_user, "1")
|
||||
yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
|
||||
|
||||
# check that it's gone
|
||||
res = None
|
||||
try:
|
||||
yield self.handler.get_version_info(self.local_user, "1")
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.get_version_info(self.local_user, "1")
|
||||
)
|
||||
except errors.SynapseError as e:
|
||||
res = e.code
|
||||
self.assertEqual(res, 404)
|
||||
@@ -304,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
"""
|
||||
res = None
|
||||
try:
|
||||
yield self.handler.get_room_keys(self.local_user, "bogus_version")
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.get_room_keys(self.local_user, "bogus_version")
|
||||
)
|
||||
except errors.SynapseError as e:
|
||||
res = e.code
|
||||
self.assertEqual(res, 404)
|
||||
@@ -313,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
def test_get_missing_room_keys(self):
|
||||
"""Check we get an empty response from an empty backup
|
||||
"""
|
||||
version = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||
version = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "first_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(version, "1")
|
||||
|
||||
res = yield self.handler.get_room_keys(self.local_user, version)
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.get_room_keys(self.local_user, version)
|
||||
)
|
||||
self.assertDictEqual(res, {"rooms": {}})
|
||||
|
||||
# TODO: test the locking semantics when uploading room_keys,
|
||||
@@ -331,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
"""
|
||||
res = None
|
||||
try:
|
||||
yield self.handler.upload_room_keys(
|
||||
self.local_user, "no_version", room_keys
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
|
||||
)
|
||||
except errors.SynapseError as e:
|
||||
res = e.code
|
||||
@@ -343,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
"""Check that we get a 404 on uploading keys when an nonexistent version
|
||||
is specified
|
||||
"""
|
||||
version = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||
version = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "first_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(version, "1")
|
||||
|
||||
res = None
|
||||
try:
|
||||
yield self.handler.upload_room_keys(
|
||||
self.local_user, "bogus_version", room_keys
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_room_keys(
|
||||
self.local_user, "bogus_version", room_keys
|
||||
)
|
||||
)
|
||||
except errors.SynapseError as e:
|
||||
res = e.code
|
||||
@@ -362,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
def test_upload_room_keys_wrong_version(self):
|
||||
"""Check that we get a 403 on uploading keys for an old version
|
||||
"""
|
||||
version = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||
version = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "first_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(version, "1")
|
||||
|
||||
version = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "second_version_auth_data",
|
||||
},
|
||||
version = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "second_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(version, "2")
|
||||
|
||||
res = None
|
||||
try:
|
||||
yield self.handler.upload_room_keys(self.local_user, "1", room_keys)
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_room_keys(self.local_user, "1", room_keys)
|
||||
)
|
||||
except errors.SynapseError as e:
|
||||
res = e.code
|
||||
self.assertEqual(res, 403)
|
||||
@@ -388,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
def test_upload_room_keys_insert(self):
|
||||
"""Check that we can insert and retrieve keys for a session
|
||||
"""
|
||||
version = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||
version = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "first_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(version, "1")
|
||||
|
||||
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||
)
|
||||
|
||||
res = yield self.handler.get_room_keys(self.local_user, version)
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.get_room_keys(self.local_user, version)
|
||||
)
|
||||
self.assertDictEqual(res, room_keys)
|
||||
|
||||
# check getting room_keys for a given room
|
||||
res = yield self.handler.get_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org"
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.get_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org"
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(res, room_keys)
|
||||
|
||||
# check getting room_keys for a given session_id
|
||||
res = yield self.handler.get_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.get_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(res, room_keys)
|
||||
|
||||
@@ -415,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
def test_upload_room_keys_merge(self):
|
||||
"""Check that we can upload a new room_key for an existing session and
|
||||
have it correctly merged"""
|
||||
version = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||
version = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "first_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(version, "1")
|
||||
|
||||
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||
)
|
||||
|
||||
# get the etag to compare to future versions
|
||||
res = yield self.handler.get_version_info(self.local_user)
|
||||
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
|
||||
backup_etag = res["etag"]
|
||||
self.assertEqual(res["count"], 1)
|
||||
|
||||
@@ -434,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
# test that increasing the message_index doesn't replace the existing session
|
||||
new_room_key["first_message_index"] = 2
|
||||
new_room_key["session_data"] = "new"
|
||||
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
||||
)
|
||||
|
||||
res = yield self.handler.get_room_keys(self.local_user, version)
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.get_room_keys(self.local_user, version)
|
||||
)
|
||||
self.assertEqual(
|
||||
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
|
||||
"SSBBTSBBIEZJU0gK",
|
||||
)
|
||||
|
||||
# the etag should be the same since the session did not change
|
||||
res = yield self.handler.get_version_info(self.local_user)
|
||||
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
|
||||
self.assertEqual(res["etag"], backup_etag)
|
||||
|
||||
# test that marking the session as verified however /does/ replace it
|
||||
new_room_key["is_verified"] = True
|
||||
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
||||
)
|
||||
|
||||
res = yield self.handler.get_room_keys(self.local_user, version)
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.get_room_keys(self.local_user, version)
|
||||
)
|
||||
self.assertEqual(
|
||||
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
|
||||
)
|
||||
|
||||
# the etag should NOT be equal now, since the key changed
|
||||
res = yield self.handler.get_version_info(self.local_user)
|
||||
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
|
||||
self.assertNotEqual(res["etag"], backup_etag)
|
||||
backup_etag = res["etag"]
|
||||
|
||||
@@ -464,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
# with a lower forwarding count
|
||||
new_room_key["forwarded_count"] = 2
|
||||
new_room_key["session_data"] = "other"
|
||||
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
|
||||
)
|
||||
|
||||
res = yield self.handler.get_room_keys(self.local_user, version)
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.get_room_keys(self.local_user, version)
|
||||
)
|
||||
self.assertEqual(
|
||||
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
|
||||
)
|
||||
|
||||
# the etag should be the same since the session did not change
|
||||
res = yield self.handler.get_version_info(self.local_user)
|
||||
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
|
||||
self.assertEqual(res["etag"], backup_etag)
|
||||
|
||||
# TODO: check edge cases as well as the common variations here
|
||||
@@ -481,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
|
||||
def test_delete_room_keys(self):
|
||||
"""Check that we can insert and delete keys for a session
|
||||
"""
|
||||
version = yield self.handler.create_version(
|
||||
self.local_user,
|
||||
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
|
||||
version = yield defer.ensureDeferred(
|
||||
self.handler.create_version(
|
||||
self.local_user,
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "first_version_auth_data",
|
||||
},
|
||||
)
|
||||
)
|
||||
self.assertEqual(version, "1")
|
||||
|
||||
# check for bulk-delete
|
||||
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||
yield self.handler.delete_room_keys(self.local_user, version)
|
||||
res = yield self.handler.get_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||
)
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.delete_room_keys(self.local_user, version)
|
||||
)
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.get_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(res, {"rooms": {}})
|
||||
|
||||
# check for bulk-delete per room
|
||||
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||
yield self.handler.delete_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org"
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||
)
|
||||
res = yield self.handler.get_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.delete_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org"
|
||||
)
|
||||
)
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.get_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(res, {"rooms": {}})
|
||||
|
||||
# check for bulk-delete per session
|
||||
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||
yield self.handler.delete_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.upload_room_keys(self.local_user, version, room_keys)
|
||||
)
|
||||
res = yield self.handler.get_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||
yield defer.ensureDeferred(
|
||||
self.handler.delete_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||
)
|
||||
)
|
||||
res = yield defer.ensureDeferred(
|
||||
self.handler.get_room_keys(
|
||||
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
|
||||
)
|
||||
)
|
||||
self.assertDictEqual(res, {"rooms": {}})
|
||||
|
||||
@@ -138,10 +138,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
||||
|
||||
def get_current_users_in_room(room_id):
|
||||
def get_users_in_room(room_id):
|
||||
return defer.succeed({str(u) for u in self.room_members})
|
||||
|
||||
hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
|
||||
self.datastore.get_users_in_room = get_users_in_room
|
||||
|
||||
self.datastore.get_user_directory_stream_pos.return_value = (
|
||||
# we deliberately return a non-None stream pos to avoid doing an initial_spam
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
@@ -26,8 +26,9 @@ from synapse.app.generic_worker import (
|
||||
GenericWorkerReplicationHandler,
|
||||
GenericWorkerServer,
|
||||
)
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.replication.http import streams
|
||||
from synapse.replication.http import ReplicationRestResource, streams
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
@@ -35,7 +36,7 @@ from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport
|
||||
from tests.server import FakeTransport, render
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(request.method, b"GET")
|
||||
|
||||
|
||||
class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||
"""Base class for tests running multiple workers.
|
||||
|
||||
Automatically handle HTTP replication requests from workers to master,
|
||||
unlike `BaseStreamTestCase`.
|
||||
"""
|
||||
|
||||
servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# build a replication server
|
||||
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||
self.streamer = self.hs.get_replication_streamer()
|
||||
|
||||
store = self.hs.get_datastore()
|
||||
self.database = store.db
|
||||
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
|
||||
self._worker_hs_to_resource = {}
|
||||
|
||||
# When we see a connection attempt to the master replication listener we
|
||||
# automatically set up the connection. This is so that tests don't
|
||||
# manually have to go and explicitly set it up each time (plus sometimes
|
||||
# it is impossible to write the handling explicitly in the tests).
|
||||
self.reactor.add_tcp_client_callback(
|
||||
"1.2.3.4", 8765, self._handle_http_replication_attempt
|
||||
)
|
||||
|
||||
def create_test_json_resource(self):
|
||||
"""Overrides `HomeserverTestCase.create_test_json_resource`.
|
||||
"""
|
||||
# We override this so that it automatically registers all the HTTP
|
||||
# replication servlets, without having to explicitly do that in all
|
||||
# subclassses.
|
||||
|
||||
resource = ReplicationRestResource(self.hs)
|
||||
|
||||
for servlet in self.servlets:
|
||||
servlet(self.hs, resource)
|
||||
|
||||
return resource
|
||||
|
||||
def make_worker_hs(
|
||||
self, worker_app: str, extra_config: dict = {}, **kwargs
|
||||
) -> HomeServer:
|
||||
"""Make a new worker HS instance, correctly connecting replcation
|
||||
stream to the master HS.
|
||||
|
||||
Args:
|
||||
worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
|
||||
extra_config: Any extra config to use for this instances.
|
||||
**kwargs: Options that get passed to `self.setup_test_homeserver`,
|
||||
useful to e.g. pass some mocks for things like `http_client`
|
||||
|
||||
Returns:
|
||||
The new worker HomeServer instance.
|
||||
"""
|
||||
|
||||
config = self._get_worker_hs_config()
|
||||
config["worker_app"] = worker_app
|
||||
config.update(extra_config)
|
||||
|
||||
worker_hs = self.setup_test_homeserver(
|
||||
homeserverToUse=GenericWorkerServer,
|
||||
config=config,
|
||||
reactor=self.reactor,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
store = worker_hs.get_datastore()
|
||||
store.db._db_pool = self.database._db_pool
|
||||
|
||||
repl_handler = ReplicationCommandHandler(worker_hs)
|
||||
client = ClientReplicationStreamProtocol(
|
||||
worker_hs, "client", "test", self.clock, repl_handler,
|
||||
)
|
||||
server = self.server_factory.buildProtocol(None)
|
||||
|
||||
client_transport = FakeTransport(server, self.reactor)
|
||||
client.makeConnection(client_transport)
|
||||
|
||||
server_transport = FakeTransport(client, self.reactor)
|
||||
server.makeConnection(server_transport)
|
||||
|
||||
# Set up a resource for the worker
|
||||
resource = ReplicationRestResource(self.hs)
|
||||
|
||||
for servlet in self.servlets:
|
||||
servlet(worker_hs, resource)
|
||||
|
||||
self._worker_hs_to_resource[worker_hs] = resource
|
||||
|
||||
return worker_hs
|
||||
|
||||
def _get_worker_hs_config(self) -> dict:
|
||||
config = self.default_config()
|
||||
config["worker_replication_host"] = "testserv"
|
||||
config["worker_replication_http_port"] = "8765"
|
||||
return config
|
||||
|
||||
def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
|
||||
render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
|
||||
|
||||
def replicate(self):
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
wait for the replication to occur.
|
||||
"""
|
||||
self.streamer.on_notifier_poke()
|
||||
self.pump()
|
||||
|
||||
def _handle_http_replication_attempt(self):
|
||||
"""Handles a connection attempt to the master replication HTTP
|
||||
listener.
|
||||
"""
|
||||
|
||||
# We should have at least one outbound connection attempt, where the
|
||||
# last is one to the HTTP repication IP/port.
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertGreaterEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
|
||||
self.assertEqual(host, "1.2.3.4")
|
||||
self.assertEqual(port, 8765)
|
||||
|
||||
# Set up client side protocol
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
|
||||
request_factory = OneShotRequestFactory()
|
||||
|
||||
# Set up the server side protocol
|
||||
channel = _PushHTTPChannel(self.reactor)
|
||||
channel.requestFactory = request_factory
|
||||
channel.site = self.site
|
||||
|
||||
# Connect client to server and vice versa.
|
||||
client_to_server_transport = FakeTransport(
|
||||
channel, self.reactor, client_protocol
|
||||
)
|
||||
client_protocol.makeConnection(client_to_server_transport)
|
||||
|
||||
server_to_client_transport = FakeTransport(
|
||||
client_protocol, self.reactor, channel
|
||||
)
|
||||
channel.makeConnection(server_to_client_transport)
|
||||
|
||||
# Note: at this point we've wired everything up, but we need to return
|
||||
# before the data starts flowing over the connections as this is called
|
||||
# inside `connecTCP` before the connection has been passed back to the
|
||||
# code that requested the TCP connection.
|
||||
|
||||
|
||||
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
|
||||
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
|
||||
|
||||
@@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel):
|
||||
# We need to manually stop the _PullToPushProducer.
|
||||
self._pull_to_push_producer.stop()
|
||||
|
||||
def checkPersistence(self, request, version):
|
||||
"""Check whether the connection can be re-used
|
||||
"""
|
||||
# We hijack this to always say no for ease of wiring stuff up in
|
||||
# `handle_http_replication_attempt`.
|
||||
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
|
||||
return False
|
||||
|
||||
|
||||
class _PullToPushProducer:
|
||||
"""A push producer that wraps a pull producer.
|
||||
|
||||
@@ -15,63 +15,26 @@
|
||||
import logging
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.rest.client.v2_alpha import register
|
||||
|
||||
from tests import unittest
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
|
||||
from tests.server import FakeChannel, render
|
||||
from tests.server import FakeChannel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ClientReaderTestCase(unittest.HomeserverTestCase):
|
||||
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
|
||||
"""Base class for tests of the replication streams"""
|
||||
|
||||
servlets = [
|
||||
register.register_servlets,
|
||||
]
|
||||
servlets = [register.register_servlets]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# build a replication server
|
||||
self.server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
self.streamer = hs.get_replication_streamer()
|
||||
|
||||
store = hs.get_datastore()
|
||||
self.database = store.db
|
||||
|
||||
self.recaptcha_checker = DummyRecaptchaChecker(hs)
|
||||
auth_handler = hs.get_auth_handler()
|
||||
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
|
||||
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
|
||||
def make_worker_hs(self, extra_config={}):
|
||||
config = self._get_worker_hs_config()
|
||||
config.update(extra_config)
|
||||
|
||||
worker_hs = self.setup_test_homeserver(
|
||||
homeserverToUse=GenericWorkerServer, config=config, reactor=self.reactor,
|
||||
)
|
||||
|
||||
store = worker_hs.get_datastore()
|
||||
store.db._db_pool = self.database._db_pool
|
||||
|
||||
# Register the expected servlets, essentially this is HomeserverTestCase.create_test_json_resource.
|
||||
resource = JsonResource(self.hs)
|
||||
|
||||
for servlet in self.servlets:
|
||||
servlet(worker_hs, resource)
|
||||
|
||||
# Essentially HomeserverTestCase.render.
|
||||
def _render(request):
|
||||
render(request, self.resource, self.reactor)
|
||||
|
||||
return worker_hs, _render
|
||||
|
||||
def _get_worker_hs_config(self) -> dict:
|
||||
config = self.default_config()
|
||||
config["worker_app"] = "synapse.app.client_reader"
|
||||
@@ -82,14 +45,14 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
|
||||
def test_register_single_worker(self):
|
||||
"""Test that registration works when using a single client reader worker.
|
||||
"""
|
||||
_, worker_render = self.make_worker_hs()
|
||||
worker_hs = self.make_worker_hs("synapse.app.client_reader")
|
||||
|
||||
request_1, channel_1 = self.make_request(
|
||||
"POST",
|
||||
"register",
|
||||
{"username": "user", "type": "m.login.password", "password": "bar"},
|
||||
) # type: SynapseRequest, FakeChannel
|
||||
worker_render(request_1)
|
||||
self.render_on_worker(worker_hs, request_1)
|
||||
self.assertEqual(request_1.code, 401)
|
||||
|
||||
# Grab the session
|
||||
@@ -99,7 +62,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
|
||||
request_2, channel_2 = self.make_request(
|
||||
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
|
||||
) # type: SynapseRequest, FakeChannel
|
||||
worker_render(request_2)
|
||||
self.render_on_worker(worker_hs, request_2)
|
||||
self.assertEqual(request_2.code, 200)
|
||||
|
||||
# We're given a registered user.
|
||||
@@ -108,15 +71,15 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
|
||||
def test_register_multi_worker(self):
|
||||
"""Test that registration works when using multiple client reader workers.
|
||||
"""
|
||||
_, worker_render_1 = self.make_worker_hs()
|
||||
_, worker_render_2 = self.make_worker_hs()
|
||||
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
|
||||
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
|
||||
|
||||
request_1, channel_1 = self.make_request(
|
||||
"POST",
|
||||
"register",
|
||||
{"username": "user", "type": "m.login.password", "password": "bar"},
|
||||
) # type: SynapseRequest, FakeChannel
|
||||
worker_render_1(request_1)
|
||||
self.render_on_worker(worker_hs_1, request_1)
|
||||
self.assertEqual(request_1.code, 401)
|
||||
|
||||
# Grab the session
|
||||
@@ -126,7 +89,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
|
||||
request_2, channel_2 = self.make_request(
|
||||
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
|
||||
) # type: SynapseRequest, FakeChannel
|
||||
worker_render_2(request_2)
|
||||
self.render_on_worker(worker_hs_2, request_2)
|
||||
self.assertEqual(request_2.code, 200)
|
||||
|
||||
# We're given a registered user.
|
||||
|
||||
@@ -19,132 +19,40 @@ from mock import Mock
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
from synapse.replication.http import streams
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.rest.admin import register_servlets_for_client_rest_resource
|
||||
from synapse.rest.client.v1 import login, room
|
||||
from synapse.types import UserID
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeTransport
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
"""Base class for tests of the replication streams"""
|
||||
|
||||
servlets = [
|
||||
streams.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# build a replication server
|
||||
self.server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
self.streamer = hs.get_replication_streamer()
|
||||
|
||||
store = hs.get_datastore()
|
||||
self.database = store.db
|
||||
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
|
||||
def default_config(self):
|
||||
conf = super().default_config()
|
||||
conf["send_federation"] = False
|
||||
return conf
|
||||
|
||||
def make_worker_hs(self, extra_config={}):
|
||||
config = self._get_worker_hs_config()
|
||||
config.update(extra_config)
|
||||
|
||||
mock_federation_client = Mock(spec=["put_json"])
|
||||
mock_federation_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
|
||||
|
||||
worker_hs = self.setup_test_homeserver(
|
||||
http_client=mock_federation_client,
|
||||
homeserverToUse=GenericWorkerServer,
|
||||
config=config,
|
||||
reactor=self.reactor,
|
||||
)
|
||||
|
||||
store = worker_hs.get_datastore()
|
||||
store.db._db_pool = self.database._db_pool
|
||||
|
||||
repl_handler = ReplicationCommandHandler(worker_hs)
|
||||
client = ClientReplicationStreamProtocol(
|
||||
worker_hs, "client", "test", self.clock, repl_handler,
|
||||
)
|
||||
server = self.server_factory.buildProtocol(None)
|
||||
|
||||
client_transport = FakeTransport(server, self.reactor)
|
||||
client.makeConnection(client_transport)
|
||||
|
||||
server_transport = FakeTransport(client, self.reactor)
|
||||
server.makeConnection(server_transport)
|
||||
|
||||
return worker_hs
|
||||
|
||||
def _get_worker_hs_config(self) -> dict:
|
||||
config = self.default_config()
|
||||
config["worker_app"] = "synapse.app.federation_sender"
|
||||
config["worker_replication_host"] = "testserv"
|
||||
config["worker_replication_http_port"] = "8765"
|
||||
return config
|
||||
|
||||
def replicate(self):
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
wait for the replication to occur.
|
||||
"""
|
||||
self.streamer.on_notifier_poke()
|
||||
self.pump()
|
||||
|
||||
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
|
||||
room = self.helper.create_room_as(user, tok=token)
|
||||
store = self.hs.get_datastore()
|
||||
federation = self.hs.get_handlers().federation_handler
|
||||
|
||||
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
|
||||
room_version = self.get_success(store.get_room_version(room))
|
||||
|
||||
factory = EventBuilderFactory(self.hs)
|
||||
factory.hostname = remote_server
|
||||
|
||||
user_id = UserID("user", remote_server).to_string()
|
||||
|
||||
event_dict = {
|
||||
"type": EventTypes.Member,
|
||||
"state_key": user_id,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
"sender": user_id,
|
||||
"room_id": room,
|
||||
}
|
||||
|
||||
builder = factory.for_room_version(room_version, event_dict)
|
||||
join_event = self.get_success(builder.build(prev_event_ids))
|
||||
|
||||
self.get_success(federation.on_send_join_request(remote_server, join_event))
|
||||
self.replicate()
|
||||
|
||||
return room
|
||||
|
||||
|
||||
class FederationSenderTestCase(BaseStreamTestCase):
|
||||
class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
|
||||
servlets = [
|
||||
login.register_servlets,
|
||||
register_servlets_for_client_rest_resource,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def default_config(self):
|
||||
conf = super().default_config()
|
||||
conf["send_federation"] = False
|
||||
return conf
|
||||
|
||||
def test_send_event_single_sender(self):
|
||||
"""Test that using a single federation sender worker correctly sends a
|
||||
new event.
|
||||
"""
|
||||
worker_hs = self.make_worker_hs({"send_federation": True})
|
||||
mock_client = worker_hs.get_http_client()
|
||||
mock_client = Mock(spec=["put_json"])
|
||||
mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
|
||||
|
||||
self.make_worker_hs(
|
||||
"synapse.app.federation_sender",
|
||||
{"send_federation": True},
|
||||
http_client=mock_client,
|
||||
)
|
||||
|
||||
user = self.register_user("user", "pass")
|
||||
token = self.login("user", "pass")
|
||||
@@ -165,23 +73,29 @@ class FederationSenderTestCase(BaseStreamTestCase):
|
||||
"""Test that using two federation sender workers correctly sends
|
||||
new events.
|
||||
"""
|
||||
worker1 = self.make_worker_hs(
|
||||
mock_client1 = Mock(spec=["put_json"])
|
||||
mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
|
||||
self.make_worker_hs(
|
||||
"synapse.app.federation_sender",
|
||||
{
|
||||
"send_federation": True,
|
||||
"worker_name": "sender1",
|
||||
"federation_sender_instances": ["sender1", "sender2"],
|
||||
}
|
||||
},
|
||||
http_client=mock_client1,
|
||||
)
|
||||
mock_client1 = worker1.get_http_client()
|
||||
|
||||
worker2 = self.make_worker_hs(
|
||||
mock_client2 = Mock(spec=["put_json"])
|
||||
mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
|
||||
self.make_worker_hs(
|
||||
"synapse.app.federation_sender",
|
||||
{
|
||||
"send_federation": True,
|
||||
"worker_name": "sender2",
|
||||
"federation_sender_instances": ["sender1", "sender2"],
|
||||
}
|
||||
},
|
||||
http_client=mock_client2,
|
||||
)
|
||||
mock_client2 = worker2.get_http_client()
|
||||
|
||||
user = self.register_user("user2", "pass")
|
||||
token = self.login("user2", "pass")
|
||||
@@ -191,8 +105,8 @@ class FederationSenderTestCase(BaseStreamTestCase):
|
||||
for i in range(20):
|
||||
server_name = "other_server_%d" % (i,)
|
||||
room = self.create_room_with_remote_server(user, token, server_name)
|
||||
mock_client1.reset_mock()
|
||||
mock_client2.reset_mock()
|
||||
mock_client1.reset_mock() # type: ignore[attr-defined]
|
||||
mock_client2.reset_mock() # type: ignore[attr-defined]
|
||||
|
||||
self.create_and_send_event(room, UserID.from_string(user))
|
||||
self.replicate()
|
||||
@@ -222,23 +136,29 @@ class FederationSenderTestCase(BaseStreamTestCase):
|
||||
"""Test that using two federation sender workers correctly sends
|
||||
new typing EDUs.
|
||||
"""
|
||||
worker1 = self.make_worker_hs(
|
||||
mock_client1 = Mock(spec=["put_json"])
|
||||
mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
|
||||
self.make_worker_hs(
|
||||
"synapse.app.federation_sender",
|
||||
{
|
||||
"send_federation": True,
|
||||
"worker_name": "sender1",
|
||||
"federation_sender_instances": ["sender1", "sender2"],
|
||||
}
|
||||
},
|
||||
http_client=mock_client1,
|
||||
)
|
||||
mock_client1 = worker1.get_http_client()
|
||||
|
||||
worker2 = self.make_worker_hs(
|
||||
mock_client2 = Mock(spec=["put_json"])
|
||||
mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
|
||||
self.make_worker_hs(
|
||||
"synapse.app.federation_sender",
|
||||
{
|
||||
"send_federation": True,
|
||||
"worker_name": "sender2",
|
||||
"federation_sender_instances": ["sender1", "sender2"],
|
||||
}
|
||||
},
|
||||
http_client=mock_client2,
|
||||
)
|
||||
mock_client2 = worker2.get_http_client()
|
||||
|
||||
user = self.register_user("user3", "pass")
|
||||
token = self.login("user3", "pass")
|
||||
@@ -250,8 +170,8 @@ class FederationSenderTestCase(BaseStreamTestCase):
|
||||
for i in range(20):
|
||||
server_name = "other_server_%d" % (i,)
|
||||
room = self.create_room_with_remote_server(user, token, server_name)
|
||||
mock_client1.reset_mock()
|
||||
mock_client2.reset_mock()
|
||||
mock_client1.reset_mock() # type: ignore[attr-defined]
|
||||
mock_client2.reset_mock() # type: ignore[attr-defined]
|
||||
|
||||
self.get_success(
|
||||
typing_handler.started_typing(
|
||||
@@ -284,3 +204,32 @@ class FederationSenderTestCase(BaseStreamTestCase):
|
||||
|
||||
self.assertTrue(sent_on_1)
|
||||
self.assertTrue(sent_on_2)
|
||||
|
||||
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
|
||||
room = self.helper.create_room_as(user, tok=token)
|
||||
store = self.hs.get_datastore()
|
||||
federation = self.hs.get_handlers().federation_handler
|
||||
|
||||
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
|
||||
room_version = self.get_success(store.get_room_version(room))
|
||||
|
||||
factory = EventBuilderFactory(self.hs)
|
||||
factory.hostname = remote_server
|
||||
|
||||
user_id = UserID("user", remote_server).to_string()
|
||||
|
||||
event_dict = {
|
||||
"type": EventTypes.Member,
|
||||
"state_key": user_id,
|
||||
"content": {"membership": Membership.JOIN},
|
||||
"sender": user_id,
|
||||
"room_id": room,
|
||||
}
|
||||
|
||||
builder = factory.for_room_version(room_version, event_dict)
|
||||
join_event = self.get_success(builder.build(prev_event_ids))
|
||||
|
||||
self.get_success(federation.on_send_join_request(remote_server, join_event))
|
||||
self.replicate()
|
||||
|
||||
return room
|
||||
|
||||
@@ -857,6 +857,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(True, channel.json_body["deactivated"])
|
||||
|
||||
def test_reactivate_user(self):
|
||||
"""
|
||||
Test reactivating another user.
|
||||
"""
|
||||
|
||||
# Deactivate the user.
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
||||
# Attempt to reactivate the user (without a password).
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
||||
# Reactivate the user.
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
self.url_other_user,
|
||||
access_token=self.admin_user_tok,
|
||||
content=json.dumps({"deactivated": False, "password": "foo"}).encode(
|
||||
encoding="utf_8"
|
||||
),
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
|
||||
# Get user
|
||||
request, channel = self.make_request(
|
||||
"GET", self.url_other_user, access_token=self.admin_user_tok,
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||
self.assertEqual("@user:test", channel.json_body["name"])
|
||||
self.assertEqual(False, channel.json_body["deactivated"])
|
||||
|
||||
def test_set_user_as_admin(self):
|
||||
"""
|
||||
Test setting the admin flag on a user.
|
||||
|
||||
@@ -547,8 +547,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_login_jwt_invalid_signature(self):
|
||||
channel = self.jwt_login({"sub": "frog"}, "notsecret")
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
channel.json_body["error"],
|
||||
"JWT validation failed: Signature verification failed",
|
||||
@@ -556,8 +556,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_login_jwt_expired(self):
|
||||
channel = self.jwt_login({"sub": "frog", "exp": 864000})
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
channel.json_body["error"], "JWT validation failed: Signature has expired"
|
||||
)
|
||||
@@ -565,8 +565,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
def test_login_jwt_not_before(self):
|
||||
now = int(time.time())
|
||||
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
channel.json_body["error"],
|
||||
"JWT validation failed: The token is not yet valid (nbf)",
|
||||
@@ -574,8 +574,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_login_no_sub(self):
|
||||
channel = self.jwt_login({"username": "root"})
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(channel.json_body["error"], "Invalid JWT")
|
||||
|
||||
@override_config(
|
||||
@@ -597,16 +597,16 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# An invalid issuer.
|
||||
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
channel.json_body["error"], "JWT validation failed: Invalid issuer"
|
||||
)
|
||||
|
||||
# Not providing an issuer.
|
||||
channel = self.jwt_login({"sub": "kermit"})
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
channel.json_body["error"],
|
||||
'JWT validation failed: Token is missing the "iss" claim',
|
||||
@@ -637,16 +637,16 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# An invalid audience.
|
||||
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
channel.json_body["error"], "JWT validation failed: Invalid audience"
|
||||
)
|
||||
|
||||
# Not providing an audience.
|
||||
channel = self.jwt_login({"sub": "kermit"})
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
channel.json_body["error"],
|
||||
'JWT validation failed: Token is missing the "aud" claim',
|
||||
@@ -655,7 +655,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
def test_login_aud_no_config(self):
|
||||
"""Test providing an audience without requiring it in the configuration."""
|
||||
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
channel.json_body["error"], "JWT validation failed: Invalid audience"
|
||||
)
|
||||
@@ -664,8 +665,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
||||
params = json.dumps({"type": "org.matrix.login.jwt"})
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
|
||||
|
||||
|
||||
@@ -747,8 +748,8 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_login_jwt_invalid_signature(self):
|
||||
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
|
||||
self.assertEqual(channel.result["code"], b"401", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
|
||||
self.assertEqual(channel.result["code"], b"403", channel.result)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
||||
self.assertEqual(
|
||||
channel.json_body["error"],
|
||||
"JWT validation failed: Signature verification failed",
|
||||
|
||||
@@ -237,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||
def __init__(self):
|
||||
self.threadpool = ThreadPool(self)
|
||||
|
||||
self._tcp_callbacks = {}
|
||||
self._udp = []
|
||||
lookups = self.lookups = {}
|
||||
|
||||
@@ -268,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
|
||||
def getThreadPool(self):
|
||||
return self.threadpool
|
||||
|
||||
def add_tcp_client_callback(self, host, port, callback):
|
||||
"""Add a callback that will be invoked when we receive a connection
|
||||
attempt to the given IP/port using `connectTCP`.
|
||||
|
||||
Note that the callback gets run before we return the connection to the
|
||||
client, which means callbacks cannot block while waiting for writes.
|
||||
"""
|
||||
self._tcp_callbacks[(host, port)] = callback
|
||||
|
||||
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
|
||||
"""Fake L{IReactorTCP.connectTCP}.
|
||||
"""
|
||||
|
||||
conn = super().connectTCP(
|
||||
host, port, factory, timeout=timeout, bindAddress=None
|
||||
)
|
||||
|
||||
callback = self._tcp_callbacks.get((host, port))
|
||||
if callback:
|
||||
callback()
|
||||
|
||||
return conn
|
||||
|
||||
|
||||
class ThreadPool:
|
||||
"""
|
||||
@@ -486,7 +510,7 @@ class FakeTransport(object):
|
||||
try:
|
||||
self.other.dataReceived(to_write)
|
||||
except Exception as e:
|
||||
logger.warning("Exception writing to protocol: %s", e)
|
||||
logger.exception("Exception writing to protocol: %s", e)
|
||||
return
|
||||
|
||||
self.buffer = self.buffer[len(to_write) :]
|
||||
|
||||
Reference in New Issue
Block a user