1
0

Merge commit 'a973bcb8a' into anoa/dinsic_release_1_18_x

* commit 'a973bcb8a':
  Add some tiny type annotations (#7870)
  Remove obsolete comment.
  Ensure that calls to `json.dumps` are compatible with the standard library json. (#7836)
  Avoid brand new rooms in `delete_old_current_state_events` (#7854)
  Allow accounts to be re-activated from the admin APIs. (#7847)
  Fix tests
  Fix typo
  Newsfile
  Use get_users_in_room rather than state handler in typing for speed
  Fix client reader sharding tests (#7853)
  Convert E2E key and room key handlers to async/await. (#7851)
  Return the proper 403 Forbidden error during errors with JWT logins. (#7844)
  remove `retry_on_integrity_error` wrapper for persist_events (#7848)
This commit is contained in:
Andrew Morgan
2020-08-03 17:31:36 -07:00
39 changed files with 1032 additions and 694 deletions

1
changelog.d/7836.misc Normal file
View File

@@ -0,0 +1 @@
Ensure that calls to `json.dumps` are compatible with the standard library json.

1
changelog.d/7844.bugfix Normal file
View File

@@ -0,0 +1 @@
Errors which occur while using the non-standard JWT login now return the proper error: `403 Forbidden` with an error code of `M_FORBIDDEN`.

1
changelog.d/7847.feature Normal file
View File

@@ -0,0 +1 @@
Add the ability to re-activate an account from the admin API.

1
changelog.d/7848.misc Normal file
View File

@@ -0,0 +1 @@
Remove redundant `retry_on_integrity_error` wrapper for event persistence code.

1
changelog.d/7851.misc Normal file
View File

@@ -0,0 +1 @@
Convert E2E keys and room keys handlers to async/await.

1
changelog.d/7853.misc Normal file
View File

@@ -0,0 +1 @@
Add support for handling registration requests across multiple client reader workers.

1
changelog.d/7854.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug introduced in Synapse 1.10.0 which could cause a "no create event in auth events" error during room creation.

1
changelog.d/7856.misc Normal file
View File

@@ -0,0 +1 @@
Small performance improvement in typing processing.

1
changelog.d/7870.misc Normal file
View File

@@ -0,0 +1 @@
Add some type annotations to `HomeServer` and `BaseHandler`.

View File

@@ -91,10 +91,14 @@ Body parameters:
- ``admin``, optional, defaults to ``false``.
- ``deactivated``, optional, defaults to ``false``.
- ``deactivated``, optional. If unspecified, deactivation state will be left
unchanged on existing accounts and set to ``false`` for new accounts.
If the user already exists then optional parameters default to the current value.
In order to re-activate an account ``deactivated`` must be set to ``false``. If
users do not login via single-sign-on, a new ``password`` must be provided.
List Accounts
=============

View File

@@ -31,10 +31,7 @@ The `token` field should include the JSON web token with the following claims:
Providing the audience claim when not configured will cause validation to fail.
In the case that the token is not valid, the homeserver must respond with
`401 Unauthorized` and an error code of `M_UNAUTHORIZED`.
(Note that this differs from the token based logins which return a
`403 Forbidden` and an error code of `M_FORBIDDEN` if an error occurs.)
`403 Forbidden` and an error code of `M_FORBIDDEN`.
As with other login types, there are additional fields (e.g. `device_id` and
`initial_device_display_name`) which can be included in the above request.

View File

@@ -16,12 +16,14 @@
# limitations under the License.
"""Contains exceptions and error codes."""
import json
import logging
import typing
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from canonicaljson import json
from twisted.web import http
if typing.TYPE_CHECKING:

View File

@@ -14,10 +14,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
from canonicaljson import json
from prometheus_client import Counter, Histogram
from twisted.internet import defer

View File

@@ -61,8 +61,6 @@ class TransactionManager(object):
# all the edus in that transaction. This needs to be done since there is
# no active span here, so if the edus were not received by the remote the
# span would have no causality and it would be forgotten.
# The span_contexts is a generator so that it won't be evaluated if
# opentracing is disabled. (Yay speed!)
span_contexts = []
keep_destination = whitelisted_homeserver(destination)

View File

@@ -17,6 +17,8 @@ import logging
from twisted.internet import defer
import synapse.state
import synapse.storage
import synapse.types
from synapse.api.constants import EventTypes, Membership
from synapse.api.ratelimiting import Ratelimiter
@@ -28,10 +30,6 @@ logger = logging.getLogger(__name__)
class BaseHandler(object):
"""
Common base class for the event handlers.
Attributes:
store (synapse.storage.DataStore):
state_handler (synapse.state.StateHandler):
"""
def __init__(self, hs):
@@ -39,10 +37,10 @@ class BaseHandler(object):
Args:
hs (synapse.server.HomeServer):
"""
self.store = hs.get_datastore()
self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.state_handler = hs.get_state_handler()
self.state_handler = hs.get_state_handler() # type: synapse.state.StateHandler
self.distributor = hs.get_distributor()
self.clock = hs.get_clock()
self.hs = hs

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -46,19 +47,20 @@ class DeactivateAccountHandler(BaseHandler):
self._account_validity_enabled = hs.config.account_validity.enabled
async def deactivate_account(self, user_id, erase_data, id_server=None):
async def deactivate_account(
self, user_id: str, erase_data: bool, id_server: Optional[str] = None
) -> bool:
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
erase_data (bool): whether to GDPR-erase the user's data
id_server (str|None): Use the given identity server when unbinding
user_id: ID of user to be deactivated
erase_data: whether to GDPR-erase the user's data
id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
Returns:
Deferred[bool]: True if identity server supports removing
threepids, otherwise False.
True if identity server supports removing threepids, otherwise False.
"""
# FIXME: Theoretically there is a race here wherein user resets
# password using threepid.
@@ -138,11 +140,11 @@ class DeactivateAccountHandler(BaseHandler):
return identity_server_supports_unbinding
async def _reject_pending_invites_for_user(self, user_id):
async def _reject_pending_invites_for_user(self, user_id: str):
"""Reject pending invites addressed to a given user ID.
Args:
user_id (str): The user ID to reject pending invites for.
user_id: The user ID to reject pending invites for.
"""
user = UserID.from_string(user_id)
pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
@@ -170,22 +172,16 @@ class DeactivateAccountHandler(BaseHandler):
room.room_id,
)
def _start_user_parting(self):
def _start_user_parting(self) -> None:
"""
Start the process that goes through the table of users
pending deactivation, if it isn't already running.
Returns:
None
"""
if not self._user_parter_running:
run_as_background_process("user_parter_loop", self._user_parter_loop)
async def _user_parter_loop(self):
async def _user_parter_loop(self) -> None:
"""Loop that parts deactivated users from rooms
Returns:
None
"""
self._user_parter_running = True
logger.info("Starting user parter")
@@ -202,11 +198,8 @@ class DeactivateAccountHandler(BaseHandler):
finally:
self._user_parter_running = False
async def _part_user(self, user_id):
async def _part_user(self, user_id: str) -> None:
"""Causes the given user_id to leave all the rooms they're joined to
Returns:
None
"""
user = UserID.from_string(user_id)
@@ -228,3 +221,18 @@ class DeactivateAccountHandler(BaseHandler):
user_id,
room_id,
)
async def activate_account(self, user_id: str) -> None:
"""
Activate an account that was previously deactivated.
This simply marks the user as activate in the database and does not
attempt to rejoin rooms, re-add threepids, etc.
The user will also need a password hash set to actually login.
Args:
user_id: ID of user to be deactivated
"""
# Mark the user as activate.
await self.store.set_user_deactivated_status(user_id, False)

View File

@@ -77,8 +77,7 @@ class E2eKeysHandler(object):
)
@trace
@defer.inlineCallbacks
def query_devices(self, query_body, timeout, from_user_id):
async def query_devices(self, query_body, timeout, from_user_id):
""" Handle a device key query from a client
{
@@ -124,7 +123,7 @@ class E2eKeysHandler(object):
failures = {}
results = {}
if local_query:
local_result = yield self.query_local_devices(local_query)
local_result = await self.query_local_devices(local_query)
for user_id, keys in local_result.items():
if user_id in local_query:
results[user_id] = keys
@@ -142,7 +141,7 @@ class E2eKeysHandler(object):
(
user_ids_not_in_cache,
remote_results,
) = yield self.store.get_user_devices_from_cache(query_list)
) = await self.store.get_user_devices_from_cache(query_list)
for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.items():
@@ -161,14 +160,13 @@ class E2eKeysHandler(object):
r[user_id] = remote_queries[user_id]
# Get cached cross-signing keys
cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, from_user_id
)
# Now fetch any devices that we don't have in our cache
@trace
@defer.inlineCallbacks
def do_remote_query(destination):
async def do_remote_query(destination):
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
@@ -192,7 +190,7 @@ class E2eKeysHandler(object):
if device_list:
continue
room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
continue
@@ -201,11 +199,11 @@ class E2eKeysHandler(object):
# done an initial sync on the device list so we do it now.
try:
if self._is_master:
user_devices = yield self.device_handler.device_list_updater.user_device_resync(
user_devices = await self.device_handler.device_list_updater.user_device_resync(
user_id
)
else:
user_devices = yield self._user_device_resync_client(
user_devices = await self._user_device_resync_client(
user_id=user_id
)
@@ -227,7 +225,7 @@ class E2eKeysHandler(object):
destination_query.pop(user_id)
try:
remote_result = yield self.federation.query_client_keys(
remote_result = await self.federation.query_client_keys(
destination, {"device_keys": destination_query}, timeout=timeout
)
@@ -251,7 +249,7 @@ class E2eKeysHandler(object):
set_tag("error", True)
set_tag("reason", failure)
yield make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(do_remote_query, destination)
@@ -267,8 +265,7 @@ class E2eKeysHandler(object):
return ret
@defer.inlineCallbacks
def get_cross_signing_keys_from_cache(self, query, from_user_id):
async def get_cross_signing_keys_from_cache(self, query, from_user_id):
"""Get cross-signing keys for users from the database
Args:
@@ -289,7 +286,7 @@ class E2eKeysHandler(object):
user_ids = list(query)
keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
keys = await self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
for user_id, user_info in keys.items():
if user_info is None:
@@ -315,8 +312,7 @@ class E2eKeysHandler(object):
}
@trace
@defer.inlineCallbacks
def query_local_devices(self, query):
async def query_local_devices(self, query):
"""Get E2E device keys for local users
Args:
@@ -354,7 +350,7 @@ class E2eKeysHandler(object):
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
results = yield self.store.get_e2e_device_keys(local_query)
results = await self.store.get_e2e_device_keys(local_query)
# Build the result structure
for user_id, device_keys in results.items():
@@ -364,16 +360,15 @@ class E2eKeysHandler(object):
log_kv(results)
return result_dict
@defer.inlineCallbacks
def on_federation_query_client_keys(self, query_body):
async def on_federation_query_client_keys(self, query_body):
""" Handle a device key query from a federated server
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
res = await self.query_local_devices(device_keys_query)
ret = {"device_keys": res}
# add in the cross-signing keys
cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
cross_signing_keys = await self.get_cross_signing_keys_from_cache(
device_keys_query, None
)
@@ -382,8 +377,7 @@ class E2eKeysHandler(object):
return ret
@trace
@defer.inlineCallbacks
def claim_one_time_keys(self, query, timeout):
async def claim_one_time_keys(self, query, timeout):
local_query = []
remote_queries = {}
@@ -399,7 +393,7 @@ class E2eKeysHandler(object):
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
results = yield self.store.claim_e2e_one_time_keys(local_query)
results = await self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
failures = {}
@@ -411,12 +405,11 @@ class E2eKeysHandler(object):
}
@trace
@defer.inlineCallbacks
def claim_client_keys(destination):
async def claim_client_keys(destination):
set_tag("destination", destination)
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys(
remote_result = await self.federation.claim_client_keys(
destination, {"one_time_keys": device_keys}, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
@@ -429,7 +422,7 @@ class E2eKeysHandler(object):
set_tag("error", True)
set_tag("reason", failure)
yield make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(claim_client_keys, destination)
@@ -454,9 +447,8 @@ class E2eKeysHandler(object):
log_kv({"one_time_keys": json_result, "failures": failures})
return {"one_time_keys": json_result, "failures": failures}
@defer.inlineCallbacks
@tag_args
def upload_keys_for_user(self, user_id, device_id, keys):
async def upload_keys_for_user(self, user_id, device_id, keys):
time_now = self.clock.time_msec()
@@ -477,12 +469,12 @@ class E2eKeysHandler(object):
}
)
# TODO: Sign the JSON with the server key
changed = yield self.store.set_e2e_device_keys(
changed = await self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
yield self.device_handler.notify_device_update(user_id, [device_id])
await self.device_handler.notify_device_update(user_id, [device_id])
else:
log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
one_time_keys = keys.get("one_time_keys", None)
@@ -494,7 +486,7 @@ class E2eKeysHandler(object):
"device_id": device_id,
}
)
yield self._upload_one_time_keys_for_user(
await self._upload_one_time_keys_for_user(
user_id, device_id, time_now, one_time_keys
)
else:
@@ -507,15 +499,14 @@ class E2eKeysHandler(object):
# old access_token without an associated device_id. Either way, we
# need to double-check the device is registered to avoid ending up with
# keys without a corresponding device.
yield self.device_handler.check_device_registered(user_id, device_id)
await self.device_handler.check_device_registered(user_id, device_id)
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", result)
return {"one_time_key_counts": result}
@defer.inlineCallbacks
def _upload_one_time_keys_for_user(
async def _upload_one_time_keys_for_user(
self, user_id, device_id, time_now, one_time_keys
):
logger.info(
@@ -533,7 +524,7 @@ class E2eKeysHandler(object):
key_list.append((algorithm, key_id, key_obj))
# First we check if we have already persisted any of the keys.
existing_key_map = yield self.store.get_e2e_one_time_keys(
existing_key_map = await self.store.get_e2e_one_time_keys(
user_id, device_id, [k_id for _, k_id, _ in key_list]
)
@@ -556,10 +547,9 @@ class E2eKeysHandler(object):
)
log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
@defer.inlineCallbacks
def upload_signing_keys_for_user(self, user_id, keys):
async def upload_signing_keys_for_user(self, user_id, keys):
"""Upload signing keys for cross-signing
Args:
@@ -574,7 +564,7 @@ class E2eKeysHandler(object):
_check_cross_signing_key(master_key, user_id, "master")
else:
master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
# if there is no master key, then we can't do anything, because all the
# other cross-signing keys need to be signed by the master key
@@ -613,10 +603,10 @@ class E2eKeysHandler(object):
# if everything checks out, then store the keys and send notifications
deviceids = []
if "master_key" in keys:
yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
deviceids.append(master_verify_key.version)
if "self_signing_key" in keys:
yield self.store.set_e2e_cross_signing_key(
await self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
try:
@@ -626,23 +616,22 @@ class E2eKeysHandler(object):
except ValueError:
raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM)
if "user_signing_key" in keys:
yield self.store.set_e2e_cross_signing_key(
await self.store.set_e2e_cross_signing_key(
user_id, "user_signing", user_signing_key
)
# the signature stream matches the semantics that we want for
# user-signing key updates: only the user themselves is notified of
# their own user-signing key updates
yield self.device_handler.notify_user_signature_update(user_id, [user_id])
await self.device_handler.notify_user_signature_update(user_id, [user_id])
# master key and self-signing key updates match the semantics of device
# list updates: all users who share an encrypted room are notified
if len(deviceids):
yield self.device_handler.notify_device_update(user_id, deviceids)
await self.device_handler.notify_device_update(user_id, deviceids)
return {}
@defer.inlineCallbacks
def upload_signatures_for_device_keys(self, user_id, signatures):
async def upload_signatures_for_device_keys(self, user_id, signatures):
"""Upload device signatures for cross-signing
Args:
@@ -667,13 +656,13 @@ class E2eKeysHandler(object):
self_signatures = signatures.get(user_id, {})
other_signatures = {k: v for k, v in signatures.items() if k != user_id}
self_signature_list, self_failures = yield self._process_self_signatures(
self_signature_list, self_failures = await self._process_self_signatures(
user_id, self_signatures
)
signature_list.extend(self_signature_list)
failures.update(self_failures)
other_signature_list, other_failures = yield self._process_other_signatures(
other_signature_list, other_failures = await self._process_other_signatures(
user_id, other_signatures
)
signature_list.extend(other_signature_list)
@@ -681,21 +670,20 @@ class E2eKeysHandler(object):
# store the signature, and send the appropriate notifications for sync
logger.debug("upload signature failures: %r", failures)
yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
await self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
self_device_ids = [item.target_device_id for item in self_signature_list]
if self_device_ids:
yield self.device_handler.notify_device_update(user_id, self_device_ids)
await self.device_handler.notify_device_update(user_id, self_device_ids)
signed_users = [item.target_user_id for item in other_signature_list]
if signed_users:
yield self.device_handler.notify_user_signature_update(
await self.device_handler.notify_user_signature_update(
user_id, signed_users
)
return {"failures": failures}
@defer.inlineCallbacks
def _process_self_signatures(self, user_id, signatures):
async def _process_self_signatures(self, user_id, signatures):
"""Process uploaded signatures of the user's own keys.
Signatures of the user's own keys from this API come in two forms:
@@ -728,7 +716,7 @@ class E2eKeysHandler(object):
_,
self_signing_key_id,
self_signing_verify_key,
) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
) = await self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
# get our master key, since we may have received a signature of it.
# We need to fetch it here so that we know what its key ID is, so
@@ -738,12 +726,12 @@ class E2eKeysHandler(object):
master_key,
_,
master_verify_key,
) = yield self._get_e2e_cross_signing_verify_key(user_id, "master")
) = await self._get_e2e_cross_signing_verify_key(user_id, "master")
# fetch our stored devices. This is used to 1. verify
# signatures on the master key, and 2. to compare with what
# was sent if the device was signed
devices = yield self.store.get_e2e_device_keys([(user_id, None)])
devices = await self.store.get_e2e_device_keys([(user_id, None)])
if user_id not in devices:
raise NotFoundError("No device keys found")
@@ -853,8 +841,7 @@ class E2eKeysHandler(object):
return master_key_signature_list
@defer.inlineCallbacks
def _process_other_signatures(self, user_id, signatures):
async def _process_other_signatures(self, user_id, signatures):
"""Process uploaded signatures of other users' keys. These will be the
target user's master keys, signed by the uploading user's user-signing
key.
@@ -882,7 +869,7 @@ class E2eKeysHandler(object):
user_signing_key,
user_signing_key_id,
user_signing_verify_key,
) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
) = await self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
except SynapseError as e:
failure = _exception_to_failure(e)
for user, devicemap in signatures.items():
@@ -905,7 +892,7 @@ class E2eKeysHandler(object):
master_key,
master_key_id,
_,
) = yield self._get_e2e_cross_signing_verify_key(
) = await self._get_e2e_cross_signing_verify_key(
target_user, "master", user_id
)
@@ -958,8 +945,7 @@ class E2eKeysHandler(object):
return signature_list, failures
@defer.inlineCallbacks
def _get_e2e_cross_signing_verify_key(
async def _get_e2e_cross_signing_verify_key(
self, user_id: str, key_type: str, from_user_id: str = None
):
"""Fetch locally or remotely query for a cross-signing public key.
@@ -983,7 +969,7 @@ class E2eKeysHandler(object):
SynapseError: if `user_id` is invalid
"""
user = UserID.from_string(user_id)
key = yield self.store.get_e2e_cross_signing_key(
key = await self.store.get_e2e_cross_signing_key(
user_id, key_type, from_user_id
)
@@ -1009,15 +995,14 @@ class E2eKeysHandler(object):
key,
key_id,
verify_key,
) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
if key is None:
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
return key, key_id, verify_key
@defer.inlineCallbacks
def _retrieve_cross_signing_keys_for_remote_user(
async def _retrieve_cross_signing_keys_for_remote_user(
self, user: UserID, desired_key_type: str,
):
"""Queries cross-signing keys for a remote user and saves them to the database
@@ -1035,7 +1020,7 @@ class E2eKeysHandler(object):
If the key cannot be retrieved, all values in the tuple will instead be None.
"""
try:
remote_result = yield self.federation.query_user_devices(
remote_result = await self.federation.query_user_devices(
user.domain, user.to_string()
)
except Exception as e:
@@ -1101,14 +1086,14 @@ class E2eKeysHandler(object):
desired_key_id = key_id
# At the same time, store this key in the db for subsequent queries
yield self.store.set_e2e_cross_signing_key(
await self.store.set_e2e_cross_signing_key(
user.to_string(), key_type, key_content
)
# Notify clients that new devices for this user have been discovered
if retrieved_device_ids:
# XXX is this necessary?
yield self.device_handler.notify_device_update(
await self.device_handler.notify_device_update(
user.to_string(), retrieved_device_ids
)
@@ -1250,8 +1235,7 @@ class SigningKeyEduUpdater(object):
iterable=True,
)
@defer.inlineCallbacks
def incoming_signing_key_update(self, origin, edu_content):
async def incoming_signing_key_update(self, origin, edu_content):
"""Called on incoming signing key update from federation. Responsible for
parsing the EDU and adding to pending updates list.
@@ -1268,7 +1252,7 @@ class SigningKeyEduUpdater(object):
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
return
room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
@@ -1278,10 +1262,9 @@ class SigningKeyEduUpdater(object):
(master_key, self_signing_key)
)
yield self._handle_signing_key_updates(user_id)
await self._handle_signing_key_updates(user_id)
@defer.inlineCallbacks
def _handle_signing_key_updates(self, user_id):
async def _handle_signing_key_updates(self, user_id):
"""Actually handle pending updates.
Args:
@@ -1291,7 +1274,7 @@ class SigningKeyEduUpdater(object):
device_handler = self.e2e_keys_handler.device_handler
device_list_updater = device_handler.device_list_updater
with (yield self._remote_edu_linearizer.queue(user_id)):
with (await self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
@@ -1302,9 +1285,9 @@ class SigningKeyEduUpdater(object):
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
new_device_ids = yield device_list_updater.process_cross_signing_key_update(
new_device_ids = await device_list_updater.process_cross_signing_key_update(
user_id, master_key, self_signing_key,
)
device_ids = device_ids + new_device_ids
yield device_handler.notify_device_update(user_id, device_ids)
await device_handler.notify_device_update(user_id, device_ids)

View File

@@ -16,8 +16,6 @@
import logging
from twisted.internet import defer
from synapse.api.errors import (
Codes,
NotFoundError,
@@ -50,8 +48,7 @@ class E2eRoomKeysHandler(object):
self._upload_linearizer = Linearizer("upload_room_keys_lock")
@trace
@defer.inlineCallbacks
def get_room_keys(self, user_id, version, room_id=None, session_id=None):
async def get_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
@@ -71,17 +68,17 @@ class E2eRoomKeysHandler(object):
# we deliberately take the lock to get keys so that changing the version
# works atomically
with (yield self._upload_linearizer.queue(user_id)):
with (await self._upload_linearizer.queue(user_id)):
# make sure the backup version exists
try:
yield self.store.get_e2e_room_keys_version_info(user_id, version)
await self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
else:
raise
results = yield self.store.get_e2e_room_keys(
results = await self.store.get_e2e_room_keys(
user_id, version, room_id, session_id
)
@@ -89,8 +86,7 @@ class E2eRoomKeysHandler(object):
return results
@trace
@defer.inlineCallbacks
def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
async def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
@@ -109,10 +105,10 @@ class E2eRoomKeysHandler(object):
"""
# lock for consistency with uploading
with (yield self._upload_linearizer.queue(user_id)):
with (await self._upload_linearizer.queue(user_id)):
# make sure the backup version exists
try:
version_info = yield self.store.get_e2e_room_keys_version_info(
version_info = await self.store.get_e2e_room_keys_version_info(
user_id, version
)
except StoreError as e:
@@ -121,19 +117,18 @@ class E2eRoomKeysHandler(object):
else:
raise
yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
await self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
version_etag = version_info["etag"] + 1
yield self.store.update_e2e_room_keys_version(
await self.store.update_e2e_room_keys_version(
user_id, version, None, version_etag
)
count = yield self.store.count_e2e_room_keys(user_id, version)
count = await self.store.count_e2e_room_keys(user_id, version)
return {"etag": str(version_etag), "count": count}
@trace
@defer.inlineCallbacks
def upload_room_keys(self, user_id, version, room_keys):
async def upload_room_keys(self, user_id, version, room_keys):
"""Bulk upload a list of room keys into a given backup version, asserting
that the given version is the current backup version. room_keys are merged
into the current backup as described in RoomKeysServlet.on_PUT().
@@ -169,11 +164,11 @@ class E2eRoomKeysHandler(object):
# TODO: Validate the JSON to make sure it has the right keys.
# XXX: perhaps we should use a finer grained lock here?
with (yield self._upload_linearizer.queue(user_id)):
with (await self._upload_linearizer.queue(user_id)):
# Check that the version we're trying to upload is the current version
try:
version_info = yield self.store.get_e2e_room_keys_version_info(user_id)
version_info = await self.store.get_e2e_room_keys_version_info(user_id)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Version '%s' not found" % (version,))
@@ -183,7 +178,7 @@ class E2eRoomKeysHandler(object):
if version_info["version"] != version:
# Check that the version we're trying to upload actually exists
try:
version_info = yield self.store.get_e2e_room_keys_version_info(
version_info = await self.store.get_e2e_room_keys_version_info(
user_id, version
)
# if we get this far, the version must exist
@@ -198,7 +193,7 @@ class E2eRoomKeysHandler(object):
# submitted. Then compare them with the submitted keys. If the
# key is new, insert it; if the key should be updated, then update
# it; otherwise, drop it.
existing_keys = yield self.store.get_e2e_room_keys_multi(
existing_keys = await self.store.get_e2e_room_keys_multi(
user_id, version, room_keys["rooms"]
)
to_insert = [] # batch the inserts together
@@ -227,7 +222,7 @@ class E2eRoomKeysHandler(object):
# updates are done one at a time in the DB, so send
# updates right away rather than batching them up,
# like we do with the inserts
yield self.store.update_e2e_room_key(
await self.store.update_e2e_room_key(
user_id, version, room_id, session_id, room_key
)
changed = True
@@ -246,16 +241,16 @@ class E2eRoomKeysHandler(object):
changed = True
if len(to_insert):
yield self.store.add_e2e_room_keys(user_id, version, to_insert)
await self.store.add_e2e_room_keys(user_id, version, to_insert)
version_etag = version_info["etag"]
if changed:
version_etag = version_etag + 1
yield self.store.update_e2e_room_keys_version(
await self.store.update_e2e_room_keys_version(
user_id, version, None, version_etag
)
count = yield self.store.count_e2e_room_keys(user_id, version)
count = await self.store.count_e2e_room_keys(user_id, version)
return {"etag": str(version_etag), "count": count}
@staticmethod
@@ -291,8 +286,7 @@ class E2eRoomKeysHandler(object):
return True
@trace
@defer.inlineCallbacks
def create_version(self, user_id, version_info):
async def create_version(self, user_id, version_info):
"""Create a new backup version. This automatically becomes the new
backup version for the user's keys; previous backups will no longer be
writeable to.
@@ -313,14 +307,13 @@ class E2eRoomKeysHandler(object):
# TODO: Validate the JSON to make sure it has the right keys.
# lock everyone out until we've switched version
with (yield self._upload_linearizer.queue(user_id)):
new_version = yield self.store.create_e2e_room_keys_version(
with (await self._upload_linearizer.queue(user_id)):
new_version = await self.store.create_e2e_room_keys_version(
user_id, version_info
)
return new_version
@defer.inlineCallbacks
def get_version_info(self, user_id, version=None):
async def get_version_info(self, user_id, version=None):
"""Get the info about a given version of the user's backup
Args:
@@ -339,22 +332,21 @@ class E2eRoomKeysHandler(object):
}
"""
with (yield self._upload_linearizer.queue(user_id)):
with (await self._upload_linearizer.queue(user_id)):
try:
res = yield self.store.get_e2e_room_keys_version_info(user_id, version)
res = await self.store.get_e2e_room_keys_version_info(user_id, version)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
else:
raise
res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"])
res["count"] = await self.store.count_e2e_room_keys(user_id, res["version"])
res["etag"] = str(res["etag"])
return res
@trace
@defer.inlineCallbacks
def delete_version(self, user_id, version=None):
async def delete_version(self, user_id, version=None):
"""Deletes a given version of the user's e2e_room_keys backup
Args:
@@ -364,9 +356,9 @@ class E2eRoomKeysHandler(object):
NotFoundError: if this backup version doesn't exist
"""
with (yield self._upload_linearizer.queue(user_id)):
with (await self._upload_linearizer.queue(user_id)):
try:
yield self.store.delete_e2e_room_keys_version(user_id, version)
await self.store.delete_e2e_room_keys_version(user_id, version)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown backup version")
@@ -374,8 +366,7 @@ class E2eRoomKeysHandler(object):
raise
@trace
@defer.inlineCallbacks
def update_version(self, user_id, version, version_info):
async def update_version(self, user_id, version, version_info):
"""Update the info about a given version of the user's backup
Args:
@@ -393,9 +384,9 @@ class E2eRoomKeysHandler(object):
raise SynapseError(
400, "Version in body does not match", Codes.INVALID_PARAM
)
with (yield self._upload_linearizer.queue(user_id)):
with (await self._upload_linearizer.queue(user_id)):
try:
old_info = yield self.store.get_e2e_room_keys_version_info(
old_info = await self.store.get_e2e_room_keys_version_info(
user_id, version
)
except StoreError as e:
@@ -406,7 +397,7 @@ class E2eRoomKeysHandler(object):
if old_info["algorithm"] != version_info["algorithm"]:
raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM)
yield self.store.update_e2e_room_keys_version(
await self.store.update_e2e_room_keys_version(
user_id, version, version_info
)

View File

@@ -185,7 +185,7 @@ class TypingHandler(object):
async def _push_remote(self, member, typing):
try:
users = await self.state.get_current_users_in_room(member.room_id)
users = await self.store.get_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec()
@@ -224,7 +224,7 @@ class TypingHandler(object):
)
return
users = await self.state.get_current_users_in_room(room_id)
users = await self.store.get_users_in_room(room_id)
domains = {get_domain_from_id(u) for u in users}
if self.server_name in domains:

View File

@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from canonicaljson import json
@@ -117,7 +118,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = json.loads(data)
resp_body = json.loads(data.decode("utf-8"))
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly

View File

@@ -13,13 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import urllib
from io import BytesIO
import treq
from canonicaljson import encode_canonical_json
from canonicaljson import encode_canonical_json, json
from netaddr import IPAddress
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -31,6 +31,7 @@ from twisted.internet.interfaces import (
IReactorPluggableNameResolver,
IResolutionReceiver,
)
from twisted.internet.task import Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from twisted.web.client import Agent, HTTPConnectionPool, readBody
@@ -69,6 +70,21 @@ def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
return False
_EPSILON = 0.00000001
def _make_scheduler(reactor):
"""Makes a schedular suitable for a Cooperator using the given reactor.
(This is effectively just a copy from `twisted.internet.task`)
"""
def _scheduler(x):
return reactor.callLater(_EPSILON, x)
return _scheduler
class IPBlacklistingResolver(object):
"""
A proxy for reactor.nameResolver which only produces non-blacklisted IP
@@ -212,6 +228,10 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
# We use this for our body producers to ensure that they use the correct
# reactor.
self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))
self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist:
@@ -292,7 +312,9 @@ class SimpleHttpClient(object):
try:
body_producer = None
if data is not None:
body_producer = QuieterFileBodyProducer(BytesIO(data))
body_producer = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator,
)
request_deferred = treq.request(
method,

View File

@@ -14,9 +14,11 @@
# limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """
import json
import logging
from canonicaljson import json
from synapse.api.errors import Codes, SynapseError
logger = logging.getLogger(__name__)

View File

@@ -239,6 +239,15 @@ class UserRestServletV2(RestServlet):
await self.deactivate_account_handler.deactivate_account(
target_user.to_string(), False
)
elif not deactivate and user["deactivated"]:
if "password" not in body:
raise SynapseError(
400, "Must provide a password to re-activate an account."
)
await self.deactivate_account_handler.activate_account(
target_user.to_string()
)
user = await self.admin_handler.get_user(target_user)
return 200, user
@@ -254,7 +263,6 @@ class UserRestServletV2(RestServlet):
admin = body.get("admin", None)
user_type = body.get("user_type", None)
displayname = body.get("displayname", None)
threepids = body.get("threepids", None)
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
raise SynapseError(400, "Invalid user type")

View File

@@ -371,7 +371,7 @@ class LoginRestServlet(RestServlet):
token = login_submission.get("token", None)
if token is None:
raise LoginError(
401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
403, "Token field for JWT is missing", errcode=Codes.FORBIDDEN
)
import jwt
@@ -387,14 +387,12 @@ class LoginRestServlet(RestServlet):
except jwt.PyJWTError as e:
# A JWT error occurred, return some info back to the client.
raise LoginError(
401,
"JWT validation failed: %s" % (str(e),),
errcode=Codes.UNAUTHORIZED,
403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
)
user = payload.get("sub", None)
if user is None:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
raise LoginError(403, "Invalid JWT", errcode=Codes.FORBIDDEN)
user_id = UserID(user, self.hs.hostname).to_string()
result = await self._complete_login(

View File

@@ -15,6 +15,7 @@
# limitations under the License.
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
import re
from typing import List, Optional
@@ -515,9 +516,9 @@ class RoomMessageListRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request, default_limit=10)
as_client_event = b"raw" not in request.args
filter_bytes = parse_string(request, b"filter", encoding=None)
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
if (
event_filter
@@ -627,9 +628,9 @@ class RoomEventContextServlet(RestServlet):
limit = parse_integer(request, "limit", default=10)
# picking the API shape for symmetry with /messages
filter_bytes = parse_string(request, "filter")
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes)
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
else:
event_filter = None

View File

@@ -202,9 +202,11 @@ class RemoteKey(DirectServeJsonResource):
if miss:
cache_misses.setdefault(server_name, set()).add(key_id)
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
else:
for ts_added, result in results:
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss:
@@ -213,7 +215,7 @@ class RemoteKey(DirectServeJsonResource):
else:
signed_keys = []
for key_json in json_results:
key_json = json.loads(key_json)
key_json = json.loads(key_json.decode("utf-8"))
for signing_key in self.config.key_server_signing_keys:
key_json = sign_json(key_json, self.config.server_name, signing_key)

View File

@@ -106,7 +106,7 @@ from synapse.server_notices.worker_server_notices_sender import (
WorkerServerNoticesSender,
)
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStores, Storage
from synapse.storage import DataStore, DataStores, Storage
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
@@ -314,7 +314,7 @@ class HomeServer(object):
def get_clock(self):
return self.clock
def get_datastore(self):
def get_datastore(self) -> DataStore:
return self.datastores.main
def get_datastores(self):

View File

@@ -20,6 +20,7 @@ import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
import synapse.http.client
import synapse.http.matrixfederationclient
import synapse.notifier
import synapse.push.pusherpool
import synapse.replication.tcp.client
@@ -143,3 +144,7 @@ class HomeServer(object):
pass
def get_replication_streams(self) -> Dict[str, Stream]:
pass
def get_http_client(
self,
) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient:
pass

View File

@@ -17,7 +17,6 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
from functools import wraps
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
import attr
@@ -69,27 +68,6 @@ def encode_json(json_object):
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
def _retry_on_integrity_error(func):
"""Wraps a database function so that it gets retried on IntegrityError,
with `delete_existing=True` passed in.
Args:
func: function that returns a Deferred and accepts a `delete_existing` arg
"""
@wraps(func)
@defer.inlineCallbacks
def f(self, *args, **kwargs):
try:
res = yield func(self, *args, delete_existing=False, **kwargs)
except self.database_engine.module.IntegrityError:
logger.exception("IntegrityError, retrying.")
res = yield func(self, *args, delete_existing=True, **kwargs)
return res
return f
@attr.s(slots=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
@@ -134,7 +112,6 @@ class PersistEventsStore:
hs.config.worker.writers.events == hs.get_instance_name()
), "Can only instantiate EventsStore on master"
@_retry_on_integrity_error
@defer.inlineCallbacks
def _persist_events_and_state_updates(
self,
@@ -143,7 +120,6 @@ class PersistEventsStore:
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
delete_existing: bool = False,
):
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
@@ -157,7 +133,6 @@ class PersistEventsStore:
new_forward_extremities: Map from room_id to list of event IDs
that are the new forward extremities of the room.
backfilled
delete_existing
Returns:
Deferred: resolves when the events have been persisted
@@ -197,7 +172,6 @@ class PersistEventsStore:
self._persist_events_txn,
events_and_contexts=events_and_contexts,
backfilled=backfilled,
delete_existing=delete_existing,
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
@@ -341,7 +315,6 @@ class PersistEventsStore:
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
delete_existing: bool = False,
state_delta_for_room: Dict[str, DeltaState] = {},
new_forward_extremeties: Dict[str, List[str]] = {},
):
@@ -393,13 +366,6 @@ class PersistEventsStore:
# From this point onwards the events are only events that we haven't
# seen before.
if delete_existing:
# For paranoia reasons, we go and delete all the existing entries
# for these events so we can reinsert them.
# This gets around any problems with some tables already having
# entries.
self._delete_existing_rows_txn(txn, events_and_contexts=events_and_contexts)
self._store_event_txn(txn, events_and_contexts=events_and_contexts)
# Insert into event_to_state_groups.
@@ -797,39 +763,6 @@ class PersistEventsStore:
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
@classmethod
def _delete_existing_rows_txn(cls, txn, events_and_contexts):
if not events_and_contexts:
# nothing to do here
return
logger.info("Deleting existing")
for table in (
"events",
"event_auth",
"event_json",
"event_edges",
"event_forward_extremities",
"event_reference_hashes",
"event_search",
"event_to_state_groups",
"state_events",
"rejections",
"redactions",
"room_memberships",
):
txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,),
[(ev.event_id,) for ev, _ in events_and_contexts],
)
for table in ("event_push_actions",):
txn.executemany(
"DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
[(ev.room_id, ev.event_id) for ev, _ in events_and_contexts],
)
def _store_event_txn(self, txn, events_and_contexts):
"""Insert new events into the event and event_json tables

View File

@@ -353,6 +353,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
last_room_id = progress.get("last_room_id", "")
def _background_remove_left_rooms_txn(txn):
# get a batch of room ids to consider
sql = """
SELECT DISTINCT room_id FROM current_state_events
WHERE room_id > ? ORDER BY room_id LIMIT ?
@@ -363,24 +364,68 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
if not room_ids:
return True, set()
###########################################################################
#
# exclude rooms where we have active members
sql = """
SELECT room_id
FROM current_state_events
FROM local_current_membership
WHERE
room_id > ? AND room_id <= ?
AND type = 'm.room.member'
AND membership = 'join'
AND state_key LIKE ?
GROUP BY room_id
"""
txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
txn.execute(sql, (last_room_id, room_ids[-1]))
joined_room_ids = {row[0] for row in txn}
to_delete = set(room_ids) - joined_room_ids
left_rooms = set(room_ids) - joined_room_ids
###########################################################################
#
# exclude rooms which we are in the process of constructing; these otherwise
# qualify as "rooms with no local users", and would have their
# forward extremities cleaned up.
logger.info("Deleting current state left rooms: %r", left_rooms)
# the following query will return a list of rooms which have forward
# extremities that are *not* also the create event in the room - ie
# those that are not being created currently.
sql = """
SELECT DISTINCT efe.room_id
FROM event_forward_extremities efe
LEFT JOIN current_state_events cse ON
cse.event_id = efe.event_id
AND cse.type = 'm.room.create'
AND cse.state_key = ''
WHERE
cse.event_id IS NULL
AND efe.room_id > ? AND efe.room_id <= ?
"""
txn.execute(sql, (last_room_id, room_ids[-1]))
# build a set of those rooms within `to_delete` that do not appear in
# the above, leaving us with the rooms in `to_delete` that *are* being
# created.
creating_rooms = to_delete.difference(row[0] for row in txn)
logger.info("skipping rooms which are being created: %s", creating_rooms)
# now remove the rooms being created from the list of those to delete.
#
# (we could have just taken the intersection of `to_delete` with the result
# of the sql query, but it's useful to be able to log `creating_rooms`; and
# having done so, it's quicker to remove the (few) creating rooms from
# `to_delete` than it is to form the intersection with the (larger) list of
# not-creating-rooms)
to_delete -= creating_rooms
###########################################################################
#
# now clear the state for the rooms
logger.info("Deleting current state left rooms: %r", to_delete)
# First we get all users that we still think were joined to the
# room. This is so that we can mark those device lists as
@@ -391,7 +436,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="current_state_events",
column="room_id",
iterable=left_rooms,
iterable=to_delete,
keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
retcols=("state_key",),
)
@@ -403,7 +448,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="current_state_events",
column="room_id",
iterable=left_rooms,
iterable=to_delete,
keyvalues={},
)
@@ -411,7 +456,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
txn,
table="event_forward_extremities",
column="room_id",
iterable=left_rooms,
iterable=to_delete,
keyvalues={},
)

View File

@@ -46,7 +46,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"""If the user has no devices, we expect an empty list.
"""
local_user = "@boris:" + self.hs.hostname
res = yield self.handler.query_local_devices({local_user: None})
res = yield defer.ensureDeferred(
self.handler.query_local_devices({local_user: None})
)
self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
@@ -60,15 +62,19 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
res = yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
)
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
# we should be able to change the signature without a problem
keys["alg2:k2"]["signatures"]["k1"] = "sig2"
res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
res = yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
)
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
@@ -84,44 +90,56 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"alg2:k3": {"key": "key3"},
}
res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
res = yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
)
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
try:
yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
)
)
self.fail("No error when changing string key")
except errors.SynapseError:
pass
try:
yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
)
)
self.fail("No error when replacing dict key with string")
except errors.SynapseError:
pass
try:
yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}}
yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user,
device_id,
{"one_time_keys": {"alg1:k1": {"key": "key"}}},
)
)
self.fail("No error when replacing string key with dict")
except errors.SynapseError:
pass
try:
yield self.handler.upload_keys_for_user(
local_user,
device_id,
{
"one_time_keys": {
"alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
}
},
yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user,
device_id,
{
"one_time_keys": {
"alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
}
},
)
)
self.fail("No error when replacing dict key")
except errors.SynapseError:
@@ -133,13 +151,17 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
device_id = "xyz"
keys = {"alg1:k1": "key1"}
res = yield self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
res = yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"one_time_keys": keys}
)
)
self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
res2 = yield self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
res2 = yield defer.ensureDeferred(
self.handler.claim_one_time_keys(
{"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
)
)
self.assertEqual(
res2,
@@ -163,7 +185,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
yield self.handler.upload_signing_keys_for_user(local_user, keys1)
yield defer.ensureDeferred(
self.handler.upload_signing_keys_for_user(local_user, keys1)
)
keys2 = {
"master_key": {
@@ -175,10 +199,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
yield self.handler.upload_signing_keys_for_user(local_user, keys2)
yield defer.ensureDeferred(
self.handler.upload_signing_keys_for_user(local_user, keys2)
)
devices = yield self.handler.query_devices(
{"device_keys": {local_user: []}}, 0, local_user
devices = yield defer.ensureDeferred(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@@ -215,7 +241,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
"2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
)
yield self.handler.upload_signing_keys_for_user(local_user, keys1)
yield defer.ensureDeferred(
self.handler.upload_signing_keys_for_user(local_user, keys1)
)
# upload two device keys, which will be signed later by the self-signing key
device_key_1 = {
@@ -245,18 +273,24 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"signatures": {local_user: {"ed25519:def": "base64+signature"}},
}
yield self.handler.upload_keys_for_user(
local_user, "abc", {"device_keys": device_key_1}
yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, "abc", {"device_keys": device_key_1}
)
)
yield self.handler.upload_keys_for_user(
local_user, "def", {"device_keys": device_key_2}
yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, "def", {"device_keys": device_key_2}
)
)
# sign the first device key and upload it
del device_key_1["signatures"]
sign.sign_json(device_key_1, local_user, signing_key)
yield self.handler.upload_signatures_for_device_keys(
local_user, {local_user: {"abc": device_key_1}}
yield defer.ensureDeferred(
self.handler.upload_signatures_for_device_keys(
local_user, {local_user: {"abc": device_key_1}}
)
)
# sign the second device key and upload both device keys. The server
@@ -264,14 +298,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
# signature for it
del device_key_2["signatures"]
sign.sign_json(device_key_2, local_user, signing_key)
yield self.handler.upload_signatures_for_device_keys(
local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
yield defer.ensureDeferred(
self.handler.upload_signatures_for_device_keys(
local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
)
)
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
devices = yield self.handler.query_devices(
{"device_keys": {local_user: []}}, 0, local_user
devices = yield defer.ensureDeferred(
self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
)
del devices["device_keys"][local_user]["abc"]["unsigned"]
del devices["device_keys"][local_user]["def"]["unsigned"]
@@ -292,7 +328,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
},
}
}
yield self.handler.upload_signing_keys_for_user(local_user, keys1)
yield defer.ensureDeferred(
self.handler.upload_signing_keys_for_user(local_user, keys1)
)
res = None
try:
@@ -305,7 +343,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
res = e.code
self.assertEqual(res, 400)
res = yield self.handler.query_local_devices({local_user: None})
res = yield defer.ensureDeferred(
self.handler.query_local_devices({local_user: None})
)
self.assertDictEqual(res, {local_user: {}})
@defer.inlineCallbacks
@@ -331,8 +371,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
)
yield self.handler.upload_keys_for_user(
local_user, device_id, {"device_keys": device_key}
yield defer.ensureDeferred(
self.handler.upload_keys_for_user(
local_user, device_id, {"device_keys": device_key}
)
)
# private key: 2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0
@@ -372,7 +414,9 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"user_signing_key": usersigning_key,
"self_signing_key": selfsigning_key,
}
yield self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
yield defer.ensureDeferred(
self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
)
# set up another user with a master key. This user will be signed by
# the first user
@@ -384,76 +428,90 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
"usage": ["master"],
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
}
yield self.handler.upload_signing_keys_for_user(
other_user, {"master_key": other_master_key}
yield defer.ensureDeferred(
self.handler.upload_signing_keys_for_user(
other_user, {"master_key": other_master_key}
)
)
# test various signature failures (see below)
ret = yield self.handler.upload_signatures_for_device_keys(
local_user,
{
local_user: {
# fails because the signature is invalid
# should fail with INVALID_SIGNATURE
device_id: {
"user_id": local_user,
"device_id": device_id,
"algorithms": [
"m.olm.curve25519-aes-sha2",
RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
],
"keys": {
"curve25519:xyz": "curve25519+key",
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
"ed25519:xyz": device_pubkey,
ret = yield defer.ensureDeferred(
self.handler.upload_signatures_for_device_keys(
local_user,
{
local_user: {
# fails because the signature is invalid
# should fail with INVALID_SIGNATURE
device_id: {
"user_id": local_user,
"device_id": device_id,
"algorithms": [
"m.olm.curve25519-aes-sha2",
RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
],
"keys": {
"curve25519:xyz": "curve25519+key",
# private key: OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA
"ed25519:xyz": device_pubkey,
},
"signatures": {
local_user: {
"ed25519:" + selfsigning_pubkey: "something"
}
},
},
"signatures": {
local_user: {"ed25519:" + selfsigning_pubkey: "something"}
# fails because device is unknown
# should fail with NOT_FOUND
"unknown": {
"user_id": local_user,
"device_id": "unknown",
"signatures": {
local_user: {
"ed25519:" + selfsigning_pubkey: "something"
}
},
},
# fails because the signature is invalid
# should fail with INVALID_SIGNATURE
master_pubkey: {
"user_id": local_user,
"usage": ["master"],
"keys": {"ed25519:" + master_pubkey: master_pubkey},
"signatures": {
local_user: {"ed25519:" + device_pubkey: "something"}
},
},
},
# fails because device is unknown
# should fail with NOT_FOUND
"unknown": {
"user_id": local_user,
"device_id": "unknown",
"signatures": {
local_user: {"ed25519:" + selfsigning_pubkey: "something"}
other_user: {
# fails because the device is not the user's master-signing key
# should fail with NOT_FOUND
"unknown": {
"user_id": other_user,
"device_id": "unknown",
"signatures": {
local_user: {
"ed25519:" + usersigning_pubkey: "something"
}
},
},
},
# fails because the signature is invalid
# should fail with INVALID_SIGNATURE
master_pubkey: {
"user_id": local_user,
"usage": ["master"],
"keys": {"ed25519:" + master_pubkey: master_pubkey},
"signatures": {
local_user: {"ed25519:" + device_pubkey: "something"}
other_master_pubkey: {
# fails because the key doesn't match what the server has
# should fail with UNKNOWN
"user_id": other_user,
"usage": ["master"],
"keys": {
"ed25519:" + other_master_pubkey: other_master_pubkey
},
"something": "random",
"signatures": {
local_user: {
"ed25519:" + usersigning_pubkey: "something"
}
},
},
},
},
other_user: {
# fails because the device is not the user's master-signing key
# should fail with NOT_FOUND
"unknown": {
"user_id": other_user,
"device_id": "unknown",
"signatures": {
local_user: {"ed25519:" + usersigning_pubkey: "something"}
},
},
other_master_pubkey: {
# fails because the key doesn't match what the server has
# should fail with UNKNOWN
"user_id": other_user,
"usage": ["master"],
"keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
"something": "random",
"signatures": {
local_user: {"ed25519:" + usersigning_pubkey: "something"}
},
},
},
},
)
)
user_failures = ret["failures"][local_user]
@@ -478,19 +536,23 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
sign.sign_json(device_key, local_user, selfsigning_signing_key)
sign.sign_json(master_key, local_user, device_signing_key)
sign.sign_json(other_master_key, local_user, usersigning_signing_key)
ret = yield self.handler.upload_signatures_for_device_keys(
local_user,
{
local_user: {device_id: device_key, master_pubkey: master_key},
other_user: {other_master_pubkey: other_master_key},
},
ret = yield defer.ensureDeferred(
self.handler.upload_signatures_for_device_keys(
local_user,
{
local_user: {device_id: device_key, master_pubkey: master_key},
other_user: {other_master_pubkey: other_master_key},
},
)
)
self.assertEqual(ret["failures"], {})
# fetch the signed keys/devices and make sure that the signatures are there
ret = yield self.handler.query_devices(
{"device_keys": {local_user: [], other_user: []}}, 0, local_user
ret = yield defer.ensureDeferred(
self.handler.query_devices(
{"device_keys": {local_user: [], other_user: []}}, 0, local_user
)
)
self.assertEqual(

View File

@@ -66,7 +66,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
yield self.handler.get_version_info(self.local_user)
yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -78,7 +78,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
yield self.handler.get_version_info(self.local_user, "bogus_version")
yield defer.ensureDeferred(
self.handler.get_version_info(self.local_user, "bogus_version")
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -87,14 +89,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_create_version(self):
"""Check that we can create and then retrieve versions.
"""
res = yield self.handler.create_version(
self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
res = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
)
self.assertEqual(res, "1")
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
version_etag = res["etag"]
self.assertIsInstance(version_etag, str)
del res["etag"]
@@ -109,7 +116,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# check we can retrieve it as a specific version
res = yield self.handler.get_version_info(self.local_user, "1")
res = yield defer.ensureDeferred(
self.handler.get_version_info(self.local_user, "1")
)
self.assertEqual(res["etag"], version_etag)
del res["etag"]
self.assertDictEqual(
@@ -123,17 +132,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
)
# upload a new one...
res = yield self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "second_version_auth_data",
},
res = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "second_version_auth_data",
},
)
)
self.assertEqual(res, "2")
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -149,25 +160,32 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_version(self):
"""Check that we can update versions.
"""
version = yield self.handler.create_version(
self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
)
self.assertEqual(version, "1")
res = yield self.handler.update_version(
self.local_user,
version,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": version,
},
res = yield defer.ensureDeferred(
self.handler.update_version(
self.local_user,
version,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": version,
},
)
)
self.assertDictEqual(res, {})
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"]
self.assertDictEqual(
res,
@@ -185,14 +203,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
yield self.handler.update_version(
self.local_user,
"1",
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": "1",
},
yield defer.ensureDeferred(
self.handler.update_version(
self.local_user,
"1",
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": "1",
},
)
)
except errors.SynapseError as e:
res = e.code
@@ -202,23 +222,30 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_omitted_version(self):
"""Check that the update succeeds if the version is missing from the body
"""
version = yield self.handler.create_version(
self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
)
self.assertEqual(version, "1")
yield self.handler.update_version(
self.local_user,
version,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
},
yield defer.ensureDeferred(
self.handler.update_version(
self.local_user,
version,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
},
)
)
# check we can retrieve it as the current version
res = yield self.handler.get_version_info(self.local_user)
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
del res["etag"] # etag is opaque, so don't test its contents
self.assertDictEqual(
res,
@@ -234,22 +261,29 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_update_bad_version(self):
"""Check that we get a 400 if the version in the body doesn't match
"""
version = yield self.handler.create_version(
self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
)
self.assertEqual(version, "1")
res = None
try:
yield self.handler.update_version(
self.local_user,
version,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": "incorrect",
},
yield defer.ensureDeferred(
self.handler.update_version(
self.local_user,
version,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "revised_first_version_auth_data",
"version": "incorrect",
},
)
)
except errors.SynapseError as e:
res = e.code
@@ -261,7 +295,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
yield self.handler.delete_version(self.local_user, "1")
yield defer.ensureDeferred(
self.handler.delete_version(self.local_user, "1")
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -272,7 +308,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
yield self.handler.delete_version(self.local_user)
yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -281,19 +317,26 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_version(self):
"""Check that we can create and then delete versions.
"""
res = yield self.handler.create_version(
self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
res = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
)
self.assertEqual(res, "1")
# check we can delete it
yield self.handler.delete_version(self.local_user, "1")
yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
# check that it's gone
res = None
try:
yield self.handler.get_version_info(self.local_user, "1")
yield defer.ensureDeferred(
self.handler.get_version_info(self.local_user, "1")
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -304,7 +347,9 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
yield self.handler.get_room_keys(self.local_user, "bogus_version")
yield defer.ensureDeferred(
self.handler.get_room_keys(self.local_user, "bogus_version")
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 404)
@@ -313,13 +358,20 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_get_missing_room_keys(self):
"""Check we get an empty response from an empty backup
"""
version = yield self.handler.create_version(
self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
)
self.assertEqual(version, "1")
res = yield self.handler.get_room_keys(self.local_user, version)
res = yield defer.ensureDeferred(
self.handler.get_room_keys(self.local_user, version)
)
self.assertDictEqual(res, {"rooms": {}})
# TODO: test the locking semantics when uploading room_keys,
@@ -331,8 +383,8 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""
res = None
try:
yield self.handler.upload_room_keys(
self.local_user, "no_version", room_keys
yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
)
except errors.SynapseError as e:
res = e.code
@@ -343,16 +395,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
"""Check that we get a 404 on uploading keys when an nonexistent version
is specified
"""
version = yield self.handler.create_version(
self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
)
self.assertEqual(version, "1")
res = None
try:
yield self.handler.upload_room_keys(
self.local_user, "bogus_version", room_keys
yield defer.ensureDeferred(
self.handler.upload_room_keys(
self.local_user, "bogus_version", room_keys
)
)
except errors.SynapseError as e:
res = e.code
@@ -362,24 +421,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_wrong_version(self):
"""Check that we get a 403 on uploading keys for an old version
"""
version = yield self.handler.create_version(
self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
)
self.assertEqual(version, "1")
version = yield self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "second_version_auth_data",
},
version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "second_version_auth_data",
},
)
)
self.assertEqual(version, "2")
res = None
try:
yield self.handler.upload_room_keys(self.local_user, "1", room_keys)
yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, "1", room_keys)
)
except errors.SynapseError as e:
res = e.code
self.assertEqual(res, 403)
@@ -388,26 +456,39 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_insert(self):
"""Check that we can insert and retrieve keys for a session
"""
version = yield self.handler.create_version(
self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
)
self.assertEqual(version, "1")
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
res = yield self.handler.get_room_keys(self.local_user, version)
res = yield defer.ensureDeferred(
self.handler.get_room_keys(self.local_user, version)
)
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given room
res = yield self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org"
res = yield defer.ensureDeferred(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org"
)
)
self.assertDictEqual(res, room_keys)
# check getting room_keys for a given session_id
res = yield self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
res = yield defer.ensureDeferred(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
)
self.assertDictEqual(res, room_keys)
@@ -415,16 +496,23 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_upload_room_keys_merge(self):
"""Check that we can upload a new room_key for an existing session and
have it correctly merged"""
version = yield self.handler.create_version(
self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
)
self.assertEqual(version, "1")
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
# get the etag to compare to future versions
res = yield self.handler.get_version_info(self.local_user)
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
backup_etag = res["etag"]
self.assertEqual(res["count"], 1)
@@ -434,29 +522,37 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# test that increasing the message_index doesn't replace the existing session
new_room_key["first_message_index"] = 2
new_room_key["session_data"] = "new"
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
res = yield self.handler.get_room_keys(self.local_user, version)
res = yield defer.ensureDeferred(
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
"SSBBTSBBIEZJU0gK",
)
# the etag should be the same since the session did not change
res = yield self.handler.get_version_info(self.local_user)
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# test that marking the session as verified however /does/ replace it
new_room_key["is_verified"] = True
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
res = yield self.handler.get_room_keys(self.local_user, version)
res = yield defer.ensureDeferred(
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should NOT be equal now, since the key changed
res = yield self.handler.get_version_info(self.local_user)
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertNotEqual(res["etag"], backup_etag)
backup_etag = res["etag"]
@@ -464,15 +560,19 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
# with a lower forwarding count
new_room_key["forwarded_count"] = 2
new_room_key["session_data"] = "other"
yield self.handler.upload_room_keys(self.local_user, version, new_room_keys)
yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, new_room_keys)
)
res = yield self.handler.get_room_keys(self.local_user, version)
res = yield defer.ensureDeferred(
self.handler.get_room_keys(self.local_user, version)
)
self.assertEqual(
res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
)
# the etag should be the same since the session did not change
res = yield self.handler.get_version_info(self.local_user)
res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
self.assertEqual(res["etag"], backup_etag)
# TODO: check edge cases as well as the common variations here
@@ -481,36 +581,59 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
def test_delete_room_keys(self):
"""Check that we can insert and delete keys for a session
"""
version = yield self.handler.create_version(
self.local_user,
{"algorithm": "m.megolm_backup.v1", "auth_data": "first_version_auth_data"},
version = yield defer.ensureDeferred(
self.handler.create_version(
self.local_user,
{
"algorithm": "m.megolm_backup.v1",
"auth_data": "first_version_auth_data",
},
)
)
self.assertEqual(version, "1")
# check for bulk-delete
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
yield self.handler.delete_room_keys(self.local_user, version)
res = yield self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
yield defer.ensureDeferred(
self.handler.delete_room_keys(self.local_user, version)
)
res = yield defer.ensureDeferred(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
)
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per room
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
yield self.handler.delete_room_keys(
self.local_user, version, room_id="!abc:matrix.org"
yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
res = yield self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
yield defer.ensureDeferred(
self.handler.delete_room_keys(
self.local_user, version, room_id="!abc:matrix.org"
)
)
res = yield defer.ensureDeferred(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
)
self.assertDictEqual(res, {"rooms": {}})
# check for bulk-delete per session
yield self.handler.upload_room_keys(self.local_user, version, room_keys)
yield self.handler.delete_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
yield defer.ensureDeferred(
self.handler.upload_room_keys(self.local_user, version, room_keys)
)
res = yield self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
yield defer.ensureDeferred(
self.handler.delete_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
)
res = yield defer.ensureDeferred(
self.handler.get_room_keys(
self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
)
)
self.assertDictEqual(res, {"rooms": {}})

View File

@@ -138,10 +138,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
def get_current_users_in_room(room_id):
def get_users_in_room(room_id):
return defer.succeed({str(u) for u in self.room_members})
hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
self.datastore.get_users_in_room = get_users_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam

View File

@@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import Any, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple
import attr
@@ -26,8 +26,9 @@ from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.replication.http import streams
from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -35,7 +36,7 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.server import FakeTransport
from tests.server import FakeTransport, render
logger = logging.getLogger(__name__)
@@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET")
class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests running multiple workers.
Automatically handle HTTP replication requests from workers to master,
unlike `BaseStreamTestCase`.
"""
servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
def setUp(self):
super().setUp()
# build a replication server
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
store = self.hs.get_datastore()
self.database = store.db
self.reactor.lookups["testserv"] = "1.2.3.4"
self._worker_hs_to_resource = {}
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
# manually have to go and explicitly set it up each time (plus sometimes
# it is impossible to write the handling explicitly in the tests).
self.reactor.add_tcp_client_callback(
"1.2.3.4", 8765, self._handle_http_replication_attempt
)
def create_test_json_resource(self):
"""Overrides `HomeserverTestCase.create_test_json_resource`.
"""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
# subclassses.
resource = ReplicationRestResource(self.hs)
for servlet in self.servlets:
servlet(self.hs, resource)
return resource
def make_worker_hs(
self, worker_app: str, extra_config: dict = {}, **kwargs
) -> HomeServer:
"""Make a new worker HS instance, correctly connecting replcation
stream to the master HS.
Args:
worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
extra_config: Any extra config to use for this instances.
**kwargs: Options that get passed to `self.setup_test_homeserver`,
useful to e.g. pass some mocks for things like `http_client`
Returns:
The new worker HomeServer instance.
"""
config = self._get_worker_hs_config()
config["worker_app"] = worker_app
config.update(extra_config)
worker_hs = self.setup_test_homeserver(
homeserverToUse=GenericWorkerServer,
config=config,
reactor=self.reactor,
**kwargs
)
store = worker_hs.get_datastore()
store.db._db_pool = self.database._db_pool
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs, "client", "test", self.clock, repl_handler,
)
server = self.server_factory.buildProtocol(None)
client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)
server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)
# Set up a resource for the worker
resource = ReplicationRestResource(self.hs)
for servlet in self.servlets:
servlet(worker_hs, resource)
self._worker_hs_to_resource[worker_hs] = resource
return worker_hs
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config
def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()
def _handle_http_replication_attempt(self):
"""Handles a connection attempt to the master replication HTTP
listener.
"""
# We should have at least one outbound connection attempt, where the
# last is one to the HTTP repication IP/port.
clients = self.reactor.tcpClients
self.assertGreaterEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop()
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 8765)
# Set up client side protocol
client_protocol = client_factory.buildProtocol(None)
request_factory = OneShotRequestFactory()
# Set up the server side protocol
channel = _PushHTTPChannel(self.reactor)
channel.requestFactory = request_factory
channel.site = self.site
# Connect client to server and vice versa.
client_to_server_transport = FakeTransport(
channel, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, channel
)
channel.makeConnection(server_to_client_transport)
# Note: at this point we've wired everything up, but we need to return
# before the data starts flowing over the connections as this is called
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel):
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()
def checkPersistence(self, request, version):
"""Check whether the connection can be re-used
"""
# We hijack this to always say no for ease of wiring stuff up in
# `handle_http_replication_attempt`.
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
return False
class _PullToPushProducer:
"""A push producer that wraps a pull producer.

View File

@@ -15,63 +15,26 @@
import logging
from synapse.api.constants import LoginType
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest.client.v2_alpha import register
from tests import unittest
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel, render
from tests.server import FakeChannel
logger = logging.getLogger(__name__)
class ClientReaderTestCase(unittest.HomeserverTestCase):
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
"""Base class for tests of the replication streams"""
servlets = [
register.register_servlets,
]
servlets = [register.register_servlets]
def prepare(self, reactor, clock, hs):
# build a replication server
self.server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
store = hs.get_datastore()
self.database = store.db
self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
self.reactor.lookups["testserv"] = "1.2.3.4"
def make_worker_hs(self, extra_config={}):
config = self._get_worker_hs_config()
config.update(extra_config)
worker_hs = self.setup_test_homeserver(
homeserverToUse=GenericWorkerServer, config=config, reactor=self.reactor,
)
store = worker_hs.get_datastore()
store.db._db_pool = self.database._db_pool
# Register the expected servlets, essentially this is HomeserverTestCase.create_test_json_resource.
resource = JsonResource(self.hs)
for servlet in self.servlets:
servlet(worker_hs, resource)
# Essentially HomeserverTestCase.render.
def _render(request):
render(request, self.resource, self.reactor)
return worker_hs, _render
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
@@ -82,14 +45,14 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
def test_register_single_worker(self):
"""Test that registration works when using a single client reader worker.
"""
_, worker_render = self.make_worker_hs()
worker_hs = self.make_worker_hs("synapse.app.client_reader")
request_1, channel_1 = self.make_request(
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
worker_render(request_1)
self.render_on_worker(worker_hs, request_1)
self.assertEqual(request_1.code, 401)
# Grab the session
@@ -99,7 +62,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
) # type: SynapseRequest, FakeChannel
worker_render(request_2)
self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200)
# We're given a registered user.
@@ -108,15 +71,15 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
def test_register_multi_worker(self):
"""Test that registration works when using multiple client reader workers.
"""
_, worker_render_1 = self.make_worker_hs()
_, worker_render_2 = self.make_worker_hs()
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
request_1, channel_1 = self.make_request(
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
) # type: SynapseRequest, FakeChannel
worker_render_1(request_1)
self.render_on_worker(worker_hs_1, request_1)
self.assertEqual(request_1.code, 401)
# Grab the session
@@ -126,7 +89,7 @@ class ClientReaderTestCase(unittest.HomeserverTestCase):
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
) # type: SynapseRequest, FakeChannel
worker_render_2(request_2)
self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200)
# We're given a registered user.

View File

@@ -19,132 +19,40 @@ from mock import Mock
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.app.generic_worker import GenericWorkerServer
from synapse.events.builder import EventBuilderFactory
from synapse.replication.http import streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room
from synapse.types import UserID
from tests import unittest
from tests.server import FakeTransport
from tests.replication._base import BaseMultiWorkerStreamTestCase
logger = logging.getLogger(__name__)
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
servlets = [
streams.register_servlets,
]
def prepare(self, reactor, clock, hs):
# build a replication server
self.server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
store = hs.get_datastore()
self.database = store.db
self.reactor.lookups["testserv"] = "1.2.3.4"
def default_config(self):
conf = super().default_config()
conf["send_federation"] = False
return conf
def make_worker_hs(self, extra_config={}):
config = self._get_worker_hs_config()
config.update(extra_config)
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
worker_hs = self.setup_test_homeserver(
http_client=mock_federation_client,
homeserverToUse=GenericWorkerServer,
config=config,
reactor=self.reactor,
)
store = worker_hs.get_datastore()
store.db._db_pool = self.database._db_pool
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs, "client", "test", self.clock, repl_handler,
)
server = self.server_factory.buildProtocol(None)
client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)
server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)
return worker_hs
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.federation_sender"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config
def replicate(self):
"""Tell the master side of replication that something has happened, and then
wait for the replication to occur.
"""
self.streamer.on_notifier_poke()
self.pump()
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastore()
federation = self.hs.get_handlers().federation_handler
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
room_version = self.get_success(store.get_room_version(room))
factory = EventBuilderFactory(self.hs)
factory.hostname = remote_server
user_id = UserID("user", remote_server).to_string()
event_dict = {
"type": EventTypes.Member,
"state_key": user_id,
"content": {"membership": Membership.JOIN},
"sender": user_id,
"room_id": room,
}
builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success(builder.build(prev_event_ids))
self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate()
return room
class FederationSenderTestCase(BaseStreamTestCase):
class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
servlets = [
login.register_servlets,
register_servlets_for_client_rest_resource,
room.register_servlets,
]
def default_config(self):
conf = super().default_config()
conf["send_federation"] = False
return conf
def test_send_event_single_sender(self):
"""Test that using a single federation sender worker correctly sends a
new event.
"""
worker_hs = self.make_worker_hs({"send_federation": True})
mock_client = worker_hs.get_http_client()
mock_client = Mock(spec=["put_json"])
mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
self.make_worker_hs(
"synapse.app.federation_sender",
{"send_federation": True},
http_client=mock_client,
)
user = self.register_user("user", "pass")
token = self.login("user", "pass")
@@ -165,23 +73,29 @@ class FederationSenderTestCase(BaseStreamTestCase):
"""Test that using two federation sender workers correctly sends
new events.
"""
worker1 = self.make_worker_hs(
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
"send_federation": True,
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
}
},
http_client=mock_client1,
)
mock_client1 = worker1.get_http_client()
worker2 = self.make_worker_hs(
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
"send_federation": True,
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
}
},
http_client=mock_client2,
)
mock_client2 = worker2.get_http_client()
user = self.register_user("user2", "pass")
token = self.login("user2", "pass")
@@ -191,8 +105,8 @@ class FederationSenderTestCase(BaseStreamTestCase):
for i in range(20):
server_name = "other_server_%d" % (i,)
room = self.create_room_with_remote_server(user, token, server_name)
mock_client1.reset_mock()
mock_client2.reset_mock()
mock_client1.reset_mock() # type: ignore[attr-defined]
mock_client2.reset_mock() # type: ignore[attr-defined]
self.create_and_send_event(room, UserID.from_string(user))
self.replicate()
@@ -222,23 +136,29 @@ class FederationSenderTestCase(BaseStreamTestCase):
"""Test that using two federation sender workers correctly sends
new typing EDUs.
"""
worker1 = self.make_worker_hs(
mock_client1 = Mock(spec=["put_json"])
mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
"send_federation": True,
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
}
},
http_client=mock_client1,
)
mock_client1 = worker1.get_http_client()
worker2 = self.make_worker_hs(
mock_client2 = Mock(spec=["put_json"])
mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({})
self.make_worker_hs(
"synapse.app.federation_sender",
{
"send_federation": True,
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
}
},
http_client=mock_client2,
)
mock_client2 = worker2.get_http_client()
user = self.register_user("user3", "pass")
token = self.login("user3", "pass")
@@ -250,8 +170,8 @@ class FederationSenderTestCase(BaseStreamTestCase):
for i in range(20):
server_name = "other_server_%d" % (i,)
room = self.create_room_with_remote_server(user, token, server_name)
mock_client1.reset_mock()
mock_client2.reset_mock()
mock_client1.reset_mock() # type: ignore[attr-defined]
mock_client2.reset_mock() # type: ignore[attr-defined]
self.get_success(
typing_handler.started_typing(
@@ -284,3 +204,32 @@ class FederationSenderTestCase(BaseStreamTestCase):
self.assertTrue(sent_on_1)
self.assertTrue(sent_on_2)
def create_room_with_remote_server(self, user, token, remote_server="other_server"):
room = self.helper.create_room_as(user, tok=token)
store = self.hs.get_datastore()
federation = self.hs.get_handlers().federation_handler
prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
room_version = self.get_success(store.get_room_version(room))
factory = EventBuilderFactory(self.hs)
factory.hostname = remote_server
user_id = UserID("user", remote_server).to_string()
event_dict = {
"type": EventTypes.Member,
"state_key": user_id,
"content": {"membership": Membership.JOIN},
"sender": user_id,
"room_id": room,
}
builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success(builder.build(prev_event_ids))
self.get_success(federation.on_send_join_request(remote_server, join_event))
self.replicate()
return room

View File

@@ -857,6 +857,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(True, channel.json_body["deactivated"])
def test_reactivate_user(self):
"""
Test reactivating another user.
"""
# Deactivate the user.
request, channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Attempt to reactivate the user (without a password).
request, channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
)
self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
# Reactivate the user.
request, channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content=json.dumps({"deactivated": False, "password": "foo"}).encode(
encoding="utf_8"
),
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Get user
request, channel = self.make_request(
"GET", self.url_other_user, access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(False, channel.json_body["deactivated"])
def test_set_user_as_admin(self):
"""
Test setting the admin flag on a user.

View File

@@ -547,8 +547,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, "notsecret")
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
"JWT validation failed: Signature verification failed",
@@ -556,8 +556,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_expired(self):
channel = self.jwt_login({"sub": "frog", "exp": 864000})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Signature has expired"
)
@@ -565,8 +565,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_jwt_not_before(self):
now = int(time.time())
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
"JWT validation failed: The token is not yet valid (nbf)",
@@ -574,8 +574,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_no_sub(self):
channel = self.jwt_login({"username": "root"})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Invalid JWT")
@override_config(
@@ -597,16 +597,16 @@ class JWTTestCase(unittest.HomeserverTestCase):
# An invalid issuer.
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid issuer"
)
# Not providing an issuer.
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
'JWT validation failed: Token is missing the "iss" claim',
@@ -637,16 +637,16 @@ class JWTTestCase(unittest.HomeserverTestCase):
# An invalid audience.
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid audience"
)
# Not providing an audience.
channel = self.jwt_login({"sub": "kermit"})
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
'JWT validation failed: Token is missing the "aud" claim',
@@ -655,7 +655,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
def test_login_aud_no_config(self):
"""Test providing an audience without requiring it in the configuration."""
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"], "JWT validation failed: Invalid audience"
)
@@ -664,8 +665,8 @@ class JWTTestCase(unittest.HomeserverTestCase):
params = json.dumps({"type": "org.matrix.login.jwt"})
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
@@ -747,8 +748,8 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
def test_login_jwt_invalid_signature(self):
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
self.assertEqual(channel.result["code"], b"401", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_UNAUTHORIZED")
self.assertEqual(channel.result["code"], b"403", channel.result)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
self.assertEqual(
channel.json_body["error"],
"JWT validation failed: Signature verification failed",

View File

@@ -237,6 +237,7 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def __init__(self):
self.threadpool = ThreadPool(self)
self._tcp_callbacks = {}
self._udp = []
lookups = self.lookups = {}
@@ -268,6 +269,29 @@ class ThreadedMemoryReactorClock(MemoryReactorClock):
def getThreadPool(self):
return self.threadpool
def add_tcp_client_callback(self, host, port, callback):
"""Add a callback that will be invoked when we receive a connection
attempt to the given IP/port using `connectTCP`.
Note that the callback gets run before we return the connection to the
client, which means callbacks cannot block while waiting for writes.
"""
self._tcp_callbacks[(host, port)] = callback
def connectTCP(self, host, port, factory, timeout=30, bindAddress=None):
"""Fake L{IReactorTCP.connectTCP}.
"""
conn = super().connectTCP(
host, port, factory, timeout=timeout, bindAddress=None
)
callback = self._tcp_callbacks.get((host, port))
if callback:
callback()
return conn
class ThreadPool:
"""
@@ -486,7 +510,7 @@ class FakeTransport(object):
try:
self.other.dataReceived(to_write)
except Exception as e:
logger.warning("Exception writing to protocol: %s", e)
logger.exception("Exception writing to protocol: %s", e)
return
self.buffer = self.buffer[len(to_write) :]