1
0

Compare commits

..

21 Commits

Author SHA1 Message Date
Andrew Morgan
fd835d20a0 Pass a list of arg tuples to urlencode instead of request.args 2020-08-28 19:00:45 +01:00
Andrew Morgan
4d8bf7d021 Convert confirmation from request args to HTML form data 2020-08-28 18:42:44 +01:00
Andrew Morgan
1b4458ed26 Return 400 when accessing submit_token/_confirm with REMOTE behaviour 2020-08-28 15:17:30 +01:00
Andrew Morgan
990178f58b Pull things from, instead of copying the entirety of, the config 2020-08-28 14:59:05 +01:00
Andrew Morgan
d6addba84e Update synapse/rest/client/v2_alpha/account.py
Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
2020-08-28 14:15:25 +01:00
Andrew Morgan
3f4e350cd9 Merge branch 'develop' of github.com:matrix-org/synapse into anoa/password_reset_confirmation
* 'develop' of github.com:matrix-org/synapse: (22 commits)
  Update the test federation client to handle streaming responses (#8130)
  Do not propagate profile changes of shadow-banned users into rooms. (#8157)
  Make SlavedIdTracker.advance have same interface as MultiWriterIDGenerator (#8171)
  Convert simple_select_one and simple_select_one_onecol to async (#8162)
  Fix rate limiting unit tests. (#8167)
  Add functions to `MultiWriterIdGen` used by events stream (#8164)
  Do not allow send_nonmember_event to be called with shadow-banned users. (#8158)
  Changelog fixes
  1.19.1rc1
  Make StreamIdGen `get_next` and `get_next_mult` async  (#8161)
  Wording fixes to 'name' user admin api filter (#8163)
  Fix missing double-backtick in RST document
  Search in columns 'name' and 'displayname' in the admin users endpoint (#7377)
  Add type hints for state. (#8140)
  Stop shadow-banned users from sending non-member events. (#8142)
  Allow capping a room's retention policy (#8104)
  Add healthcheck for default localhost 8008 port on /health endpoint. (#8147)
  Fix flaky shadow-ban tests. (#8152)
  Fix join ratelimiter breaking profile updates and idempotency (#8153)
  Do not apply ratelimiting on joins to appservices (#8139)
  ...
2020-08-26 14:28:12 +01:00
Richard van der Hoff
88b9807ba4 Update the test federation client to handle streaming responses (#8130)
Now that the server supports streaming back JSON responses, it would be nice to
show the response as it is streamed, in the test tool.
2020-08-26 14:11:38 +01:00
Patrick Cloke
2e6c90ff84 Do not propagate profile changes of shadow-banned users into rooms. (#8157) 2020-08-26 08:49:01 -04:00
Erik Johnston
e3c91a3c55 Make SlavedIdTracker.advance have same interface as MultiWriterIDGenerator (#8171) 2020-08-26 13:15:20 +01:00
Patrick Cloke
4c6c56dc58 Convert simple_select_one and simple_select_one_onecol to async (#8162) 2020-08-26 07:19:32 -04:00
Patrick Cloke
56efa9ec71 Fix rate limiting unit tests. (#8167)
These were passing on the release-v1.19.1 branch but started failing once merged
to develop.
2020-08-26 07:19:20 -04:00
Andrew Morgan
1d79f7b22b Remove another unused class var 2020-08-21 17:24:44 +01:00
Andrew Morgan
0a2b11f361 Ensure we don't fail a request due to the config 2020-08-21 17:21:02 +01:00
Andrew Morgan
b41b0512f9 typing 2020-08-21 17:18:42 +01:00
Andrew Morgan
8cc8ee4448 Remove unused fields 2020-08-21 17:14:36 +01:00
Andrew Morgan
3568e1897c Add changelog 2020-08-21 11:41:56 +01:00
Andrew Morgan
622946e881 Remind people about the new template in the upgrade notes 2020-08-21 11:41:56 +01:00
Andrew Morgan
6140b32397 Add confirmation as a step to the password reset unit tests 2020-08-21 11:41:56 +01:00
Andrew Morgan
d9a19fc696 Create new password reset confirmation endpoint
Creates a new implementation-specific endpoint for accepting
confirmation of password resets. This endpoint should be hit with the
same parameters as the regular password reset submit_token endpoint.
This endpoint will now complete the password resets, whereas the
existing endpoint will now just show the confirmation template page.
2020-08-21 11:41:52 +01:00
Andrew Morgan
5f7a834a50 Load the new template 2020-08-21 11:19:44 +01:00
Andrew Morgan
9003eb4bcd Add confirmation page template 2020-08-21 11:19:44 +01:00
59 changed files with 938 additions and 657 deletions

View File

@@ -75,6 +75,30 @@ for example:
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
Upgrading to v1.20.0
====================
New HTML templates
------------------
A new HTML template,
`password_reset_confirmation.html <https://github.com/matrix-org/synapse/blob/develop/synapse/res/templates/password_reset_confirmation.html>`_,
has been added to the ``synapse/res/templates`` directory. If you are using a
custom template directory, you may want to copy the template over and modify it.
Note that as of v1.20.0, templates do not need to be included in custom template
directories for Synapse to start. The default templates will be used if a custom
template cannot be found.
This page will appear to the user after clicking a password reset link that has
been emailed to them.
To complete password reset, the page must include a way to make a `POST`
request to
``/_matrix/client/unstable/password_reset/{medium}/submit_token_confirm``
with the query parameters from the original link. See the file itself for more
details.
Upgrading to v1.18.0
====================

View File

@@ -1 +0,0 @@
Support `identifier` dictionary fields in User-Interactive Authentication flows. Relax requirement of the `user` parameter.

1
changelog.d/8004.feature Normal file
View File

@@ -0,0 +1 @@
Require the user to confirm that their password should be reset after clicking the email confirmation link.

1
changelog.d/8130.misc Normal file
View File

@@ -0,0 +1 @@
Update the test federation client to handle streaming responses.

1
changelog.d/8157.feature Normal file
View File

@@ -0,0 +1 @@
Add support for shadow-banning users (ignoring any message send requests).

1
changelog.d/8162.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8167.misc Normal file
View File

@@ -0,0 +1 @@
Fix tests that were broken due to the merge of 1.19.1.

1
changelog.d/8171.misc Normal file
View File

@@ -0,0 +1 @@
Make `SlavedIdTracker.advance` have the same interface as `MultiWriterIDGenerator`.

View File

@@ -2021,9 +2021,13 @@ email:
# * The contents of password reset emails sent by the homeserver:
# 'password_reset.html' and 'password_reset.txt'
#
# * HTML pages for success and failure that a user will see when they follow
# the link in the password reset email: 'password_reset_success.html' and
# 'password_reset_failure.html'
# * An HTML page that a user will see when they follow the link in the password
# reset email. The user will be asked to confirm the action before their
# password is reset: 'password_reset_confirmation.html'
#
# * HTML pages for success and failure that a user will see when they confirm
# the password reset flow using the page above: 'password_reset_success.html'
# and 'password_reset_failure.html'
#
# * The contents of address verification emails sent during registration:
# 'registration.html' and 'registration.txt'

View File

@@ -21,10 +21,12 @@ import argparse
import base64
import json
import sys
from typing import Any, Optional
from urllib import parse as urlparse
import nacl.signing
import requests
import signedjson.types
import srvlookup
import yaml
from requests.adapters import HTTPAdapter
@@ -69,7 +71,9 @@ def encode_canonical_json(value):
).encode("UTF-8")
def sign_json(json_object, signing_key, signing_name):
def sign_json(
json_object: Any, signing_key: signedjson.types.SigningKey, signing_name: str
) -> Any:
signatures = json_object.pop("signatures", {})
unsigned = json_object.pop("unsigned", None)
@@ -122,7 +126,14 @@ def read_signing_keys(stream):
return keys
def request_json(method, origin_name, origin_key, destination, path, content):
def request(
method: Optional[str],
origin_name: str,
origin_key: signedjson.types.SigningKey,
destination: str,
path: str,
content: Optional[str],
) -> requests.Response:
if method is None:
if content is None:
method = "GET"
@@ -159,11 +170,14 @@ def request_json(method, origin_name, origin_key, destination, path, content):
if method == "POST":
headers["Content-Type"] = "application/json"
result = s.request(
method=method, url=dest, headers=headers, verify=False, data=content
return s.request(
method=method,
url=dest,
headers=headers,
verify=False,
data=content,
stream=True,
)
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
return result.json()
def main():
@@ -222,7 +236,7 @@ def main():
with open(args.signing_key_path) as f:
key = read_signing_keys(f)[0]
result = request_json(
result = request(
args.method,
args.server_name,
key,
@@ -231,7 +245,12 @@ def main():
content=args.body,
)
json.dump(result, sys.stdout)
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
for chunk in result.iter_content():
# we write raw utf8 to stdout.
sys.stdout.buffer.write(chunk)
print("")

View File

@@ -198,6 +198,9 @@ class EmailConfig(Config):
"add_threepid_template_text", "add_threepid.txt"
)
password_reset_template_confirmation_html = (
"password_reset_confirmation.html"
)
password_reset_template_failure_html = email_config.get(
"password_reset_template_failure_html", "password_reset_failure.html"
)
@@ -228,6 +231,7 @@ class EmailConfig(Config):
self.email_registration_template_text,
self.email_add_threepid_template_html,
self.email_add_threepid_template_text,
self.email_password_reset_template_confirmation_html,
self.email_password_reset_template_failure_html,
self.email_registration_template_failure_html,
self.email_add_threepid_template_failure_html,
@@ -242,6 +246,7 @@ class EmailConfig(Config):
registration_template_text,
add_threepid_template_html,
add_threepid_template_text,
password_reset_template_confirmation_html,
password_reset_template_failure_html,
registration_template_failure_html,
add_threepid_template_failure_html,
@@ -404,9 +409,13 @@ class EmailConfig(Config):
# * The contents of password reset emails sent by the homeserver:
# 'password_reset.html' and 'password_reset.txt'
#
# * HTML pages for success and failure that a user will see when they follow
# the link in the password reset email: 'password_reset_success.html' and
# 'password_reset_failure.html'
# * An HTML page that a user will see when they follow the link in the password
# reset email. The user will be asked to confirm the action before their
# password is reset: 'password_reset_confirmation.html'
#
# * HTML pages for success and failure that a user will see when they confirm
# the password reset flow using the page above: 'password_reset_success.html'
# and 'password_reset_failure.html'
#
# * The contents of address verification emails sent during registration:
# 'registration.html' and 'registration.txt'

View File

@@ -38,14 +38,12 @@ from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import assert_params_in_dict
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.types import Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
@@ -53,82 +51,6 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
def client_dict_convert_legacy_fields_to_identifier(
submission: Dict[str, Union[str, Dict]]
):
"""
Convert a legacy-formatted login submission to an identifier dict.
Legacy login submissions (used in both login and user-interactive authentication)
provide user-identifying information at the top-level instead of in an `indentifier`
property. This is now deprecated and replaced with identifiers:
https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
Args:
submission: The client dict to convert. Passed by reference and modified
Raises:
SynapseError: If the format of the client dict is invalid
"""
if "user" in submission:
submission["identifier"] = {"type": "m.id.user", "user": submission.pop("user")}
if "medium" in submission and "address" in submission:
submission["identifier"] = {
"type": "m.id.thirdparty",
"medium": submission.pop("medium"),
"address": submission.pop("address"),
}
# We've converted valid, legacy login submissions to an identifier. If the
# dict still doesn't have an identifier, it's invalid
assert_params_in_dict(submission, required=["identifier"])
# Ensure the identifier has a type
if "type" not in submission["identifier"]:
raise SynapseError(
400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
)
def login_id_phone_to_thirdparty(identifier: Dict[str, str]) -> Dict[str, str]:
"""Convert a phone login identifier type to a generic threepid identifier.
Args:
identifier: Login identifier dict of type 'm.id.phone'
Returns:
An equivalent m.id.thirdparty identifier dict.
"""
if "type" not in identifier:
raise SynapseError(
400, "Invalid phone-type identifier", errcode=Codes.MISSING_PARAM
)
if "country" not in identifier or (
# XXX: We used to require `number` instead of `phone`. The spec
# defines `phone`. So accept both
"phone" not in identifier
and "number" not in identifier
):
raise SynapseError(
400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
)
# Accept both "phone" and "number" as valid keys in m.id.phone
phone_number = identifier.get("phone", identifier.get("number"))
# Convert user-provided phone number to a consistent representation
msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
# Return the new dictionary
return {
"type": "m.id.thirdparty",
"medium": "msisdn",
"address": msisdn,
}
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@@ -397,7 +319,7 @@ class AuthHandler(BaseHandler):
# otherwise use whatever was last provided.
#
# This was designed to allow the client to omit the parameters
# and just supply the session in subsequent calls. So it splits
# and just supply the session in subsequent calls so it split
# auth between devices by just sharing the session, (eg. so you
# could continue registration from your phone having clicked the
# email auth link on there). It's probably too open to abuse
@@ -602,129 +524,16 @@ class AuthHandler(BaseHandler):
res = await checker.check_auth(authdict, clientip=clientip)
return res
# We don't have a checker for the auth type provided by the client
# Assume that it is `m.login.password`.
if login_type != LoginType.PASSWORD:
raise SynapseError(
400, "Unknown authentication type", errcode=Codes.INVALID_PARAM,
)
# build a v1-login-style dict out of the authdict and fall back to the
# v1 code
user_id = authdict.get("user")
password = authdict.get("password")
if password is None:
raise SynapseError(
400,
"Missing parameter for m.login.password dict: 'password'",
errcode=Codes.INVALID_PARAM,
)
# Retrieve the user ID using details provided in the authdict
# Deprecation notice: Clients used to be able to simply provide a
# `user` field which pointed to a user_id or localpart. This has
# been deprecated in favour of an `identifier` key, which is a
# dictionary providing information on how to identify a single
# user.
# https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
#
# We convert old-style dicts to new ones here
client_dict_convert_legacy_fields_to_identifier(authdict)
# Extract a user ID from the values in the identifier
username = await self.username_from_identifier(authdict["identifier"], password)
if username is None:
raise SynapseError(400, "Valid username not found")
# Now that we've found the username, validate that the password is correct
canonical_id, _ = await self.validate_login(username, authdict)
if user_id is None:
raise SynapseError(400, "", Codes.MISSING_PARAM)
(canonical_id, callback) = await self.validate_login(user_id, authdict)
return canonical_id
async def username_from_identifier(
self, identifier: Dict[str, str], password: Optional[str] = None
) -> Optional[str]:
"""Given a dictionary containing an identifier from a client, extract the
possibly unqualified username of the user that it identifies. Does *not*
guarantee that the user exists.
If this identifier dict contains a threepid, we attempt to ask password
auth providers about it or, failing that, look up an associated user in
the database.
Args:
identifier: The identifier dictionary provided by the client
password: The user provided password if one exists. Used for asking
password auth providers for usernames from 3pid+password combos.
Returns:
A username if one was found, or None otherwise
Raises:
SynapseError: If the identifier dict is invalid
"""
# Convert phone type identifiers to generic threepid identifiers, which
# will be handled in the next step
if identifier["type"] == "m.id.phone":
identifier = login_id_phone_to_thirdparty(identifier)
# Convert a threepid identifier to an user identifier
if identifier["type"] == "m.id.thirdparty":
address = identifier.get("address")
medium = identifier.get("medium")
if not medium or not address:
# An error would've already been raised in
# `login_id_thirdparty_from_phone` if the original submission
# was a phone identifier
raise SynapseError(
400, "Invalid thirdparty identifier", errcode=Codes.INVALID_PARAM,
)
if medium == "email":
# For emails, transform the address to lowercase.
# We store all email addresses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
# Check for auth providers that support 3pid login types
if password is not None:
canonical_user_id, _ = await self.check_password_provider_3pid(
medium, address, password,
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
return canonical_user_id
# Check local store
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
medium, address
)
if not user_id:
# We were unable to find a user_id that belonged to the threepid returned
# by the password auth provider
return None
identifier = {"type": "m.id.user", "user": user_id}
# By this point, the identifier should be a `m.id.user`: if it's anything
# else, we haven't understood it.
if identifier["type"] != "m.id.user":
raise SynapseError(
400, "Unknown login identifier type", errcode=Codes.INVALID_PARAM,
)
# User identifiers have a "user" key
user = identifier.get("user")
if user is None:
raise SynapseError(
400,
"User identifier is missing 'user' key",
errcode=Codes.INVALID_PARAM,
)
return user
def _get_params_recaptcha(self) -> dict:
return {"public_key": self.hs.config.recaptcha_public_key}
@@ -889,8 +698,7 @@ class AuthHandler(BaseHandler):
m.login.password auth types.
Args:
username: a localpart or fully qualified user ID - what is provided by the
client
username: username supplied by the user
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
Returns:
@@ -902,10 +710,10 @@ class AuthHandler(BaseHandler):
LoginError if there was an authentication problem.
"""
# We need a fully qualified User ID for some method calls here
qualified_user_id = username
if not qualified_user_id.startswith("@"):
qualified_user_id = UserID(qualified_user_id, self.hs.hostname).to_string()
if username.startswith("@"):
qualified_user_id = username
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
login_type = login_submission.get("type")
known_login_type = False

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import logging
import random
from synapse.api.errors import (
AuthError,
@@ -213,8 +214,14 @@ class BaseProfileHandler(BaseHandler):
async def set_avatar_url(
self, target_user, requester, new_avatar_url, by_admin=False
):
"""target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change."""
"""Set a new avatar URL for a user.
Args:
target_user (UserID): the user whose avatar URL is to be changed.
requester (Requester): The user attempting to make this change.
new_avatar_url (str): The avatar URL to give this user.
by_admin (bool): Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -278,6 +285,12 @@ class BaseProfileHandler(BaseHandler):
await self.ratelimit(requester)
# Do not actually update the room state for shadow-banned users.
if requester.shadow_banned:
# We randomly sleep a bit just to annoy the requester.
await self.clock.sleep(random.randint(1, 10))
return
room_ids = await self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids:

View File

@@ -380,7 +380,7 @@ class RoomMemberHandler(object):
# later on.
content = dict(content)
if not self.allow_per_room_profiles:
if not self.allow_per_room_profiles or requester.shadow_banned:
# Strip profile data, knowing that new profile data will be added to the
# event's content in event_creation_handler.create_event() using the target's
# global profile.

View File

@@ -21,9 +21,9 @@ class SlavedIdTracker(object):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self.advance(_load_current_id(db_conn, table, column))
self.advance(None, _load_current_id(db_conn, table, column))
def advance(self, new_id):
def advance(self, instance_name, new_id):
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self):

View File

@@ -41,12 +41,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == TagAccountDataStream.NAME:
self._account_data_id_gen.advance(token)
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == AccountDataStream.NAME:
self._account_data_id_gen.advance(token)
self._account_data_id_gen.advance(instance_name, token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(

View File

@@ -46,7 +46,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
self._device_inbox_id_gen.advance(token)
self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
self._device_inbox_stream_cache.entity_has_changed(

View File

@@ -50,10 +50,10 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
self._device_list_id_gen.advance(token)
self._device_list_id_gen.advance(instance_name, token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@@ -40,7 +40,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == GroupServerStream.NAME:
self._group_updates_id_gen.advance(token)
self._group_updates_id_gen.advance(instance_name, token)
for row in rows:
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)

View File

@@ -44,7 +44,7 @@ class SlavedPresenceStore(BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == PresenceStream.NAME:
self._presence_id_gen.advance(token)
self._presence_id_gen.advance(instance_name, token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))

View File

@@ -30,7 +30,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(token)
self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))

View File

@@ -34,5 +34,5 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(token)
self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@@ -46,7 +46,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ReceiptsStream.NAME:
self._receipts_id_gen.advance(token)
self._receipts_id_gen.advance(instance_name, token)
for row in rows:
self.invalidate_caches_for_receipt(
row.room_id, row.receipt_type, row.user_id

View File

@@ -33,6 +33,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == PublicRoomsStream.NAME:
self._public_room_id_gen.advance(token)
self._public_room_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)

View File

@@ -0,0 +1,16 @@
<html>
<head></head>
<body>
<!--Use a hidden form to resubmit the information necessary to reset the password-->
<form action="/_matrix/client/unstable/password_reset/{{ medium }}/submit_token_confirm" method="post">
<input type="hidden" name="sid" value="{{ sid }}">
<input type="hidden" name="token" value="{{ token }}">
<input type="hidden" name="client_secret" value="{{ client_secret }}">
<p>You have requested to <strong>reset your Matrix account password</strong>. Click the link below to confirm this action. <br /><br />
If you did not mean to do this, please close this page and your password will not be changed.</p>
<p><button type="submit">Confirm changing my password</button></p>
</form>
</body>
</html>

View File

@@ -18,7 +18,6 @@ from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.auth import client_dict_convert_legacy_fields_to_identifier
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -29,11 +28,56 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
def login_submission_legacy_convert(submission):
"""
If the input login submission is an old style object
(ie. with top-level user / medium / address) convert it
to a typed object.
"""
if "user" in submission:
submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
del submission["user"]
if "medium" in submission and "address" in submission:
submission["identifier"] = {
"type": "m.id.thirdparty",
"medium": submission["medium"],
"address": submission["address"],
}
del submission["medium"]
del submission["address"]
def login_id_thirdparty_from_phone(identifier):
"""
Convert a phone login identifier type to a generic threepid identifier
Args:
identifier(dict): Login identifier dict of type 'm.id.phone'
Returns: Login identifier dict of type 'm.id.threepid'
"""
if "country" not in identifier or (
# The specification requires a "phone" field, while Synapse used to require a "number"
# field. Accept both for backwards compatibility.
"phone" not in identifier
and "number" not in identifier
):
raise SynapseError(400, "Invalid phone-type identifier")
# Accept both "phone" and "number" as valid keys in m.id.phone
phone_number = identifier.get("phone", identifier["number"])
msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
@@ -123,8 +167,7 @@ class LoginRestServlet(RestServlet):
result = await self._do_token_login(login_submission)
else:
result = await self._do_other_login(login_submission)
except KeyError as e:
logger.debug("KeyError during login: %s", e)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
well_known_data = self._well_known_builder.get_well_known()
@@ -151,14 +194,27 @@ class LoginRestServlet(RestServlet):
login_submission.get("address"),
login_submission.get("user"),
)
# Convert deprecated authdict formats to the current scheme
client_dict_convert_legacy_fields_to_identifier(login_submission)
login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
raise SynapseError(400, "Missing param: identifier")
identifier = login_submission["identifier"]
if "type" not in identifier:
raise SynapseError(400, "Login identifier has no type")
# convert phone type identifiers to generic threepids
if identifier["type"] == "m.id.phone":
identifier = login_id_thirdparty_from_phone(identifier)
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
address = identifier.get("address")
medium = identifier.get("medium")
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
# Check whether this attempt uses a threepid, if so, check if our failed attempt
# ratelimiter allows another attempt at this time
medium = login_submission.get("medium")
address = login_submission.get("address")
if medium and address:
# For emails, canonicalise the address.
# We store all email addresses canonicalised in the DB.
# (See add_threepid in synapse/handlers/auth.py)
@@ -168,41 +224,74 @@ class LoginRestServlet(RestServlet):
except ValueError as e:
raise SynapseError(400, str(e))
# We also apply account rate limiting using the 3PID as a key, as
# otherwise using 3PID bypasses the ratelimiting based on user ID.
self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
# Extract a localpart or user ID from the values in the identifier
username = await self.auth_handler.username_from_identifier(
login_submission["identifier"], login_submission.get("password")
)
# Check for login providers that support 3pid login types
(
canonical_user_id,
callback_3pid,
) = await self.auth_handler.check_password_provider_3pid(
medium, address, login_submission["password"]
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
if not username:
if medium and address:
# The user attempted to login via threepid and failed
# Record this failed attempt using the threepid as a key, as otherwise
# the user could bypass the ratelimiter by not providing a username
self._failed_attempts_ratelimiter.can_do_action(
(medium, address.lower())
result = await self._complete_login(
canonical_user_id, login_submission, callback_3pid
)
return result
raise LoginError(403, "Unauthorized threepid", errcode=Codes.FORBIDDEN)
# No password providers were able to handle this 3pid
# Check local store
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
medium, address
)
if not user_id:
logger.warning(
"unknown 3pid identifier medium %s, address %r", medium, address
)
# We mark that we've failed to log in here, as
# `check_password_provider_3pid` might have returned `None` due
# to an incorrect password, rather than the account not
# existing.
#
# If it returned None but the 3PID was bound then we won't hit
# this code path, which is fine as then the per-user ratelimit
# will kick in below.
self._failed_attempts_ratelimiter.can_do_action((medium, address))
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
# The login failed for another reason
raise LoginError(403, "Invalid login", errcode=Codes.FORBIDDEN)
identifier = {"type": "m.id.user", "user": user_id}
# We were able to extract a username successfully
# Check if we've hit the failed ratelimit for this user ID
self._failed_attempts_ratelimiter.ratelimit(username.lower(), update=False)
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
if identifier["type"] != "m.id.user":
raise SynapseError(400, "Unknown login identifier type")
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
if identifier["user"].startswith("@"):
qualified_user_id = identifier["user"]
else:
qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
# Check if we've hit the failed ratelimit (but don't update it)
self._failed_attempts_ratelimiter.ratelimit(
qualified_user_id.lower(), update=False
)
try:
canonical_user_id, callback = await self.auth_handler.validate_login(
username, login_submission
identifier["user"], login_submission
)
except LoginError:
# The user has failed to log in, so we need to update the rate
# limiter. Using `can_do_action` avoids us raising a ratelimit
# exception and masking the LoginError. This just records the attempt.
# The actual rate-limiting happens above
self._failed_attempts_ratelimiter.can_do_action(username.lower())
# exception and masking the LoginError. The actual ratelimiting
# should have happened above.
self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
raise
result = await self._complete_login(
@@ -220,7 +309,7 @@ class LoginRestServlet(RestServlet):
create_non_existent_users: bool = False,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually log them in (e.g. create devices). This gets called on
actually login them in (e.g. create devices). This gets called on
all successful logins.
Applies the ratelimiting for successful login attempts against an

View File

@@ -17,7 +17,9 @@
import logging
import random
from http import HTTPStatus
from typing import TYPE_CHECKING
from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
@@ -38,6 +40,9 @@ from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
if TYPE_CHECKING:
from synapse.server import HomeServer
from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@@ -157,14 +162,14 @@ class PasswordResetSubmitTokenServlet(RestServlet):
hs (synapse.server.HomeServer): server
"""
super(PasswordResetSubmitTokenServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastore()
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self._failure_email_template = (
self.config.email_password_reset_template_failure_html
self._threepid_behaviour_email = hs.config.threepid_behaviour_email
self._local_threepid_handling_disabled_due_to_email_config = (
hs.config.local_threepid_handling_disabled_due_to_email_config
)
if self._threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self._confirmation_email_template = (
hs.config.email_password_reset_template_confirmation_html
)
async def on_GET(self, request, medium):
@@ -173,20 +178,91 @@ class PasswordResetSubmitTokenServlet(RestServlet):
raise SynapseError(
400, "This medium is currently not supported for password resets"
)
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
if self._threepid_behaviour_email == ThreepidBehaviour.OFF:
if self._local_threepid_handling_disabled_due_to_email_config:
logger.warning(
"Password reset emails have been disabled due to lack of an email config"
)
raise SynapseError(
400, "Email-based password resets are disabled on this server"
)
elif self._threepid_behaviour_email == ThreepidBehaviour.REMOTE:
raise SynapseError(
400,
"Password resets for this homeserver are handled by a separate program",
)
sid = parse_string(request, "sid", required=True)
token = parse_string(request, "token", required=True)
client_secret = parse_string(request, "client_secret", required=True)
assert_valid_client_secret(client_secret)
# Show a confirmation page, just in case someone accidentally clicked this link when
# they didn't mean to
template_vars = {
"sid": sid,
"token": token,
"client_secret": client_secret,
"medium": medium,
}
respond_with_html(
request, 200, self._confirmation_email_template.render(**template_vars)
)
class PasswordResetConfirmationSubmitTokenServlet(RestServlet):
"""Handles confirmation of 3PID validation token submission.
A user will land on PasswordResetSubmitTokenServlet, confirm the password reset, then
submit the same parameters to this servlet.
"""
PATTERNS = client_patterns(
"/password_reset/email/submit_token_confirm$", releases=(), unstable=True,
)
def __init__(self, hs: "HomeServer"):
"""
Args:
hs: server
"""
super().__init__()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self._threepid_behaviour_email = hs.config.threepid_behaviour_email
self._local_threepid_handling_disabled_due_to_email_config = (
hs.config.local_threepid_handling_disabled_due_to_email_config
)
if self._threepid_behaviour_email == ThreepidBehaviour.LOCAL:
self._email_password_reset_template_success_html = (
hs.config.email_password_reset_template_success_html_content
)
self._failure_email_template = (
hs.config.email_password_reset_template_failure_html
)
async def on_POST(self, request: Request):
if self._threepid_behaviour_email == ThreepidBehaviour.OFF:
if self._local_threepid_handling_disabled_due_to_email_config:
logger.warning(
"Password reset emails have been disabled due to lack of an email config"
)
raise SynapseError(
400, "Email-based password resets are disabled on this server"
)
elif self._threepid_behaviour_email == ThreepidBehaviour.REMOTE:
raise SynapseError(
400,
"Password resets for this homeserver are handled by a separate program",
)
logger.info("ARGS: %s, CONTENT: %s, HEADERS: %s", request.args, request.content,
request.getAllHeaders())
sid = parse_string(request, "sid", required=True)
token = parse_string(request, "token", required=True)
client_secret = parse_string(request, "client_secret", required=True)
# Attempt to validate a 3PID session
try:
# Mark the session as valid
@@ -207,7 +283,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
return None
# Otherwise show the success template
html = self.config.email_password_reset_template_success_html_content
html = self._email_password_reset_template_success_html
status_code = 200
except ThreepidValidationError as e:
status_code = e.code
@@ -891,6 +967,7 @@ class WhoamiRestServlet(RestServlet):
def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordResetSubmitTokenServlet(hs).register(http_server)
PasswordResetConfirmationSubmitTokenServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
EmailThreepidRequestTokenRestServlet(hs).register(http_server)

View File

@@ -29,9 +29,11 @@ from typing import (
Tuple,
TypeVar,
Union,
overload,
)
from prometheus_client import Histogram
from typing_extensions import Literal
from twisted.enterprise import adbapi
from twisted.internet import defer
@@ -1020,14 +1022,36 @@ class DatabasePool(object):
return txn.execute_batch(sql, args)
def simple_select_one(
@overload
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
...
@overload
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
...
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> defer.Deferred:
) -> Optional[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@@ -1038,18 +1062,18 @@ class DatabasePool(object):
allow_none: If true, return None instead of failing if the SELECT
statement returns no rows
"""
return self.runInteraction(
return await self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
def simple_select_one_onecol(
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one_onecol",
) -> defer.Deferred:
) -> Optional[Any]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@@ -1061,7 +1085,7 @@ class DatabasePool(object):
statement returns no rows
desc: description of the transaction, for logging and metrics
"""
return self.runInteraction(
return await self.runInteraction(
desc,
self.simple_select_one_onecol_txn,
table,

View File

@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, Iterable, List, Optional, Set, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id: str, device_id: str):
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
@@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore):
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
A dict containing the device information
Raises:
StoreError: if the device is not found
"""
return self.db_pool.simple_select_one(
return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore):
)
@cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id: str):
async def get_device_list_last_stream_id_for_remote(
self, user_id: str
) -> Optional[Any]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",

View File

@@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias):
return self.db_pool.simple_select_one_onecol(
async def get_room_alias_creator(self, room_alias: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",

View File

@@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return ret
def count_e2e_room_keys(self, user_id, version):
async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
"""Get the number of keys in a backup version.
Args:
user_id (str): the user whose backup we're querying
version (str): the version ID of the backup we're querying about
user_id: the user whose backup we're querying
version: the version ID of the backup we're querying about
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",

View File

@@ -113,25 +113,25 @@ class EventsWorkerStore(SQLBaseStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(token)
self._stream_id_gen.advance(instance_name, token)
elif stream_name == BackfillStream.NAME:
self._backfill_id_gen.advance(-token)
self._backfill_id_gen.advance(instance_name, -token)
super().process_replication_rows(stream_name, instance_name, token, rows)
def get_received_ts(self, event_id):
async def get_received_ts(self, event_id: str) -> Optional[int]:
"""Get received_ts (when it was persisted) for the event.
Raises an exception for unknown events.
Args:
event_id (str)
event_id: The event ID to query.
Returns:
Deferred[int|None]: Timestamp in milliseconds, or None for events
that were persisted before received_ts was implemented.
Timestamp in milliseconds, or None for events that were persisted
before received_ts was implemented.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",

View File

@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
class GroupServerWorkerStore(SQLBaseStore):
def get_group(self, group_id):
return self.db_pool.simple_select_one(
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore):
)
return bool(result)
def is_user_admin_in_group(self, group_id, user_id):
return self.db_pool.simple_select_one_onecol(
async def is_user_admin_in_group(
self, group_id: str, user_id: str
) -> Optional[bool]:
return await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
@@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="is_user_admin_in_group",
)
def is_user_invited_to_local_group(self, group_id, user_id):
async def is_user_invited_to_local_group(
self, group_id: str, user_id: str
) -> Optional[bool]:
"""Has the group server invited a user?
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",

View File

@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
def get_local_media(self, media_id):
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
Returns:
None if the media_id doesn't exist.
"""
return self.db_pool.simple_select_one(
return await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
desc="store_local_thumbnail",
)
def get_cached_remote_media(self, origin, media_id):
return self.db_pool.simple_select_one(
async def get_cached_remote_media(
self, origin, media_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(

View File

@@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return users
@cached(num_args=1)
def user_last_seen_monthly_active(self, user_id):
async def user_last_seen_monthly_active(self, user_id: str) -> int:
"""
Checks if a given user is part of the monthly active user group
Arguments:
user_id (str): user to add/update
Return:
Deferred[int] : timestamp since last seen, None if never seen
Checks if a given user is part of the monthly active user group
Arguments:
user_id: user to add/update
Return:
Timestamp since last seen, None if never seen
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",

View File

@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
@@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
class ProfileWorkerStore(SQLBaseStore):
async def get_profileinfo(self, user_localpart):
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
@@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
def get_profile_displayname(self, user_localpart):
return self.db_pool.simple_select_one_onecol(
async def get_profile_displayname(self, user_localpart: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
def get_profile_avatar_url(self, user_localpart):
return self.db_pool.simple_select_one_onecol(
async def get_profile_avatar_url(self, user_localpart: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)
def get_from_remote_profile_cache(self, user_id):
return self.db_pool.simple_select_one(
async def get_from_remote_profile_cache(
self, user_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),

View File

@@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
return self.db_pool.simple_select_one_onecol(
async def get_last_receipt_event_id_for_user(
self, user_id: str, room_id: str, receipt_type: str
) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,

View File

@@ -17,7 +17,7 @@
import logging
import re
from typing import Awaitable, Dict, List, Optional
from typing import Any, Awaitable, Dict, List, Optional
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
@@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@cached()
def get_user_by_id(self, user_id):
return self.db_pool.simple_select_one(
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@@ -1259,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
desc="del_user_pending_deactivation",
)
def get_user_pending_deactivation(self):
async def get_user_pending_deactivation(self) -> Optional[str]:
"""
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import Optional
from synapse.storage._base import SQLBaseStore
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
def get_rejection_reason(self, event_id):
return self.db_pool.simple_select_one_onecol(
async def get_rejection_reason(self, event_id: str) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},

View File

@@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
self.config = hs.config
def get_room(self, room_id):
async def get_room(self, room_id: str) -> dict:
"""Retrieve a room.
Args:
room_id (str): The ID of the room to retrieve.
room_id: The ID of the room to retrieve.
Returns:
A dict containing the room information, or None if the room is unknown.
"""
return self.db_pool.simple_select_one(
return await self.db_pool.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
@@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore):
return ret_val
@cached(max_entries=10000)
def is_room_blocked(self, room_id):
return self.db_pool.simple_select_one_onecol(
async def is_room_blocked(self, room_id: str) -> Optional[bool]:
return await self.db_pool.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="1",

View File

@@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return event.content.get("canonical_alias")
@cached(max_entries=50000)
def _get_state_group_for_event(self, event_id):
return self.db_pool.simple_select_one_onecol(
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
return await self.db_pool.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",

View File

@@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore):
return len(rooms_to_work_on)
def get_stats_positions(self):
async def get_stats_positions(self) -> int:
"""
Returns the stats processor positions.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="stats_incremental_position",
keyvalues={},
retcol="stream_id",
@@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
def get_earliest_token_for_stats(self, stats_type, id):
async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore):
being calculated.
Returns:
Deferred[int]
The earliest token.
"""
table, id_col = TYPE_TO_TABLE[stats_type]
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
"%s_current" % (table,),
keyvalues={id_col: id},
retcol="completed_delta_stream_id",

View File

@@ -15,6 +15,7 @@
import logging
import re
from typing import Any, Dict, Optional
from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.database import DatabasePool
@@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
)
@cached()
def get_user_in_directory(self, user_id):
return self.db_pool.simple_select_one(
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
@@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
def get_user_directory_stream_pos(self):
return self.db_pool.simple_select_one_onecol(
async def get_user_directory_stream_pos(self) -> int:
return await self.db_pool.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",

View File

@@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_name(self):
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
yield defer.ensureDeferred(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.frank)
@@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank",
)
@defer.inlineCallbacks
@@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_displayname = False
# Setting displayname for the first time is allowed
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
yield defer.ensureDeferred(
self.store.set_profile_displayname(self.frank.localpart, "Frank")
)
self.assertEquals(
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.frank.localpart)
)
),
"Frank",
)
# Setting displayname a second time is forbidden
@@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_incoming_fed_query(self):
yield defer.ensureDeferred(self.store.create_profile("caroline"))
yield self.store.set_profile_displayname("caroline", "Caroline")
yield defer.ensureDeferred(
self.store.set_profile_displayname("caroline", "Caroline")
)
response = yield defer.ensureDeferred(
self.query_handlers["profile"](
@@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_my_avatar(self):
yield self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
)
)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
@@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/pic.gif",
)
@@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
)
self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/me.png",
)
@@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
self.hs.config.enable_set_avatar_url = False
# Setting displayname for the first time is allowed
yield self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
)
)
self.assertEquals(
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.frank.localpart)
)
),
"http://my.server/me.png",
)

View File

@@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_users_in_room = get_users_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
self.datastore.get_user_directory_stream_pos.side_effect = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam
defer.succeed(1)
lambda: make_awaitable(1)
)
self.datastore.get_current_state_deltas.return_value = (0, None)

View File

@@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase):
# Check that the new user exists with all provided attributes
self.assertEqual(user_id, "@bob:test")
self.assertTrue(access_token)
self.assertTrue(self.store.get_user_by_id(user_id))
self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
# Check that the email was assigned
emails = self.get_success(self.store.user_get_threepids(user_id))

View File

@@ -0,0 +1,272 @@
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from mock import Mock, patch
import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
from tests import unittest
class _ShadowBannedBase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, homeserver):
# Create two users, one of which is shadow-banned.
self.banned_user_id = self.register_user("banned", "test")
self.banned_access_token = self.login("banned", "test")
self.store = self.hs.get_datastore()
self.get_success(
self.store.db_pool.simple_update(
table="users",
keyvalues={"name": self.banned_user_id},
updatevalues={"shadow_banned": True},
desc="shadow_ban",
)
)
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
# To avoid the tests timing out don't add a delay to "annoy the requester".
@patch("random.randint", new=lambda a, b: 0)
class RoomTestCase(_ShadowBannedBase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
directory.register_servlets,
login.register_servlets,
room.register_servlets,
room_upgrade_rest_servlet.register_servlets,
]
def test_invite(self):
"""Invites from shadow-banned users don't actually get sent."""
# The create works fine.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
# Inviting the user completes successfully.
self.helper.invite(
room=room_id,
src=self.banned_user_id,
tok=self.banned_access_token,
targ=self.other_user_id,
)
# But the user wasn't actually invited.
invited_rooms = self.get_success(
self.store.get_invited_rooms_for_local_user(self.other_user_id)
)
self.assertEqual(invited_rooms, [])
def test_invite_3pid(self):
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
identity_handler = self.hs.get_handlers().identity_handler
identity_handler.lookup_3pid = Mock(
side_effect=AssertionError("This should not get called")
)
# The create works fine.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
# Inviting the user completes successfully.
request, channel = self.make_request(
"POST",
"/rooms/%s/invite" % (room_id,),
{"id_server": "test", "medium": "email", "address": "test@test.test"},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
# This should have raised an error earlier, but double check this wasn't called.
identity_handler.lookup_3pid.assert_not_called()
def test_create_room(self):
"""Invitations during a room creation should be discarded, but the room still gets created."""
# The room creation is successful.
request, channel = self.make_request(
"POST",
"/_matrix/client/r0/createRoom",
{"visibility": "public", "invite": [self.other_user_id]},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
room_id = channel.json_body["room_id"]
# But the user wasn't actually invited.
invited_rooms = self.get_success(
self.store.get_invited_rooms_for_local_user(self.other_user_id)
)
self.assertEqual(invited_rooms, [])
# Since a real room was created, the other user should be able to join it.
self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
# Both users should be in the room.
users = self.get_success(self.store.get_users_in_room(room_id))
self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
def test_message(self):
"""Messages from shadow-banned users don't actually get sent."""
room_id = self.helper.create_room_as(
self.other_user_id, tok=self.other_access_token
)
# The user should be in the room.
self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
# Sending a message should complete successfully.
result = self.helper.send_event(
room_id=room_id,
type=EventTypes.Message,
content={"msgtype": "m.text", "body": "with right label"},
tok=self.banned_access_token,
)
self.assertIn("event_id", result)
event_id = result["event_id"]
latest_events = self.get_success(
self.store.get_latest_event_ids_in_room(room_id)
)
self.assertNotIn(event_id, latest_events)
def test_upgrade(self):
"""A room upgrade should fail, but look like it succeeded."""
# The create works fine.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
request, channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
{"new_version": "6"},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
# A new room_id should be returned.
self.assertIn("replacement_room", channel.json_body)
new_room_id = channel.json_body["replacement_room"]
# It doesn't really matter what API we use here, we just want to assert
# that the room doesn't exist.
summary = self.get_success(self.store.get_room_summary(new_room_id))
# The summary should be empty since the room doesn't exist.
self.assertEqual(summary, {})
# To avoid the tests timing out don't add a delay to "annoy the requester".
@patch("random.randint", new=lambda a, b: 0)
class ProfileTestCase(_ShadowBannedBase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
profile.register_servlets,
room.register_servlets,
]
def test_displayname(self):
"""Profile changes should succeed, but don't end up in a room."""
original_display_name = "banned"
new_display_name = "new name"
# Join a room.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
# The update should succeed.
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
{"displayname": new_display_name},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertEqual(channel.json_body, {})
# The user's display name should be updated.
request, channel = self.make_request(
"GET", "/profile/%s/displayname" % (self.banned_user_id,)
)
self.render(request)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["displayname"], new_display_name)
# But the display name in the room should not be.
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
self.banned_user_id,
room_id,
"m.room.member",
self.banned_user_id,
False,
)
)
self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name}
)
def test_room_displayname(self):
"""Changes to state events for a room should be processed, but not end up in the room."""
original_display_name = "banned"
new_display_name = "new name"
# Join a room.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
# The update should succeed.
request, channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
% (room_id, self.banned_user_id),
{"membership": "join", "displayname": new_display_name},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertIn("event_id", channel.json_body)
# The display name in the room should not be changed.
message_handler = self.hs.get_message_handler()
event = self.get_success(
message_handler.get_room_data(
self.banned_user_id,
room_id,
"m.room.member",
self.banned_user_id,
False,
)
)
self.assertEqual(
event.content, {"membership": "join", "displayname": original_display_name}
)

View File

@@ -267,7 +267,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
auth = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": user_id},
# https://github.com/matrix-org/synapse/issues/5665
# "identifier": {"type": "m.id.user", "user": user_id},
"user": user_id,
"password": password,
"session": channel.json_body["session"],
}

View File

@@ -21,13 +21,13 @@
import json
from urllib import parse as urlparse
from mock import Mock, patch
from mock import Mock
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
from synapse.rest.client.v1 import directory, login, profile, room
from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet
from synapse.rest.client.v2_alpha import account
from synapse.types import JsonDict, RoomAlias, UserID
from synapse.util.stringutils import random_string
@@ -684,38 +684,39 @@ class RoomJoinRatelimitTestCase(RoomBase):
]
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit(self):
"""Tests that local joins are actually rate-limited."""
for i in range(5):
for i in range(3):
self.helper.create_room_as(self.user_id)
self.helper.create_room_as(self.user_id, expect_code=429)
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit_profile_change(self):
"""Tests that sending a profile update into all of the user's joined rooms isn't
rate-limited by the rate-limiter on joins."""
# Create and join more rooms than the rate-limiting config allows in a second.
# Create and join as many rooms as the rate-limiting config allows in a second.
room_ids = [
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
]
self.reactor.advance(1)
room_ids = room_ids + [
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
self.helper.create_room_as(self.user_id),
]
# Let some time for the rate-limiter to forget about our multi-join.
self.reactor.advance(2)
# Add one to make sure we're joined to more rooms than the config allows us to
# join in a second.
room_ids.append(self.helper.create_room_as(self.user_id))
# Create a profile for the user, since it hasn't been done on registration.
store = self.hs.get_datastore()
store.create_profile(UserID.from_string(self.user_id).localpart)
self.get_success(
store.create_profile(UserID.from_string(self.user_id).localpart)
)
# Update the display name for the user.
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
@@ -738,7 +739,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
self.assertEquals(channel.json_body["displayname"], "John Doe")
@unittest.override_config(
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
)
def test_join_local_ratelimit_idempotent(self):
"""Tests that the room join endpoints remain idempotent despite rate-limiting
@@ -754,7 +755,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
for path in paths_to_test:
# Make sure we send more requests than the rate-limiting config would allow
# if all of these requests ended up joining the user to a room.
for i in range(6):
for i in range(4):
request, channel = self.make_request("POST", path % room_id, {})
self.render(request)
self.assertEquals(channel.code, 200)
@@ -2059,158 +2060,3 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
"""An alias which does not point to the room raises a SynapseError."""
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
# To avoid the tests timing out don't add a delay to "annoy the requester".
@patch("random.randint", new=lambda a, b: 0)
class ShadowBannedTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets_for_client_rest_resource,
directory.register_servlets,
login.register_servlets,
room.register_servlets,
room_upgrade_rest_servlet.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
self.banned_user_id = self.register_user("banned", "test")
self.banned_access_token = self.login("banned", "test")
self.store = self.hs.get_datastore()
self.get_success(
self.store.db_pool.simple_update(
table="users",
keyvalues={"name": self.banned_user_id},
updatevalues={"shadow_banned": True},
desc="shadow_ban",
)
)
self.other_user_id = self.register_user("otheruser", "pass")
self.other_access_token = self.login("otheruser", "pass")
def test_invite(self):
"""Invites from shadow-banned users don't actually get sent."""
# The create works fine.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
# Inviting the user completes successfully.
self.helper.invite(
room=room_id,
src=self.banned_user_id,
tok=self.banned_access_token,
targ=self.other_user_id,
)
# But the user wasn't actually invited.
invited_rooms = self.get_success(
self.store.get_invited_rooms_for_local_user(self.other_user_id)
)
self.assertEqual(invited_rooms, [])
def test_invite_3pid(self):
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
identity_handler = self.hs.get_handlers().identity_handler
identity_handler.lookup_3pid = Mock(
side_effect=AssertionError("This should not get called")
)
# The create works fine.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
# Inviting the user completes successfully.
request, channel = self.make_request(
"POST",
"/rooms/%s/invite" % (room_id,),
{"id_server": "test", "medium": "email", "address": "test@test.test"},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
# This should have raised an error earlier, but double check this wasn't called.
identity_handler.lookup_3pid.assert_not_called()
def test_create_room(self):
"""Invitations during a room creation should be discarded, but the room still gets created."""
# The room creation is successful.
request, channel = self.make_request(
"POST",
"/_matrix/client/r0/createRoom",
{"visibility": "public", "invite": [self.other_user_id]},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
room_id = channel.json_body["room_id"]
# But the user wasn't actually invited.
invited_rooms = self.get_success(
self.store.get_invited_rooms_for_local_user(self.other_user_id)
)
self.assertEqual(invited_rooms, [])
# Since a real room was created, the other user should be able to join it.
self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
# Both users should be in the room.
users = self.get_success(self.store.get_users_in_room(room_id))
self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
def test_message(self):
"""Messages from shadow-banned users don't actually get sent."""
room_id = self.helper.create_room_as(
self.other_user_id, tok=self.other_access_token
)
# The user should be in the room.
self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
# Sending a message should complete successfully.
result = self.helper.send_event(
room_id=room_id,
type=EventTypes.Message,
content={"msgtype": "m.text", "body": "with right label"},
tok=self.banned_access_token,
)
self.assertIn("event_id", result)
event_id = result["event_id"]
latest_events = self.get_success(
self.store.get_latest_event_ids_in_room(room_id)
)
self.assertNotIn(event_id, latest_events)
def test_upgrade(self):
"""A room upgrade should fail, but look like it succeeded."""
# The create works fine.
room_id = self.helper.create_room_as(
self.banned_user_id, tok=self.banned_access_token
)
request, channel = self.make_request(
"POST",
"/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
{"new_version": "6"},
access_token=self.banned_access_token,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
# A new room_id should be returned.
self.assertIn("replacement_room", channel.json_body)
new_room_id = channel.json_body["replacement_room"]
# It doesn't really matter what API we use here, we just want to assert
# that the room doesn't exist.
summary = self.get_success(self.store.get_room_summary(new_room_id))
# The summary should be empty since the room doesn't exist.
self.assertEqual(summary, {})

View File

@@ -16,6 +16,7 @@
# limitations under the License.
import json
from urllib.parse import urlencode
import os
import re
from email.parser import Parser
@@ -70,6 +71,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
@unittest.INFO
def test_basic_password_reset(self):
"""Test basic password reset flow
"""
@@ -250,10 +252,33 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
# Remove the host
path = link.replace("https://example.com", "")
# Load the password reset confirmation page
request, channel = self.make_request("GET", path, shorthand=False)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
# Replace the path with the confirmation path
path = "/_matrix/client/unstable/password_reset/email/submit_token_confirm"
form_args = []
for key, value_list in request.args.items():
for value in value_list:
arg = (key, value)
form_args.append(arg)
print("form_args:", form_args)
print("encoded form_args:", urlencode(form_args))
# Confirm the password reset
request, channel = self.make_request(
"POST",
path,
content=urlencode(form_args).encode("utf8"),
shorthand=False,
)
self.render(request)
self.assertEquals(200, channel.code, channel.result)
def _get_link_from_email(self):
assert self.email_attempts, "No emails have been sent"

View File

@@ -38,6 +38,11 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
return succeed(True)
class DummyPasswordChecker(UserInteractiveAuthChecker):
def check_auth(self, authdict, clientip):
return succeed(authdict["identifier"]["user"])
class FallbackAuthTests(unittest.HomeserverTestCase):
servlets = [
@@ -161,6 +166,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs)
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.user_tok = self.login("test", self.user_pass)

View File

@@ -593,89 +593,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
self.assertEqual(len(self.email_attempts), 0)
def test_deactivated_user_using_user_identifier(self):
self.email_attempts = []
(user_id, tok) = self.create_user()
request_data = json.dumps(
{
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": user_id},
"password": "monkey",
},
"erase": False,
}
)
request, channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
self.render(request)
self.assertEqual(request.code, 200)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
self.assertEqual(len(self.email_attempts), 0)
def test_deactivated_user_using_thirdparty_identifier(self):
self.email_attempts = []
(user_id, tok) = self.create_user()
request_data = json.dumps(
{
"auth": {
"type": "m.login.password",
"identifier": {
"type": "m.id.thirdparty",
"medium": "email",
"address": "kermit@example.com",
},
"password": "monkey",
},
"erase": False,
}
)
request, channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
self.render(request)
self.assertEqual(request.code, 200)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
self.assertEqual(len(self.email_attempts), 0)
def test_deactivated_user_using_phone_identifier(self):
self.email_attempts = []
(user_id, tok) = self.create_user()
request_data = json.dumps(
{
"auth": {
"type": "m.login.password",
"identifier": {
"type": "m.id.phone",
"country": "GB",
"phone": "077-009-00001",
},
"password": "monkey",
},
"erase": False,
}
)
request, channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
self.render(request)
self.assertEqual(request.code, 200)
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
self.assertEqual(len(self.email_attempts), 0)
def create_user(self):
user_id = self.register_user("kermit", "monkey")
tok = self.login("kermit", "monkey")
@@ -691,15 +608,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
added_at=now,
)
)
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
medium="msisdn",
address="447700900001",
validated_at=now,
added_at=now,
)
)
return user_id, tok
def test_manual_email_send_expired_account(self):

View File

@@ -1,6 +1,7 @@
import json
import logging
from io import BytesIO
from json.decoder import JSONDecodeError
import attr
from zope.interface import implementer
@@ -195,7 +196,19 @@ def make_request(
)
if content:
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
content_is_json = True
try:
json.loads(content)
except JSONDecodeError:
content_is_json = False
print("Content is json?", content_is_json, path)
if content_is_json:
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
else:
req.requestHeaders.addRawHeader(
b"Content-Type", b"application/x-www-form-urlencoded"
)
req.requestReceived(method, path, b"1.1")

View File

@@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
value = yield self.datastore.db_pool.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
value = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_one_onecol(
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
)
)
self.assertEquals("Value", value)
@@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = (1, 2, 3)
ret = yield self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"],
ret = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
retcols=["colA", "colB", "colC"],
)
)
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
@@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
self.mock_txn.rowcount = 0
self.mock_txn.fetchone.return_value = None
ret = yield self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
allow_none=True,
ret = yield defer.ensureDeferred(
self.datastore.db_pool.simple_select_one(
table="tablename",
keyvalues={"keycol": "Not here"},
retcols=["colA"],
allow_none=True,
)
)
self.assertFalse(ret)

View File

@@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name")
)
res = yield self.store.get_device("user_id", "device_id")
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertDictContainsSubset(
{
"user_id": "user_id",
@@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
self.store.store_device("user_id", "device_id", "display_name 1")
)
res = yield self.store.get_device("user_id", "device_id")
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do a no-op first
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
res = yield self.store.get_device("user_id", "device_id")
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 1", res["display_name"])
# do the update
@@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
)
# check it worked
res = yield self.store.get_device("user_id", "device_id")
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
self.assertEqual("display_name 2", res["display_name"])
@defer.inlineCallbacks

View File

@@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase):
def test_displayname(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
yield defer.ensureDeferred(
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
)
self.assertEquals(
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
"Frank",
(
yield defer.ensureDeferred(
self.store.get_profile_displayname(self.u_frank.localpart)
)
),
)
@defer.inlineCallbacks
def test_avatar_url(self):
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
yield self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
yield defer.ensureDeferred(
self.store.set_profile_avatar_url(
self.u_frank.localpart, "http://my.site/here"
)
)
self.assertEquals(
"http://my.site/here",
(yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
(
yield defer.ensureDeferred(
self.store.get_profile_avatar_url(self.u_frank.localpart)
)
),
)

View File

@@ -53,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
"user_type": None,
"deactivated": 0,
},
(yield self.store.get_user_by_id(self.user_id)),
(yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
)
@defer.inlineCallbacks

View File

@@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"is_public": True,
},
(yield self.store.get_room(self.room.to_string())),
(yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
)
@defer.inlineCallbacks
def test_get_room_unknown_room(self):
self.assertIsNone((yield self.store.get_room("!uknown:test")),)
self.assertIsNone(
(yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
)
@defer.inlineCallbacks
def test_get_room_with_stats(self):
@@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
"creator": self.u_creator.to_string(),
"public": True,
},
(yield self.store.get_room_with_stats(self.room.to_string())),
(
yield defer.ensureDeferred(
self.store.get_room_with_stats(self.room.to_string())
)
),
)
@defer.inlineCallbacks
def test_get_room_with_stats_unknown_room(self):
self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
self.assertIsNone(
(
yield defer.ensureDeferred(
self.store.get_room_with_stats("!uknown:test")
)
),
)
class RoomEventsStoreTestCase(unittest.TestCase):

37
tests/test_utils/http.py Normal file
View File

@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.web.server import Request
def convert_request_args_to_form_data(request: Request) -> bytes:
"""Converts query arguments from a request to formatted HTML form data
Ref: https://developer.mozilla.org/en-US/docs/Learn/Forms/Sending_and_retrieving_form_data
Args:
The request to pull arguments from
Returns:
The HTML form body data representation of the request's arguments
"""
body = b""
for key, value in request.args.items():
arg = b"%s=%s&" % (key, value[0])
body += arg
# Remove the last '&' sign
return body[:-1]