Compare commits
21 Commits
anoa/user_
...
anoa/halp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd835d20a0 | ||
|
|
4d8bf7d021 | ||
|
|
1b4458ed26 | ||
|
|
990178f58b | ||
|
|
d6addba84e | ||
|
|
3f4e350cd9 | ||
|
|
88b9807ba4 | ||
|
|
2e6c90ff84 | ||
|
|
e3c91a3c55 | ||
|
|
4c6c56dc58 | ||
|
|
56efa9ec71 | ||
|
|
1d79f7b22b | ||
|
|
0a2b11f361 | ||
|
|
b41b0512f9 | ||
|
|
8cc8ee4448 | ||
|
|
3568e1897c | ||
|
|
622946e881 | ||
|
|
6140b32397 | ||
|
|
d9a19fc696 | ||
|
|
5f7a834a50 | ||
|
|
9003eb4bcd |
24
UPGRADE.rst
24
UPGRADE.rst
@@ -75,6 +75,30 @@ for example:
|
||||
wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||
dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb
|
||||
|
||||
Upgrading to v1.20.0
|
||||
====================
|
||||
|
||||
New HTML templates
|
||||
------------------
|
||||
|
||||
A new HTML template,
|
||||
`password_reset_confirmation.html <https://github.com/matrix-org/synapse/blob/develop/synapse/res/templates/password_reset_confirmation.html>`_,
|
||||
has been added to the ``synapse/res/templates`` directory. If you are using a
|
||||
custom template directory, you may want to copy the template over and modify it.
|
||||
|
||||
Note that as of v1.20.0, templates do not need to be included in custom template
|
||||
directories for Synapse to start. The default templates will be used if a custom
|
||||
template cannot be found.
|
||||
|
||||
This page will appear to the user after clicking a password reset link that has
|
||||
been emailed to them.
|
||||
|
||||
To complete password reset, the page must include a way to make a `POST`
|
||||
request to
|
||||
``/_matrix/client/unstable/password_reset/{medium}/submit_token_confirm``
|
||||
with the query parameters from the original link. See the file itself for more
|
||||
details.
|
||||
|
||||
Upgrading to v1.18.0
|
||||
====================
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Support `identifier` dictionary fields in User-Interactive Authentication flows. Relax requirement of the `user` parameter.
|
||||
1
changelog.d/8004.feature
Normal file
1
changelog.d/8004.feature
Normal file
@@ -0,0 +1 @@
|
||||
Require the user to confirm that their password should be reset after clicking the email confirmation link.
|
||||
1
changelog.d/8130.misc
Normal file
1
changelog.d/8130.misc
Normal file
@@ -0,0 +1 @@
|
||||
Update the test federation client to handle streaming responses.
|
||||
1
changelog.d/8157.feature
Normal file
1
changelog.d/8157.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add support for shadow-banning users (ignoring any message send requests).
|
||||
1
changelog.d/8162.misc
Normal file
1
changelog.d/8162.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8167.misc
Normal file
1
changelog.d/8167.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix tests that were broken due to the merge of 1.19.1.
|
||||
1
changelog.d/8171.misc
Normal file
1
changelog.d/8171.misc
Normal file
@@ -0,0 +1 @@
|
||||
Make `SlavedIdTracker.advance` have the same interface as `MultiWriterIDGenerator`.
|
||||
@@ -2021,9 +2021,13 @@ email:
|
||||
# * The contents of password reset emails sent by the homeserver:
|
||||
# 'password_reset.html' and 'password_reset.txt'
|
||||
#
|
||||
# * HTML pages for success and failure that a user will see when they follow
|
||||
# the link in the password reset email: 'password_reset_success.html' and
|
||||
# 'password_reset_failure.html'
|
||||
# * An HTML page that a user will see when they follow the link in the password
|
||||
# reset email. The user will be asked to confirm the action before their
|
||||
# password is reset: 'password_reset_confirmation.html'
|
||||
#
|
||||
# * HTML pages for success and failure that a user will see when they confirm
|
||||
# the password reset flow using the page above: 'password_reset_success.html'
|
||||
# and 'password_reset_failure.html'
|
||||
#
|
||||
# * The contents of address verification emails sent during registration:
|
||||
# 'registration.html' and 'registration.txt'
|
||||
|
||||
@@ -21,10 +21,12 @@ import argparse
|
||||
import base64
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, Optional
|
||||
from urllib import parse as urlparse
|
||||
|
||||
import nacl.signing
|
||||
import requests
|
||||
import signedjson.types
|
||||
import srvlookup
|
||||
import yaml
|
||||
from requests.adapters import HTTPAdapter
|
||||
@@ -69,7 +71,9 @@ def encode_canonical_json(value):
|
||||
).encode("UTF-8")
|
||||
|
||||
|
||||
def sign_json(json_object, signing_key, signing_name):
|
||||
def sign_json(
|
||||
json_object: Any, signing_key: signedjson.types.SigningKey, signing_name: str
|
||||
) -> Any:
|
||||
signatures = json_object.pop("signatures", {})
|
||||
unsigned = json_object.pop("unsigned", None)
|
||||
|
||||
@@ -122,7 +126,14 @@ def read_signing_keys(stream):
|
||||
return keys
|
||||
|
||||
|
||||
def request_json(method, origin_name, origin_key, destination, path, content):
|
||||
def request(
|
||||
method: Optional[str],
|
||||
origin_name: str,
|
||||
origin_key: signedjson.types.SigningKey,
|
||||
destination: str,
|
||||
path: str,
|
||||
content: Optional[str],
|
||||
) -> requests.Response:
|
||||
if method is None:
|
||||
if content is None:
|
||||
method = "GET"
|
||||
@@ -159,11 +170,14 @@ def request_json(method, origin_name, origin_key, destination, path, content):
|
||||
if method == "POST":
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
result = s.request(
|
||||
method=method, url=dest, headers=headers, verify=False, data=content
|
||||
return s.request(
|
||||
method=method,
|
||||
url=dest,
|
||||
headers=headers,
|
||||
verify=False,
|
||||
data=content,
|
||||
stream=True,
|
||||
)
|
||||
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
|
||||
return result.json()
|
||||
|
||||
|
||||
def main():
|
||||
@@ -222,7 +236,7 @@ def main():
|
||||
with open(args.signing_key_path) as f:
|
||||
key = read_signing_keys(f)[0]
|
||||
|
||||
result = request_json(
|
||||
result = request(
|
||||
args.method,
|
||||
args.server_name,
|
||||
key,
|
||||
@@ -231,7 +245,12 @@ def main():
|
||||
content=args.body,
|
||||
)
|
||||
|
||||
json.dump(result, sys.stdout)
|
||||
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
|
||||
|
||||
for chunk in result.iter_content():
|
||||
# we write raw utf8 to stdout.
|
||||
sys.stdout.buffer.write(chunk)
|
||||
|
||||
print("")
|
||||
|
||||
|
||||
|
||||
@@ -198,6 +198,9 @@ class EmailConfig(Config):
|
||||
"add_threepid_template_text", "add_threepid.txt"
|
||||
)
|
||||
|
||||
password_reset_template_confirmation_html = (
|
||||
"password_reset_confirmation.html"
|
||||
)
|
||||
password_reset_template_failure_html = email_config.get(
|
||||
"password_reset_template_failure_html", "password_reset_failure.html"
|
||||
)
|
||||
@@ -228,6 +231,7 @@ class EmailConfig(Config):
|
||||
self.email_registration_template_text,
|
||||
self.email_add_threepid_template_html,
|
||||
self.email_add_threepid_template_text,
|
||||
self.email_password_reset_template_confirmation_html,
|
||||
self.email_password_reset_template_failure_html,
|
||||
self.email_registration_template_failure_html,
|
||||
self.email_add_threepid_template_failure_html,
|
||||
@@ -242,6 +246,7 @@ class EmailConfig(Config):
|
||||
registration_template_text,
|
||||
add_threepid_template_html,
|
||||
add_threepid_template_text,
|
||||
password_reset_template_confirmation_html,
|
||||
password_reset_template_failure_html,
|
||||
registration_template_failure_html,
|
||||
add_threepid_template_failure_html,
|
||||
@@ -404,9 +409,13 @@ class EmailConfig(Config):
|
||||
# * The contents of password reset emails sent by the homeserver:
|
||||
# 'password_reset.html' and 'password_reset.txt'
|
||||
#
|
||||
# * HTML pages for success and failure that a user will see when they follow
|
||||
# the link in the password reset email: 'password_reset_success.html' and
|
||||
# 'password_reset_failure.html'
|
||||
# * An HTML page that a user will see when they follow the link in the password
|
||||
# reset email. The user will be asked to confirm the action before their
|
||||
# password is reset: 'password_reset_confirmation.html'
|
||||
#
|
||||
# * HTML pages for success and failure that a user will see when they confirm
|
||||
# the password reset flow using the page above: 'password_reset_success.html'
|
||||
# and 'password_reset_failure.html'
|
||||
#
|
||||
# * The contents of address verification emails sent during registration:
|
||||
# 'registration.html' and 'registration.txt'
|
||||
|
||||
@@ -38,14 +38,12 @@ from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
|
||||
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
|
||||
from synapse.http.server import finish_request, respond_with_html
|
||||
from synapse.http.servlet import assert_params_in_dict
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.types import Requester, UserID
|
||||
from synapse.util import stringutils as stringutils
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
from ._base import BaseHandler
|
||||
@@ -53,82 +51,6 @@ from ._base import BaseHandler
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def client_dict_convert_legacy_fields_to_identifier(
|
||||
submission: Dict[str, Union[str, Dict]]
|
||||
):
|
||||
"""
|
||||
Convert a legacy-formatted login submission to an identifier dict.
|
||||
|
||||
Legacy login submissions (used in both login and user-interactive authentication)
|
||||
provide user-identifying information at the top-level instead of in an `indentifier`
|
||||
property. This is now deprecated and replaced with identifiers:
|
||||
https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
|
||||
|
||||
Args:
|
||||
submission: The client dict to convert. Passed by reference and modified
|
||||
|
||||
Raises:
|
||||
SynapseError: If the format of the client dict is invalid
|
||||
"""
|
||||
if "user" in submission:
|
||||
submission["identifier"] = {"type": "m.id.user", "user": submission.pop("user")}
|
||||
|
||||
if "medium" in submission and "address" in submission:
|
||||
submission["identifier"] = {
|
||||
"type": "m.id.thirdparty",
|
||||
"medium": submission.pop("medium"),
|
||||
"address": submission.pop("address"),
|
||||
}
|
||||
|
||||
# We've converted valid, legacy login submissions to an identifier. If the
|
||||
# dict still doesn't have an identifier, it's invalid
|
||||
assert_params_in_dict(submission, required=["identifier"])
|
||||
|
||||
# Ensure the identifier has a type
|
||||
if "type" not in submission["identifier"]:
|
||||
raise SynapseError(
|
||||
400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
|
||||
)
|
||||
|
||||
|
||||
def login_id_phone_to_thirdparty(identifier: Dict[str, str]) -> Dict[str, str]:
|
||||
"""Convert a phone login identifier type to a generic threepid identifier.
|
||||
|
||||
Args:
|
||||
identifier: Login identifier dict of type 'm.id.phone'
|
||||
|
||||
Returns:
|
||||
An equivalent m.id.thirdparty identifier dict.
|
||||
"""
|
||||
if "type" not in identifier:
|
||||
raise SynapseError(
|
||||
400, "Invalid phone-type identifier", errcode=Codes.MISSING_PARAM
|
||||
)
|
||||
|
||||
if "country" not in identifier or (
|
||||
# XXX: We used to require `number` instead of `phone`. The spec
|
||||
# defines `phone`. So accept both
|
||||
"phone" not in identifier
|
||||
and "number" not in identifier
|
||||
):
|
||||
raise SynapseError(
|
||||
400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
# Accept both "phone" and "number" as valid keys in m.id.phone
|
||||
phone_number = identifier.get("phone", identifier.get("number"))
|
||||
|
||||
# Convert user-provided phone number to a consistent representation
|
||||
msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
|
||||
|
||||
# Return the new dictionary
|
||||
return {
|
||||
"type": "m.id.thirdparty",
|
||||
"medium": "msisdn",
|
||||
"address": msisdn,
|
||||
}
|
||||
|
||||
|
||||
class AuthHandler(BaseHandler):
|
||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||
|
||||
@@ -397,7 +319,7 @@ class AuthHandler(BaseHandler):
|
||||
# otherwise use whatever was last provided.
|
||||
#
|
||||
# This was designed to allow the client to omit the parameters
|
||||
# and just supply the session in subsequent calls. So it splits
|
||||
# and just supply the session in subsequent calls so it split
|
||||
# auth between devices by just sharing the session, (eg. so you
|
||||
# could continue registration from your phone having clicked the
|
||||
# email auth link on there). It's probably too open to abuse
|
||||
@@ -602,129 +524,16 @@ class AuthHandler(BaseHandler):
|
||||
res = await checker.check_auth(authdict, clientip=clientip)
|
||||
return res
|
||||
|
||||
# We don't have a checker for the auth type provided by the client
|
||||
# Assume that it is `m.login.password`.
|
||||
if login_type != LoginType.PASSWORD:
|
||||
raise SynapseError(
|
||||
400, "Unknown authentication type", errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
# build a v1-login-style dict out of the authdict and fall back to the
|
||||
# v1 code
|
||||
user_id = authdict.get("user")
|
||||
|
||||
password = authdict.get("password")
|
||||
if password is None:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Missing parameter for m.login.password dict: 'password'",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
# Retrieve the user ID using details provided in the authdict
|
||||
|
||||
# Deprecation notice: Clients used to be able to simply provide a
|
||||
# `user` field which pointed to a user_id or localpart. This has
|
||||
# been deprecated in favour of an `identifier` key, which is a
|
||||
# dictionary providing information on how to identify a single
|
||||
# user.
|
||||
# https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
|
||||
#
|
||||
# We convert old-style dicts to new ones here
|
||||
client_dict_convert_legacy_fields_to_identifier(authdict)
|
||||
|
||||
# Extract a user ID from the values in the identifier
|
||||
username = await self.username_from_identifier(authdict["identifier"], password)
|
||||
|
||||
if username is None:
|
||||
raise SynapseError(400, "Valid username not found")
|
||||
|
||||
# Now that we've found the username, validate that the password is correct
|
||||
canonical_id, _ = await self.validate_login(username, authdict)
|
||||
if user_id is None:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
(canonical_id, callback) = await self.validate_login(user_id, authdict)
|
||||
return canonical_id
|
||||
|
||||
async def username_from_identifier(
|
||||
self, identifier: Dict[str, str], password: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""Given a dictionary containing an identifier from a client, extract the
|
||||
possibly unqualified username of the user that it identifies. Does *not*
|
||||
guarantee that the user exists.
|
||||
|
||||
If this identifier dict contains a threepid, we attempt to ask password
|
||||
auth providers about it or, failing that, look up an associated user in
|
||||
the database.
|
||||
|
||||
Args:
|
||||
identifier: The identifier dictionary provided by the client
|
||||
password: The user provided password if one exists. Used for asking
|
||||
password auth providers for usernames from 3pid+password combos.
|
||||
|
||||
Returns:
|
||||
A username if one was found, or None otherwise
|
||||
|
||||
Raises:
|
||||
SynapseError: If the identifier dict is invalid
|
||||
"""
|
||||
|
||||
# Convert phone type identifiers to generic threepid identifiers, which
|
||||
# will be handled in the next step
|
||||
if identifier["type"] == "m.id.phone":
|
||||
identifier = login_id_phone_to_thirdparty(identifier)
|
||||
|
||||
# Convert a threepid identifier to an user identifier
|
||||
if identifier["type"] == "m.id.thirdparty":
|
||||
address = identifier.get("address")
|
||||
medium = identifier.get("medium")
|
||||
|
||||
if not medium or not address:
|
||||
# An error would've already been raised in
|
||||
# `login_id_thirdparty_from_phone` if the original submission
|
||||
# was a phone identifier
|
||||
raise SynapseError(
|
||||
400, "Invalid thirdparty identifier", errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
if medium == "email":
|
||||
# For emails, transform the address to lowercase.
|
||||
# We store all email addresses as lowercase in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
address = address.lower()
|
||||
|
||||
# Check for auth providers that support 3pid login types
|
||||
if password is not None:
|
||||
canonical_user_id, _ = await self.check_password_provider_3pid(
|
||||
medium, address, password,
|
||||
)
|
||||
if canonical_user_id:
|
||||
# Authentication through password provider and 3pid succeeded
|
||||
return canonical_user_id
|
||||
|
||||
# Check local store
|
||||
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
||||
medium, address
|
||||
)
|
||||
if not user_id:
|
||||
# We were unable to find a user_id that belonged to the threepid returned
|
||||
# by the password auth provider
|
||||
return None
|
||||
|
||||
identifier = {"type": "m.id.user", "user": user_id}
|
||||
|
||||
# By this point, the identifier should be a `m.id.user`: if it's anything
|
||||
# else, we haven't understood it.
|
||||
if identifier["type"] != "m.id.user":
|
||||
raise SynapseError(
|
||||
400, "Unknown login identifier type", errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
# User identifiers have a "user" key
|
||||
user = identifier.get("user")
|
||||
if user is None:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User identifier is missing 'user' key",
|
||||
errcode=Codes.INVALID_PARAM,
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
def _get_params_recaptcha(self) -> dict:
|
||||
return {"public_key": self.hs.config.recaptcha_public_key}
|
||||
|
||||
@@ -889,8 +698,7 @@ class AuthHandler(BaseHandler):
|
||||
m.login.password auth types.
|
||||
|
||||
Args:
|
||||
username: a localpart or fully qualified user ID - what is provided by the
|
||||
client
|
||||
username: username supplied by the user
|
||||
login_submission: the whole of the login submission
|
||||
(including 'type' and other relevant fields)
|
||||
Returns:
|
||||
@@ -902,10 +710,10 @@ class AuthHandler(BaseHandler):
|
||||
LoginError if there was an authentication problem.
|
||||
"""
|
||||
|
||||
# We need a fully qualified User ID for some method calls here
|
||||
qualified_user_id = username
|
||||
if not qualified_user_id.startswith("@"):
|
||||
qualified_user_id = UserID(qualified_user_id, self.hs.hostname).to_string()
|
||||
if username.startswith("@"):
|
||||
qualified_user_id = username
|
||||
else:
|
||||
qualified_user_id = UserID(username, self.hs.hostname).to_string()
|
||||
|
||||
login_type = login_submission.get("type")
|
||||
known_login_type = False
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import random
|
||||
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
@@ -213,8 +214,14 @@ class BaseProfileHandler(BaseHandler):
|
||||
async def set_avatar_url(
|
||||
self, target_user, requester, new_avatar_url, by_admin=False
|
||||
):
|
||||
"""target_user is the user whose avatar_url is to be changed;
|
||||
auth_user is the user attempting to make this change."""
|
||||
"""Set a new avatar URL for a user.
|
||||
|
||||
Args:
|
||||
target_user (UserID): the user whose avatar URL is to be changed.
|
||||
requester (Requester): The user attempting to make this change.
|
||||
new_avatar_url (str): The avatar URL to give this user.
|
||||
by_admin (bool): Whether this change was made by an administrator.
|
||||
"""
|
||||
if not self.hs.is_mine(target_user):
|
||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||
|
||||
@@ -278,6 +285,12 @@ class BaseProfileHandler(BaseHandler):
|
||||
|
||||
await self.ratelimit(requester)
|
||||
|
||||
# Do not actually update the room state for shadow-banned users.
|
||||
if requester.shadow_banned:
|
||||
# We randomly sleep a bit just to annoy the requester.
|
||||
await self.clock.sleep(random.randint(1, 10))
|
||||
return
|
||||
|
||||
room_ids = await self.store.get_rooms_for_user(target_user.to_string())
|
||||
|
||||
for room_id in room_ids:
|
||||
|
||||
@@ -380,7 +380,7 @@ class RoomMemberHandler(object):
|
||||
# later on.
|
||||
content = dict(content)
|
||||
|
||||
if not self.allow_per_room_profiles:
|
||||
if not self.allow_per_room_profiles or requester.shadow_banned:
|
||||
# Strip profile data, knowing that new profile data will be added to the
|
||||
# event's content in event_creation_handler.create_event() using the target's
|
||||
# global profile.
|
||||
|
||||
@@ -21,9 +21,9 @@ class SlavedIdTracker(object):
|
||||
self.step = step
|
||||
self._current = _load_current_id(db_conn, table, column, step)
|
||||
for table, column in extra_tables:
|
||||
self.advance(_load_current_id(db_conn, table, column))
|
||||
self.advance(None, _load_current_id(db_conn, table, column))
|
||||
|
||||
def advance(self, new_id):
|
||||
def advance(self, instance_name, new_id):
|
||||
self._current = (max if self.step > 0 else min)(self._current, new_id)
|
||||
|
||||
def get_current_token(self):
|
||||
|
||||
@@ -41,12 +41,12 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == TagAccountDataStream.NAME:
|
||||
self._account_data_id_gen.advance(token)
|
||||
self._account_data_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self.get_tags_for_user.invalidate((row.user_id,))
|
||||
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
|
||||
elif stream_name == AccountDataStream.NAME:
|
||||
self._account_data_id_gen.advance(token)
|
||||
self._account_data_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
if not row.room_id:
|
||||
self.get_global_account_data_by_type_for_user.invalidate(
|
||||
|
||||
@@ -46,7 +46,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == ToDeviceStream.NAME:
|
||||
self._device_inbox_id_gen.advance(token)
|
||||
self._device_inbox_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
if row.entity.startswith("@"):
|
||||
self._device_inbox_stream_cache.entity_has_changed(
|
||||
|
||||
@@ -50,10 +50,10 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(token)
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
self._invalidate_caches_for_devices(token, rows)
|
||||
elif stream_name == UserSignatureStream.NAME:
|
||||
self._device_list_id_gen.advance(token)
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -40,7 +40,7 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == GroupServerStream.NAME:
|
||||
self._group_updates_id_gen.advance(token)
|
||||
self._group_updates_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == PresenceStream.NAME:
|
||||
self._presence_id_gen.advance(token)
|
||||
self._presence_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self.presence_stream_cache.entity_has_changed(row.user_id, token)
|
||||
self._get_presence_for_user.invalidate((row.user_id,))
|
||||
|
||||
@@ -30,7 +30,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
||||
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
|
||||
|
||||
if stream_name == PushRulesStream.NAME:
|
||||
self._push_rules_stream_id_gen.advance(token)
|
||||
self._push_rules_stream_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self.get_push_rules_for_user.invalidate((row.user_id,))
|
||||
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
|
||||
|
||||
@@ -34,5 +34,5 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == PushersStream.NAME:
|
||||
self._pushers_id_gen.advance(token)
|
||||
self._pushers_id_gen.advance(instance_name, token)
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -46,7 +46,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == ReceiptsStream.NAME:
|
||||
self._receipts_id_gen.advance(token)
|
||||
self._receipts_id_gen.advance(instance_name, token)
|
||||
for row in rows:
|
||||
self.invalidate_caches_for_receipt(
|
||||
row.room_id, row.receipt_type, row.user_id
|
||||
|
||||
@@ -33,6 +33,6 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == PublicRoomsStream.NAME:
|
||||
self._public_room_id_gen.advance(token)
|
||||
self._public_room_id_gen.advance(instance_name, token)
|
||||
|
||||
return super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
16
synapse/res/templates/password_reset_confirmation.html
Normal file
16
synapse/res/templates/password_reset_confirmation.html
Normal file
@@ -0,0 +1,16 @@
|
||||
<html>
|
||||
<head></head>
|
||||
<body>
|
||||
<!--Use a hidden form to resubmit the information necessary to reset the password-->
|
||||
<form action="/_matrix/client/unstable/password_reset/{{ medium }}/submit_token_confirm" method="post">
|
||||
<input type="hidden" name="sid" value="{{ sid }}">
|
||||
<input type="hidden" name="token" value="{{ token }}">
|
||||
<input type="hidden" name="client_secret" value="{{ client_secret }}">
|
||||
|
||||
<p>You have requested to <strong>reset your Matrix account password</strong>. Click the link below to confirm this action. <br /><br />
|
||||
If you did not mean to do this, please close this page and your password will not be changed.</p>
|
||||
<p><button type="submit">Confirm changing my password</button></p>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
|
||||
@@ -18,7 +18,6 @@ from typing import Awaitable, Callable, Dict, Optional
|
||||
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.handlers.auth import client_dict_convert_legacy_fields_to_identifier
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
@@ -29,11 +28,56 @@ from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.rest.well_known import WellKnownBuilder
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def login_submission_legacy_convert(submission):
|
||||
"""
|
||||
If the input login submission is an old style object
|
||||
(ie. with top-level user / medium / address) convert it
|
||||
to a typed object.
|
||||
"""
|
||||
if "user" in submission:
|
||||
submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
|
||||
del submission["user"]
|
||||
|
||||
if "medium" in submission and "address" in submission:
|
||||
submission["identifier"] = {
|
||||
"type": "m.id.thirdparty",
|
||||
"medium": submission["medium"],
|
||||
"address": submission["address"],
|
||||
}
|
||||
del submission["medium"]
|
||||
del submission["address"]
|
||||
|
||||
|
||||
def login_id_thirdparty_from_phone(identifier):
|
||||
"""
|
||||
Convert a phone login identifier type to a generic threepid identifier
|
||||
Args:
|
||||
identifier(dict): Login identifier dict of type 'm.id.phone'
|
||||
|
||||
Returns: Login identifier dict of type 'm.id.threepid'
|
||||
"""
|
||||
if "country" not in identifier or (
|
||||
# The specification requires a "phone" field, while Synapse used to require a "number"
|
||||
# field. Accept both for backwards compatibility.
|
||||
"phone" not in identifier
|
||||
and "number" not in identifier
|
||||
):
|
||||
raise SynapseError(400, "Invalid phone-type identifier")
|
||||
|
||||
# Accept both "phone" and "number" as valid keys in m.id.phone
|
||||
phone_number = identifier.get("phone", identifier["number"])
|
||||
|
||||
msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
|
||||
|
||||
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
|
||||
|
||||
|
||||
class LoginRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/login$", v1=True)
|
||||
CAS_TYPE = "m.login.cas"
|
||||
@@ -123,8 +167,7 @@ class LoginRestServlet(RestServlet):
|
||||
result = await self._do_token_login(login_submission)
|
||||
else:
|
||||
result = await self._do_other_login(login_submission)
|
||||
except KeyError as e:
|
||||
logger.debug("KeyError during login: %s", e)
|
||||
except KeyError:
|
||||
raise SynapseError(400, "Missing JSON keys.")
|
||||
|
||||
well_known_data = self._well_known_builder.get_well_known()
|
||||
@@ -151,14 +194,27 @@ class LoginRestServlet(RestServlet):
|
||||
login_submission.get("address"),
|
||||
login_submission.get("user"),
|
||||
)
|
||||
# Convert deprecated authdict formats to the current scheme
|
||||
client_dict_convert_legacy_fields_to_identifier(login_submission)
|
||||
login_submission_legacy_convert(login_submission)
|
||||
|
||||
if "identifier" not in login_submission:
|
||||
raise SynapseError(400, "Missing param: identifier")
|
||||
|
||||
identifier = login_submission["identifier"]
|
||||
if "type" not in identifier:
|
||||
raise SynapseError(400, "Login identifier has no type")
|
||||
|
||||
# convert phone type identifiers to generic threepids
|
||||
if identifier["type"] == "m.id.phone":
|
||||
identifier = login_id_thirdparty_from_phone(identifier)
|
||||
|
||||
# convert threepid identifiers to user IDs
|
||||
if identifier["type"] == "m.id.thirdparty":
|
||||
address = identifier.get("address")
|
||||
medium = identifier.get("medium")
|
||||
|
||||
if medium is None or address is None:
|
||||
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||
|
||||
# Check whether this attempt uses a threepid, if so, check if our failed attempt
|
||||
# ratelimiter allows another attempt at this time
|
||||
medium = login_submission.get("medium")
|
||||
address = login_submission.get("address")
|
||||
if medium and address:
|
||||
# For emails, canonicalise the address.
|
||||
# We store all email addresses canonicalised in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
@@ -168,41 +224,74 @@ class LoginRestServlet(RestServlet):
|
||||
except ValueError as e:
|
||||
raise SynapseError(400, str(e))
|
||||
|
||||
# We also apply account rate limiting using the 3PID as a key, as
|
||||
# otherwise using 3PID bypasses the ratelimiting based on user ID.
|
||||
self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
|
||||
|
||||
# Extract a localpart or user ID from the values in the identifier
|
||||
username = await self.auth_handler.username_from_identifier(
|
||||
login_submission["identifier"], login_submission.get("password")
|
||||
)
|
||||
# Check for login providers that support 3pid login types
|
||||
(
|
||||
canonical_user_id,
|
||||
callback_3pid,
|
||||
) = await self.auth_handler.check_password_provider_3pid(
|
||||
medium, address, login_submission["password"]
|
||||
)
|
||||
if canonical_user_id:
|
||||
# Authentication through password provider and 3pid succeeded
|
||||
|
||||
if not username:
|
||||
if medium and address:
|
||||
# The user attempted to login via threepid and failed
|
||||
# Record this failed attempt using the threepid as a key, as otherwise
|
||||
# the user could bypass the ratelimiter by not providing a username
|
||||
self._failed_attempts_ratelimiter.can_do_action(
|
||||
(medium, address.lower())
|
||||
result = await self._complete_login(
|
||||
canonical_user_id, login_submission, callback_3pid
|
||||
)
|
||||
return result
|
||||
|
||||
raise LoginError(403, "Unauthorized threepid", errcode=Codes.FORBIDDEN)
|
||||
# No password providers were able to handle this 3pid
|
||||
# Check local store
|
||||
user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
||||
medium, address
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"unknown 3pid identifier medium %s, address %r", medium, address
|
||||
)
|
||||
# We mark that we've failed to log in here, as
|
||||
# `check_password_provider_3pid` might have returned `None` due
|
||||
# to an incorrect password, rather than the account not
|
||||
# existing.
|
||||
#
|
||||
# If it returned None but the 3PID was bound then we won't hit
|
||||
# this code path, which is fine as then the per-user ratelimit
|
||||
# will kick in below.
|
||||
self._failed_attempts_ratelimiter.can_do_action((medium, address))
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
# The login failed for another reason
|
||||
raise LoginError(403, "Invalid login", errcode=Codes.FORBIDDEN)
|
||||
identifier = {"type": "m.id.user", "user": user_id}
|
||||
|
||||
# We were able to extract a username successfully
|
||||
# Check if we've hit the failed ratelimit for this user ID
|
||||
self._failed_attempts_ratelimiter.ratelimit(username.lower(), update=False)
|
||||
# by this point, the identifier should be an m.id.user: if it's anything
|
||||
# else, we haven't understood it.
|
||||
if identifier["type"] != "m.id.user":
|
||||
raise SynapseError(400, "Unknown login identifier type")
|
||||
if "user" not in identifier:
|
||||
raise SynapseError(400, "User identifier is missing 'user' key")
|
||||
|
||||
if identifier["user"].startswith("@"):
|
||||
qualified_user_id = identifier["user"]
|
||||
else:
|
||||
qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
|
||||
|
||||
# Check if we've hit the failed ratelimit (but don't update it)
|
||||
self._failed_attempts_ratelimiter.ratelimit(
|
||||
qualified_user_id.lower(), update=False
|
||||
)
|
||||
|
||||
try:
|
||||
canonical_user_id, callback = await self.auth_handler.validate_login(
|
||||
username, login_submission
|
||||
identifier["user"], login_submission
|
||||
)
|
||||
except LoginError:
|
||||
# The user has failed to log in, so we need to update the rate
|
||||
# limiter. Using `can_do_action` avoids us raising a ratelimit
|
||||
# exception and masking the LoginError. This just records the attempt.
|
||||
# The actual rate-limiting happens above
|
||||
self._failed_attempts_ratelimiter.can_do_action(username.lower())
|
||||
# exception and masking the LoginError. The actual ratelimiting
|
||||
# should have happened above.
|
||||
self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
|
||||
raise
|
||||
|
||||
result = await self._complete_login(
|
||||
@@ -220,7 +309,7 @@ class LoginRestServlet(RestServlet):
|
||||
create_non_existent_users: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Called when we've successfully authed the user and now need to
|
||||
actually log them in (e.g. create devices). This gets called on
|
||||
actually login them in (e.g. create devices). This gets called on
|
||||
all successful logins.
|
||||
|
||||
Applies the ratelimiting for successful login attempts against an
|
||||
|
||||
@@ -17,7 +17,9 @@
|
||||
import logging
|
||||
import random
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.server import Request
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import (
|
||||
Codes,
|
||||
@@ -38,6 +40,9 @@ from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.stringutils import assert_valid_client_secret, random_string
|
||||
from synapse.util.threepids import canonicalise_email, check_3pid_allowed
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from ._base import client_patterns, interactive_auth_handler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -157,14 +162,14 @@ class PasswordResetSubmitTokenServlet(RestServlet):
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(PasswordResetSubmitTokenServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.config = hs.config
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
self._failure_email_template = (
|
||||
self.config.email_password_reset_template_failure_html
|
||||
|
||||
self._threepid_behaviour_email = hs.config.threepid_behaviour_email
|
||||
self._local_threepid_handling_disabled_due_to_email_config = (
|
||||
hs.config.local_threepid_handling_disabled_due_to_email_config
|
||||
)
|
||||
if self._threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
self._confirmation_email_template = (
|
||||
hs.config.email_password_reset_template_confirmation_html
|
||||
)
|
||||
|
||||
async def on_GET(self, request, medium):
|
||||
@@ -173,20 +178,91 @@ class PasswordResetSubmitTokenServlet(RestServlet):
|
||||
raise SynapseError(
|
||||
400, "This medium is currently not supported for password resets"
|
||||
)
|
||||
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
|
||||
if self.config.local_threepid_handling_disabled_due_to_email_config:
|
||||
if self._threepid_behaviour_email == ThreepidBehaviour.OFF:
|
||||
if self._local_threepid_handling_disabled_due_to_email_config:
|
||||
logger.warning(
|
||||
"Password reset emails have been disabled due to lack of an email config"
|
||||
)
|
||||
raise SynapseError(
|
||||
400, "Email-based password resets are disabled on this server"
|
||||
)
|
||||
elif self._threepid_behaviour_email == ThreepidBehaviour.REMOTE:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Password resets for this homeserver are handled by a separate program",
|
||||
)
|
||||
|
||||
sid = parse_string(request, "sid", required=True)
|
||||
token = parse_string(request, "token", required=True)
|
||||
client_secret = parse_string(request, "client_secret", required=True)
|
||||
assert_valid_client_secret(client_secret)
|
||||
|
||||
# Show a confirmation page, just in case someone accidentally clicked this link when
|
||||
# they didn't mean to
|
||||
template_vars = {
|
||||
"sid": sid,
|
||||
"token": token,
|
||||
"client_secret": client_secret,
|
||||
"medium": medium,
|
||||
}
|
||||
respond_with_html(
|
||||
request, 200, self._confirmation_email_template.render(**template_vars)
|
||||
)
|
||||
|
||||
|
||||
class PasswordResetConfirmationSubmitTokenServlet(RestServlet):
|
||||
"""Handles confirmation of 3PID validation token submission.
|
||||
|
||||
A user will land on PasswordResetSubmitTokenServlet, confirm the password reset, then
|
||||
submit the same parameters to this servlet.
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns(
|
||||
"/password_reset/email/submit_token_confirm$", releases=(), unstable=True,
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
"""
|
||||
Args:
|
||||
hs: server
|
||||
"""
|
||||
super().__init__()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self._threepid_behaviour_email = hs.config.threepid_behaviour_email
|
||||
self._local_threepid_handling_disabled_due_to_email_config = (
|
||||
hs.config.local_threepid_handling_disabled_due_to_email_config
|
||||
)
|
||||
if self._threepid_behaviour_email == ThreepidBehaviour.LOCAL:
|
||||
self._email_password_reset_template_success_html = (
|
||||
hs.config.email_password_reset_template_success_html_content
|
||||
)
|
||||
self._failure_email_template = (
|
||||
hs.config.email_password_reset_template_failure_html
|
||||
)
|
||||
|
||||
async def on_POST(self, request: Request):
|
||||
if self._threepid_behaviour_email == ThreepidBehaviour.OFF:
|
||||
if self._local_threepid_handling_disabled_due_to_email_config:
|
||||
logger.warning(
|
||||
"Password reset emails have been disabled due to lack of an email config"
|
||||
)
|
||||
raise SynapseError(
|
||||
400, "Email-based password resets are disabled on this server"
|
||||
)
|
||||
elif self._threepid_behaviour_email == ThreepidBehaviour.REMOTE:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Password resets for this homeserver are handled by a separate program",
|
||||
)
|
||||
|
||||
logger.info("ARGS: %s, CONTENT: %s, HEADERS: %s", request.args, request.content,
|
||||
request.getAllHeaders())
|
||||
|
||||
sid = parse_string(request, "sid", required=True)
|
||||
token = parse_string(request, "token", required=True)
|
||||
client_secret = parse_string(request, "client_secret", required=True)
|
||||
|
||||
# Attempt to validate a 3PID session
|
||||
try:
|
||||
# Mark the session as valid
|
||||
@@ -207,7 +283,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
|
||||
return None
|
||||
|
||||
# Otherwise show the success template
|
||||
html = self.config.email_password_reset_template_success_html_content
|
||||
html = self._email_password_reset_template_success_html
|
||||
status_code = 200
|
||||
except ThreepidValidationError as e:
|
||||
status_code = e.code
|
||||
@@ -891,6 +967,7 @@ class WhoamiRestServlet(RestServlet):
|
||||
def register_servlets(hs, http_server):
|
||||
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
|
||||
PasswordResetSubmitTokenServlet(hs).register(http_server)
|
||||
PasswordResetConfirmationSubmitTokenServlet(hs).register(http_server)
|
||||
PasswordRestServlet(hs).register(http_server)
|
||||
DeactivateAccountRestServlet(hs).register(http_server)
|
||||
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -29,9 +29,11 @@ from typing import (
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from prometheus_client import Histogram
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.enterprise import adbapi
|
||||
from twisted.internet import defer
|
||||
@@ -1020,14 +1022,36 @@ class DatabasePool(object):
|
||||
|
||||
return txn.execute_batch(sql, args)
|
||||
|
||||
def simple_select_one(
|
||||
@overload
|
||||
async def simple_select_one(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcols: Iterable[str],
|
||||
allow_none: Literal[False] = False,
|
||||
desc: str = "simple_select_one",
|
||||
) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def simple_select_one(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcols: Iterable[str],
|
||||
allow_none: Literal[True] = True,
|
||||
desc: str = "simple_select_one",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
...
|
||||
|
||||
async def simple_select_one(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcols: Iterable[str],
|
||||
allow_none: bool = False,
|
||||
desc: str = "simple_select_one",
|
||||
) -> defer.Deferred:
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning multiple columns from it.
|
||||
|
||||
@@ -1038,18 +1062,18 @@ class DatabasePool(object):
|
||||
allow_none: If true, return None instead of failing if the SELECT
|
||||
statement returns no rows
|
||||
"""
|
||||
return self.runInteraction(
|
||||
return await self.runInteraction(
|
||||
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
|
||||
)
|
||||
|
||||
def simple_select_one_onecol(
|
||||
async def simple_select_one_onecol(
|
||||
self,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
allow_none: bool = False,
|
||||
desc: str = "simple_select_one_onecol",
|
||||
) -> defer.Deferred:
|
||||
) -> Optional[Any]:
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it.
|
||||
|
||||
@@ -1061,7 +1085,7 @@ class DatabasePool(object):
|
||||
statement returns no rows
|
||||
desc: description of the transaction, for logging and metrics
|
||||
"""
|
||||
return self.runInteraction(
|
||||
return await self.runInteraction(
|
||||
desc,
|
||||
self.simple_select_one_onecol_txn,
|
||||
table,
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, StoreError
|
||||
from synapse.logging.opentracing import (
|
||||
@@ -47,7 +47,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
|
||||
|
||||
|
||||
class DeviceWorkerStore(SQLBaseStore):
|
||||
def get_device(self, user_id: str, device_id: str):
|
||||
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
|
||||
"""Retrieve a device. Only returns devices that are not marked as
|
||||
hidden.
|
||||
|
||||
@@ -55,11 +55,11 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
user_id: The ID of the user which owns the device
|
||||
device_id: The ID of the device to retrieve
|
||||
Returns:
|
||||
defer.Deferred for a dict containing the device information
|
||||
A dict containing the device information
|
||||
Raises:
|
||||
StoreError: if the device is not found
|
||||
"""
|
||||
return self.db_pool.simple_select_one(
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
|
||||
retcols=("user_id", "device_id", "display_name"),
|
||||
@@ -656,11 +656,13 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def get_device_list_last_stream_id_for_remote(self, user_id: str):
|
||||
async def get_device_list_last_stream_id_for_remote(
|
||||
self, user_id: str
|
||||
) -> Optional[Any]:
|
||||
"""Get the last stream_id we got for a user. May be None if we haven't
|
||||
got any information for them.
|
||||
"""
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="device_lists_remote_extremeties",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="stream_id",
|
||||
|
||||
@@ -59,8 +59,8 @@ class DirectoryWorkerStore(SQLBaseStore):
|
||||
|
||||
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
|
||||
|
||||
def get_room_alias_creator(self, room_alias):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def get_room_alias_creator(self, room_alias: str) -> str:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="room_aliases",
|
||||
keyvalues={"room_alias": room_alias},
|
||||
retcol="creator",
|
||||
|
||||
@@ -223,15 +223,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
|
||||
return ret
|
||||
|
||||
def count_e2e_room_keys(self, user_id, version):
|
||||
async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
|
||||
"""Get the number of keys in a backup version.
|
||||
|
||||
Args:
|
||||
user_id (str): the user whose backup we're querying
|
||||
version (str): the version ID of the backup we're querying about
|
||||
user_id: the user whose backup we're querying
|
||||
version: the version ID of the backup we're querying about
|
||||
"""
|
||||
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="e2e_room_keys",
|
||||
keyvalues={"user_id": user_id, "version": version},
|
||||
retcol="COUNT(*)",
|
||||
|
||||
@@ -113,25 +113,25 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == EventsStream.NAME:
|
||||
self._stream_id_gen.advance(token)
|
||||
self._stream_id_gen.advance(instance_name, token)
|
||||
elif stream_name == BackfillStream.NAME:
|
||||
self._backfill_id_gen.advance(-token)
|
||||
self._backfill_id_gen.advance(instance_name, -token)
|
||||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def get_received_ts(self, event_id):
|
||||
async def get_received_ts(self, event_id: str) -> Optional[int]:
|
||||
"""Get received_ts (when it was persisted) for the event.
|
||||
|
||||
Raises an exception for unknown events.
|
||||
|
||||
Args:
|
||||
event_id (str)
|
||||
event_id: The event ID to query.
|
||||
|
||||
Returns:
|
||||
Deferred[int|None]: Timestamp in milliseconds, or None for events
|
||||
that were persisted before received_ts was implemented.
|
||||
Timestamp in milliseconds, or None for events that were persisted
|
||||
before received_ts was implemented.
|
||||
"""
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="events",
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="received_ts",
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
@@ -28,8 +28,8 @@ _DEFAULT_ROLE_ID = ""
|
||||
|
||||
|
||||
class GroupServerWorkerStore(SQLBaseStore):
|
||||
def get_group(self, group_id):
|
||||
return self.db_pool.simple_select_one(
|
||||
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="groups",
|
||||
keyvalues={"group_id": group_id},
|
||||
retcols=(
|
||||
@@ -351,8 +351,10 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
def is_user_admin_in_group(self, group_id, user_id):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def is_user_admin_in_group(
|
||||
self, group_id: str, user_id: str
|
||||
) -> Optional[bool]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="group_users",
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
retcol="is_admin",
|
||||
@@ -360,10 +362,12 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
desc="is_user_admin_in_group",
|
||||
)
|
||||
|
||||
def is_user_invited_to_local_group(self, group_id, user_id):
|
||||
async def is_user_invited_to_local_group(
|
||||
self, group_id: str, user_id: str
|
||||
) -> Optional[bool]:
|
||||
"""Has the group server invited a user?
|
||||
"""
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="group_invites",
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
retcol="user_id",
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool
|
||||
|
||||
@@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
def get_local_media(self, media_id):
|
||||
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the metadata for a local piece of media
|
||||
|
||||
Returns:
|
||||
None if the media_id doesn't exist.
|
||||
"""
|
||||
return self.db_pool.simple_select_one(
|
||||
return await self.db_pool.simple_select_one(
|
||||
"local_media_repository",
|
||||
{"media_id": media_id},
|
||||
(
|
||||
@@ -191,8 +194,10 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
desc="store_local_thumbnail",
|
||||
)
|
||||
|
||||
def get_cached_remote_media(self, origin, media_id):
|
||||
return self.db_pool.simple_select_one(
|
||||
async def get_cached_remote_media(
|
||||
self, origin, media_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
"remote_media_cache",
|
||||
{"media_origin": origin, "media_id": media_id},
|
||||
(
|
||||
|
||||
@@ -99,17 +99,18 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
|
||||
return users
|
||||
|
||||
@cached(num_args=1)
|
||||
def user_last_seen_monthly_active(self, user_id):
|
||||
async def user_last_seen_monthly_active(self, user_id: str) -> int:
|
||||
"""
|
||||
Checks if a given user is part of the monthly active user group
|
||||
Arguments:
|
||||
user_id (str): user to add/update
|
||||
Return:
|
||||
Deferred[int] : timestamp since last seen, None if never seen
|
||||
Checks if a given user is part of the monthly active user group
|
||||
|
||||
Arguments:
|
||||
user_id: user to add/update
|
||||
|
||||
Return:
|
||||
Timestamp since last seen, None if never seen
|
||||
"""
|
||||
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="monthly_active_users",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="timestamp",
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
@@ -19,7 +20,7 @@ from synapse.storage.databases.main.roommember import ProfileInfo
|
||||
|
||||
|
||||
class ProfileWorkerStore(SQLBaseStore):
|
||||
async def get_profileinfo(self, user_localpart):
|
||||
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
|
||||
try:
|
||||
profile = await self.db_pool.simple_select_one(
|
||||
table="profiles",
|
||||
@@ -38,24 +39,26 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
|
||||
)
|
||||
|
||||
def get_profile_displayname(self, user_localpart):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def get_profile_displayname(self, user_localpart: str) -> str:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="displayname",
|
||||
desc="get_profile_displayname",
|
||||
)
|
||||
|
||||
def get_profile_avatar_url(self, user_localpart):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def get_profile_avatar_url(self, user_localpart: str) -> str:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="profiles",
|
||||
keyvalues={"user_id": user_localpart},
|
||||
retcol="avatar_url",
|
||||
desc="get_profile_avatar_url",
|
||||
)
|
||||
|
||||
def get_from_remote_profile_cache(self, user_id):
|
||||
return self.db_pool.simple_select_one(
|
||||
async def get_from_remote_profile_cache(
|
||||
self, user_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("displayname", "avatar_url"),
|
||||
|
||||
@@ -71,8 +71,10 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@cached(num_args=3)
|
||||
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def get_last_receipt_event_id_for_user(
|
||||
self, user_id: str, room_id: str, receipt_type: str
|
||||
) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="receipts_linearized",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Awaitable, Dict, List, Optional
|
||||
from typing import Any, Awaitable, Dict, List, Optional
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||
@@ -46,8 +46,8 @@ class RegistrationWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_user_by_id(self, user_id):
|
||||
return self.db_pool.simple_select_one(
|
||||
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="users",
|
||||
keyvalues={"name": user_id},
|
||||
retcols=[
|
||||
@@ -1259,12 +1259,12 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
desc="del_user_pending_deactivation",
|
||||
)
|
||||
|
||||
def get_user_pending_deactivation(self):
|
||||
async def get_user_pending_deactivation(self) -> Optional[str]:
|
||||
"""
|
||||
Gets one user from the table of users waiting to be parted from all the rooms
|
||||
they're in.
|
||||
"""
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
"users_pending_deactivation",
|
||||
keyvalues={},
|
||||
retcol="user_id",
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
|
||||
@@ -21,8 +22,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RejectionsStore(SQLBaseStore):
|
||||
def get_rejection_reason(self, event_id):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def get_rejection_reason(self, event_id: str) -> Optional[str]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="rejections",
|
||||
retcol="reason",
|
||||
keyvalues={"event_id": event_id},
|
||||
|
||||
@@ -73,15 +73,15 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
|
||||
self.config = hs.config
|
||||
|
||||
def get_room(self, room_id):
|
||||
async def get_room(self, room_id: str) -> dict:
|
||||
"""Retrieve a room.
|
||||
|
||||
Args:
|
||||
room_id (str): The ID of the room to retrieve.
|
||||
room_id: The ID of the room to retrieve.
|
||||
Returns:
|
||||
A dict containing the room information, or None if the room is unknown.
|
||||
"""
|
||||
return self.db_pool.simple_select_one(
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcols=("room_id", "is_public", "creator"),
|
||||
@@ -330,8 +330,8 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
return ret_val
|
||||
|
||||
@cached(max_entries=10000)
|
||||
def is_room_blocked(self, room_id):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def is_room_blocked(self, room_id: str) -> Optional[bool]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="blocked_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="1",
|
||||
|
||||
@@ -260,8 +260,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
return event.content.get("canonical_alias")
|
||||
|
||||
@cached(max_entries=50000)
|
||||
def _get_state_group_for_event(self, event_id):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="event_to_state_groups",
|
||||
keyvalues={"event_id": event_id},
|
||||
retcol="state_group",
|
||||
|
||||
@@ -211,11 +211,11 @@ class StatsStore(StateDeltasStore):
|
||||
|
||||
return len(rooms_to_work_on)
|
||||
|
||||
def get_stats_positions(self):
|
||||
async def get_stats_positions(self) -> int:
|
||||
"""
|
||||
Returns the stats processor positions.
|
||||
"""
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="stats_incremental_position",
|
||||
keyvalues={},
|
||||
retcol="stream_id",
|
||||
@@ -300,7 +300,7 @@ class StatsStore(StateDeltasStore):
|
||||
return slice_list
|
||||
|
||||
@cached()
|
||||
def get_earliest_token_for_stats(self, stats_type, id):
|
||||
async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
|
||||
"""
|
||||
Fetch the "earliest token". This is used by the room stats delta
|
||||
processor to ignore deltas that have been processed between the
|
||||
@@ -308,11 +308,11 @@ class StatsStore(StateDeltasStore):
|
||||
being calculated.
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
The earliest token.
|
||||
"""
|
||||
table, id_col = TYPE_TO_TABLE[stats_type]
|
||||
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
"%s_current" % (table,),
|
||||
keyvalues={id_col: id},
|
||||
retcol="completed_delta_stream_id",
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules
|
||||
from synapse.storage.database import DatabasePool
|
||||
@@ -527,8 +528,8 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
)
|
||||
|
||||
@cached()
|
||||
def get_user_in_directory(self, user_id):
|
||||
return self.db_pool.simple_select_one(
|
||||
async def get_user_in_directory(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="user_directory",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("display_name", "avatar_url"),
|
||||
@@ -663,8 +664,8 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
users.update(rows)
|
||||
return list(users)
|
||||
|
||||
def get_user_directory_stream_pos(self):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
async def get_user_directory_stream_pos(self) -> int:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
table="user_directory_stream_pos",
|
||||
keyvalues={},
|
||||
retcol="stream_id",
|
||||
|
||||
@@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_my_name(self):
|
||||
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
)
|
||||
|
||||
displayname = yield defer.ensureDeferred(
|
||||
self.handler.get_displayname(self.frank)
|
||||
@@ -104,7 +106,12 @@ class ProfileTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
"Frank",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -112,10 +119,17 @@ class ProfileTestCase(unittest.TestCase):
|
||||
self.hs.config.enable_set_displayname = False
|
||||
|
||||
# Setting displayname for the first time is allowed
|
||||
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_displayname(self.frank.localpart, "Frank")
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
"Frank",
|
||||
)
|
||||
|
||||
# Setting displayname a second time is forbidden
|
||||
@@ -158,7 +172,9 @@ class ProfileTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_incoming_fed_query(self):
|
||||
yield defer.ensureDeferred(self.store.create_profile("caroline"))
|
||||
yield self.store.set_profile_displayname("caroline", "Caroline")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_displayname("caroline", "Caroline")
|
||||
)
|
||||
|
||||
response = yield defer.ensureDeferred(
|
||||
self.query_handlers["profile"](
|
||||
@@ -170,8 +186,10 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_my_avatar(self):
|
||||
yield self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
)
|
||||
)
|
||||
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
|
||||
|
||||
@@ -188,7 +206,11 @@ class ProfileTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_avatar_url(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
"http://my.server/pic.gif",
|
||||
)
|
||||
|
||||
@@ -202,7 +224,11 @@ class ProfileTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_avatar_url(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
"http://my.server/me.png",
|
||||
)
|
||||
|
||||
@@ -211,12 +237,18 @@ class ProfileTestCase(unittest.TestCase):
|
||||
self.hs.config.enable_set_avatar_url = False
|
||||
|
||||
# Setting displayname for the first time is allowed
|
||||
yield self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_avatar_url(self.frank.localpart)),
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_avatar_url(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
"http://my.server/me.png",
|
||||
)
|
||||
|
||||
|
||||
@@ -144,9 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.datastore.get_users_in_room = get_users_in_room
|
||||
|
||||
self.datastore.get_user_directory_stream_pos.return_value = (
|
||||
self.datastore.get_user_directory_stream_pos.side_effect = (
|
||||
# we deliberately return a non-None stream pos to avoid doing an initial_spam
|
||||
defer.succeed(1)
|
||||
lambda: make_awaitable(1)
|
||||
)
|
||||
|
||||
self.datastore.get_current_state_deltas.return_value = (0, None)
|
||||
|
||||
@@ -35,7 +35,7 @@ class ModuleApiTestCase(HomeserverTestCase):
|
||||
# Check that the new user exists with all provided attributes
|
||||
self.assertEqual(user_id, "@bob:test")
|
||||
self.assertTrue(access_token)
|
||||
self.assertTrue(self.store.get_user_by_id(user_id))
|
||||
self.assertTrue(self.get_success(self.store.get_user_by_id(user_id)))
|
||||
|
||||
# Check that the email was assigned
|
||||
emails = self.get_success(self.store.user_get_threepids(user_id))
|
||||
|
||||
272
tests/rest/client/test_shadow_banned.py
Normal file
272
tests/rest/client/test_shadow_banned.py
Normal file
@@ -0,0 +1,272 @@
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from mock import Mock, patch
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.rest.client.v1 import directory, login, profile, room
|
||||
from synapse.rest.client.v2_alpha import room_upgrade_rest_servlet
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class _ShadowBannedBase(unittest.HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
# Create two users, one of which is shadow-banned.
|
||||
self.banned_user_id = self.register_user("banned", "test")
|
||||
self.banned_access_token = self.login("banned", "test")
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_update(
|
||||
table="users",
|
||||
keyvalues={"name": self.banned_user_id},
|
||||
updatevalues={"shadow_banned": True},
|
||||
desc="shadow_ban",
|
||||
)
|
||||
)
|
||||
|
||||
self.other_user_id = self.register_user("otheruser", "pass")
|
||||
self.other_access_token = self.login("otheruser", "pass")
|
||||
|
||||
|
||||
# To avoid the tests timing out don't add a delay to "annoy the requester".
|
||||
@patch("random.randint", new=lambda a, b: 0)
|
||||
class RoomTestCase(_ShadowBannedBase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
directory.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
room_upgrade_rest_servlet.register_servlets,
|
||||
]
|
||||
|
||||
def test_invite(self):
|
||||
"""Invites from shadow-banned users don't actually get sent."""
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# Inviting the user completes successfully.
|
||||
self.helper.invite(
|
||||
room=room_id,
|
||||
src=self.banned_user_id,
|
||||
tok=self.banned_access_token,
|
||||
targ=self.other_user_id,
|
||||
)
|
||||
|
||||
# But the user wasn't actually invited.
|
||||
invited_rooms = self.get_success(
|
||||
self.store.get_invited_rooms_for_local_user(self.other_user_id)
|
||||
)
|
||||
self.assertEqual(invited_rooms, [])
|
||||
|
||||
def test_invite_3pid(self):
|
||||
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
identity_handler.lookup_3pid = Mock(
|
||||
side_effect=AssertionError("This should not get called")
|
||||
)
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# Inviting the user completes successfully.
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/rooms/%s/invite" % (room_id,),
|
||||
{"id_server": "test", "medium": "email", "address": "test@test.test"},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
|
||||
# This should have raised an error earlier, but double check this wasn't called.
|
||||
identity_handler.lookup_3pid.assert_not_called()
|
||||
|
||||
def test_create_room(self):
|
||||
"""Invitations during a room creation should be discarded, but the room still gets created."""
|
||||
# The room creation is successful.
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/createRoom",
|
||||
{"visibility": "public", "invite": [self.other_user_id]},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
room_id = channel.json_body["room_id"]
|
||||
|
||||
# But the user wasn't actually invited.
|
||||
invited_rooms = self.get_success(
|
||||
self.store.get_invited_rooms_for_local_user(self.other_user_id)
|
||||
)
|
||||
self.assertEqual(invited_rooms, [])
|
||||
|
||||
# Since a real room was created, the other user should be able to join it.
|
||||
self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
|
||||
|
||||
# Both users should be in the room.
|
||||
users = self.get_success(self.store.get_users_in_room(room_id))
|
||||
self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
|
||||
|
||||
def test_message(self):
|
||||
"""Messages from shadow-banned users don't actually get sent."""
|
||||
|
||||
room_id = self.helper.create_room_as(
|
||||
self.other_user_id, tok=self.other_access_token
|
||||
)
|
||||
|
||||
# The user should be in the room.
|
||||
self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
|
||||
|
||||
# Sending a message should complete successfully.
|
||||
result = self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={"msgtype": "m.text", "body": "with right label"},
|
||||
tok=self.banned_access_token,
|
||||
)
|
||||
self.assertIn("event_id", result)
|
||||
event_id = result["event_id"]
|
||||
|
||||
latest_events = self.get_success(
|
||||
self.store.get_latest_event_ids_in_room(room_id)
|
||||
)
|
||||
self.assertNotIn(event_id, latest_events)
|
||||
|
||||
def test_upgrade(self):
|
||||
"""A room upgrade should fail, but look like it succeeded."""
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
|
||||
{"new_version": "6"},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
# A new room_id should be returned.
|
||||
self.assertIn("replacement_room", channel.json_body)
|
||||
|
||||
new_room_id = channel.json_body["replacement_room"]
|
||||
|
||||
# It doesn't really matter what API we use here, we just want to assert
|
||||
# that the room doesn't exist.
|
||||
summary = self.get_success(self.store.get_room_summary(new_room_id))
|
||||
# The summary should be empty since the room doesn't exist.
|
||||
self.assertEqual(summary, {})
|
||||
|
||||
|
||||
# To avoid the tests timing out don't add a delay to "annoy the requester".
|
||||
@patch("random.randint", new=lambda a, b: 0)
|
||||
class ProfileTestCase(_ShadowBannedBase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
login.register_servlets,
|
||||
profile.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def test_displayname(self):
|
||||
"""Profile changes should succeed, but don't end up in a room."""
|
||||
original_display_name = "banned"
|
||||
new_display_name = "new name"
|
||||
|
||||
# Join a room.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# The update should succeed.
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
"/_matrix/client/r0/profile/%s/displayname" % (self.banned_user_id,),
|
||||
{"displayname": new_display_name},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
# The user's display name should be updated.
|
||||
request, channel = self.make_request(
|
||||
"GET", "/profile/%s/displayname" % (self.banned_user_id,)
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(channel.json_body["displayname"], new_display_name)
|
||||
|
||||
# But the display name in the room should not be.
|
||||
message_handler = self.hs.get_message_handler()
|
||||
event = self.get_success(
|
||||
message_handler.get_room_data(
|
||||
self.banned_user_id,
|
||||
room_id,
|
||||
"m.room.member",
|
||||
self.banned_user_id,
|
||||
False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
event.content, {"membership": "join", "displayname": original_display_name}
|
||||
)
|
||||
|
||||
def test_room_displayname(self):
|
||||
"""Changes to state events for a room should be processed, but not end up in the room."""
|
||||
original_display_name = "banned"
|
||||
new_display_name = "new name"
|
||||
|
||||
# Join a room.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# The update should succeed.
|
||||
request, channel = self.make_request(
|
||||
"PUT",
|
||||
"/_matrix/client/r0/rooms/%s/state/m.room.member/%s"
|
||||
% (room_id, self.banned_user_id),
|
||||
{"membership": "join", "displayname": new_display_name},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
self.assertIn("event_id", channel.json_body)
|
||||
|
||||
# The display name in the room should not be changed.
|
||||
message_handler = self.hs.get_message_handler()
|
||||
event = self.get_success(
|
||||
message_handler.get_room_data(
|
||||
self.banned_user_id,
|
||||
room_id,
|
||||
"m.room.member",
|
||||
self.banned_user_id,
|
||||
False,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
event.content, {"membership": "join", "displayname": original_display_name}
|
||||
)
|
||||
@@ -267,7 +267,9 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
auth = {
|
||||
"type": "m.login.password",
|
||||
"identifier": {"type": "m.id.user", "user": user_id},
|
||||
# https://github.com/matrix-org/synapse/issues/5665
|
||||
# "identifier": {"type": "m.id.user", "user": user_id},
|
||||
"user": user_id,
|
||||
"password": password,
|
||||
"session": channel.json_body["session"],
|
||||
}
|
||||
|
||||
@@ -21,13 +21,13 @@
|
||||
import json
|
||||
from urllib import parse as urlparse
|
||||
|
||||
from mock import Mock, patch
|
||||
from mock import Mock
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.handlers.pagination import PurgeStatus
|
||||
from synapse.rest.client.v1 import directory, login, profile, room
|
||||
from synapse.rest.client.v2_alpha import account, room_upgrade_rest_servlet
|
||||
from synapse.rest.client.v2_alpha import account
|
||||
from synapse.types import JsonDict, RoomAlias, UserID
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
@@ -684,38 +684,39 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
]
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
|
||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit(self):
|
||||
"""Tests that local joins are actually rate-limited."""
|
||||
for i in range(5):
|
||||
for i in range(3):
|
||||
self.helper.create_room_as(self.user_id)
|
||||
|
||||
self.helper.create_room_as(self.user_id, expect_code=429)
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
|
||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit_profile_change(self):
|
||||
"""Tests that sending a profile update into all of the user's joined rooms isn't
|
||||
rate-limited by the rate-limiter on joins."""
|
||||
|
||||
# Create and join more rooms than the rate-limiting config allows in a second.
|
||||
# Create and join as many rooms as the rate-limiting config allows in a second.
|
||||
room_ids = [
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
]
|
||||
self.reactor.advance(1)
|
||||
room_ids = room_ids + [
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
self.helper.create_room_as(self.user_id),
|
||||
]
|
||||
# Let some time for the rate-limiter to forget about our multi-join.
|
||||
self.reactor.advance(2)
|
||||
# Add one to make sure we're joined to more rooms than the config allows us to
|
||||
# join in a second.
|
||||
room_ids.append(self.helper.create_room_as(self.user_id))
|
||||
|
||||
# Create a profile for the user, since it hasn't been done on registration.
|
||||
store = self.hs.get_datastore()
|
||||
store.create_profile(UserID.from_string(self.user_id).localpart)
|
||||
self.get_success(
|
||||
store.create_profile(UserID.from_string(self.user_id).localpart)
|
||||
)
|
||||
|
||||
# Update the display name for the user.
|
||||
path = "/_matrix/client/r0/profile/%s/displayname" % self.user_id
|
||||
@@ -738,7 +739,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
self.assertEquals(channel.json_body["displayname"], "John Doe")
|
||||
|
||||
@unittest.override_config(
|
||||
{"rc_joins": {"local": {"per_second": 3, "burst_count": 3}}}
|
||||
{"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}}
|
||||
)
|
||||
def test_join_local_ratelimit_idempotent(self):
|
||||
"""Tests that the room join endpoints remain idempotent despite rate-limiting
|
||||
@@ -754,7 +755,7 @@ class RoomJoinRatelimitTestCase(RoomBase):
|
||||
for path in paths_to_test:
|
||||
# Make sure we send more requests than the rate-limiting config would allow
|
||||
# if all of these requests ended up joining the user to a room.
|
||||
for i in range(6):
|
||||
for i in range(4):
|
||||
request, channel = self.make_request("POST", path % room_id, {})
|
||||
self.render(request)
|
||||
self.assertEquals(channel.code, 200)
|
||||
@@ -2059,158 +2060,3 @@ class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
"""An alias which does not point to the room raises a SynapseError."""
|
||||
self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
|
||||
self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
|
||||
|
||||
|
||||
# To avoid the tests timing out don't add a delay to "annoy the requester".
|
||||
@patch("random.randint", new=lambda a, b: 0)
|
||||
class ShadowBannedTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
directory.register_servlets,
|
||||
login.register_servlets,
|
||||
room.register_servlets,
|
||||
room_upgrade_rest_servlet.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.banned_user_id = self.register_user("banned", "test")
|
||||
self.banned_access_token = self.login("banned", "test")
|
||||
|
||||
self.store = self.hs.get_datastore()
|
||||
|
||||
self.get_success(
|
||||
self.store.db_pool.simple_update(
|
||||
table="users",
|
||||
keyvalues={"name": self.banned_user_id},
|
||||
updatevalues={"shadow_banned": True},
|
||||
desc="shadow_ban",
|
||||
)
|
||||
)
|
||||
|
||||
self.other_user_id = self.register_user("otheruser", "pass")
|
||||
self.other_access_token = self.login("otheruser", "pass")
|
||||
|
||||
def test_invite(self):
|
||||
"""Invites from shadow-banned users don't actually get sent."""
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# Inviting the user completes successfully.
|
||||
self.helper.invite(
|
||||
room=room_id,
|
||||
src=self.banned_user_id,
|
||||
tok=self.banned_access_token,
|
||||
targ=self.other_user_id,
|
||||
)
|
||||
|
||||
# But the user wasn't actually invited.
|
||||
invited_rooms = self.get_success(
|
||||
self.store.get_invited_rooms_for_local_user(self.other_user_id)
|
||||
)
|
||||
self.assertEqual(invited_rooms, [])
|
||||
|
||||
def test_invite_3pid(self):
|
||||
"""Ensure that a 3PID invite does not attempt to contact the identity server."""
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
identity_handler.lookup_3pid = Mock(
|
||||
side_effect=AssertionError("This should not get called")
|
||||
)
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
# Inviting the user completes successfully.
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/rooms/%s/invite" % (room_id,),
|
||||
{"id_server": "test", "medium": "email", "address": "test@test.test"},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
|
||||
# This should have raised an error earlier, but double check this wasn't called.
|
||||
identity_handler.lookup_3pid.assert_not_called()
|
||||
|
||||
def test_create_room(self):
|
||||
"""Invitations during a room creation should be discarded, but the room still gets created."""
|
||||
# The room creation is successful.
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/createRoom",
|
||||
{"visibility": "public", "invite": [self.other_user_id]},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
room_id = channel.json_body["room_id"]
|
||||
|
||||
# But the user wasn't actually invited.
|
||||
invited_rooms = self.get_success(
|
||||
self.store.get_invited_rooms_for_local_user(self.other_user_id)
|
||||
)
|
||||
self.assertEqual(invited_rooms, [])
|
||||
|
||||
# Since a real room was created, the other user should be able to join it.
|
||||
self.helper.join(room_id, self.other_user_id, tok=self.other_access_token)
|
||||
|
||||
# Both users should be in the room.
|
||||
users = self.get_success(self.store.get_users_in_room(room_id))
|
||||
self.assertCountEqual(users, ["@banned:test", "@otheruser:test"])
|
||||
|
||||
def test_message(self):
|
||||
"""Messages from shadow-banned users don't actually get sent."""
|
||||
|
||||
room_id = self.helper.create_room_as(
|
||||
self.other_user_id, tok=self.other_access_token
|
||||
)
|
||||
|
||||
# The user should be in the room.
|
||||
self.helper.join(room_id, self.banned_user_id, tok=self.banned_access_token)
|
||||
|
||||
# Sending a message should complete successfully.
|
||||
result = self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={"msgtype": "m.text", "body": "with right label"},
|
||||
tok=self.banned_access_token,
|
||||
)
|
||||
self.assertIn("event_id", result)
|
||||
event_id = result["event_id"]
|
||||
|
||||
latest_events = self.get_success(
|
||||
self.store.get_latest_event_ids_in_room(room_id)
|
||||
)
|
||||
self.assertNotIn(event_id, latest_events)
|
||||
|
||||
def test_upgrade(self):
|
||||
"""A room upgrade should fail, but look like it succeeded."""
|
||||
|
||||
# The create works fine.
|
||||
room_id = self.helper.create_room_as(
|
||||
self.banned_user_id, tok=self.banned_access_token
|
||||
)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
"/_matrix/client/r0/rooms/%s/upgrade" % (room_id,),
|
||||
{"new_version": "6"},
|
||||
access_token=self.banned_access_token,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
# A new room_id should be returned.
|
||||
self.assertIn("replacement_room", channel.json_body)
|
||||
|
||||
new_room_id = channel.json_body["replacement_room"]
|
||||
|
||||
# It doesn't really matter what API we use here, we just want to assert
|
||||
# that the room doesn't exist.
|
||||
summary = self.get_success(self.store.get_room_summary(new_room_id))
|
||||
# The summary should be empty since the room doesn't exist.
|
||||
self.assertEqual(summary, {})
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from urllib.parse import urlencode
|
||||
import os
|
||||
import re
|
||||
from email.parser import Parser
|
||||
@@ -70,6 +71,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@unittest.INFO
|
||||
def test_basic_password_reset(self):
|
||||
"""Test basic password reset flow
|
||||
"""
|
||||
@@ -250,10 +252,33 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
|
||||
# Remove the host
|
||||
path = link.replace("https://example.com", "")
|
||||
|
||||
# Load the password reset confirmation page
|
||||
request, channel = self.make_request("GET", path, shorthand=False)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
|
||||
# Replace the path with the confirmation path
|
||||
path = "/_matrix/client/unstable/password_reset/email/submit_token_confirm"
|
||||
|
||||
form_args = []
|
||||
for key, value_list in request.args.items():
|
||||
for value in value_list:
|
||||
arg = (key, value)
|
||||
form_args.append(arg)
|
||||
|
||||
print("form_args:", form_args)
|
||||
print("encoded form_args:", urlencode(form_args))
|
||||
|
||||
# Confirm the password reset
|
||||
request, channel = self.make_request(
|
||||
"POST",
|
||||
path,
|
||||
content=urlencode(form_args).encode("utf8"),
|
||||
shorthand=False,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEquals(200, channel.code, channel.result)
|
||||
|
||||
def _get_link_from_email(self):
|
||||
assert self.email_attempts, "No emails have been sent"
|
||||
|
||||
|
||||
@@ -38,6 +38,11 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
|
||||
return succeed(True)
|
||||
|
||||
|
||||
class DummyPasswordChecker(UserInteractiveAuthChecker):
|
||||
def check_auth(self, authdict, clientip):
|
||||
return succeed(authdict["identifier"]["user"])
|
||||
|
||||
|
||||
class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
@@ -161,6 +166,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
auth_handler = hs.get_auth_handler()
|
||||
auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs)
|
||||
|
||||
self.user_pass = "pass"
|
||||
self.user = self.register_user("test", self.user_pass)
|
||||
self.user_tok = self.login("test", self.user_pass)
|
||||
|
||||
@@ -593,89 +593,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(len(self.email_attempts), 0)
|
||||
|
||||
def test_deactivated_user_using_user_identifier(self):
|
||||
self.email_attempts = []
|
||||
|
||||
(user_id, tok) = self.create_user()
|
||||
|
||||
request_data = json.dumps(
|
||||
{
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
"identifier": {"type": "m.id.user", "user": user_id},
|
||||
"password": "monkey",
|
||||
},
|
||||
"erase": False,
|
||||
}
|
||||
)
|
||||
request, channel = self.make_request(
|
||||
"POST", "account/deactivate", request_data, access_token=tok
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(request.code, 200)
|
||||
|
||||
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
|
||||
|
||||
self.assertEqual(len(self.email_attempts), 0)
|
||||
|
||||
def test_deactivated_user_using_thirdparty_identifier(self):
|
||||
self.email_attempts = []
|
||||
|
||||
(user_id, tok) = self.create_user()
|
||||
|
||||
request_data = json.dumps(
|
||||
{
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
"identifier": {
|
||||
"type": "m.id.thirdparty",
|
||||
"medium": "email",
|
||||
"address": "kermit@example.com",
|
||||
},
|
||||
"password": "monkey",
|
||||
},
|
||||
"erase": False,
|
||||
}
|
||||
)
|
||||
request, channel = self.make_request(
|
||||
"POST", "account/deactivate", request_data, access_token=tok
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(request.code, 200)
|
||||
|
||||
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
|
||||
|
||||
self.assertEqual(len(self.email_attempts), 0)
|
||||
|
||||
def test_deactivated_user_using_phone_identifier(self):
|
||||
self.email_attempts = []
|
||||
|
||||
(user_id, tok) = self.create_user()
|
||||
|
||||
request_data = json.dumps(
|
||||
{
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
"identifier": {
|
||||
"type": "m.id.phone",
|
||||
"country": "GB",
|
||||
"phone": "077-009-00001",
|
||||
},
|
||||
"password": "monkey",
|
||||
},
|
||||
"erase": False,
|
||||
}
|
||||
)
|
||||
request, channel = self.make_request(
|
||||
"POST", "account/deactivate", request_data, access_token=tok
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(request.code, 200)
|
||||
|
||||
self.reactor.advance(datetime.timedelta(days=8).total_seconds())
|
||||
|
||||
self.assertEqual(len(self.email_attempts), 0)
|
||||
|
||||
def create_user(self):
|
||||
user_id = self.register_user("kermit", "monkey")
|
||||
tok = self.login("kermit", "monkey")
|
||||
@@ -691,15 +608,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||
added_at=now,
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
self.store.user_add_threepid(
|
||||
user_id=user_id,
|
||||
medium="msisdn",
|
||||
address="447700900001",
|
||||
validated_at=now,
|
||||
added_at=now,
|
||||
)
|
||||
)
|
||||
return user_id, tok
|
||||
|
||||
def test_manual_email_send_expired_account(self):
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from json.decoder import JSONDecodeError
|
||||
|
||||
import attr
|
||||
from zope.interface import implementer
|
||||
@@ -195,7 +196,19 @@ def make_request(
|
||||
)
|
||||
|
||||
if content:
|
||||
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
|
||||
content_is_json = True
|
||||
try:
|
||||
json.loads(content)
|
||||
except JSONDecodeError:
|
||||
content_is_json = False
|
||||
|
||||
print("Content is json?", content_is_json, path)
|
||||
if content_is_json:
|
||||
req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
|
||||
else:
|
||||
req.requestHeaders.addRawHeader(
|
||||
b"Content-Type", b"application/x-www-form-urlencoded"
|
||||
)
|
||||
|
||||
req.requestReceived(method, path, b"1.1")
|
||||
|
||||
|
||||
@@ -97,8 +97,10 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
self.mock_txn.rowcount = 1
|
||||
self.mock_txn.__iter__ = Mock(return_value=iter([("Value",)]))
|
||||
|
||||
value = yield self.datastore.db_pool.simple_select_one_onecol(
|
||||
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
|
||||
value = yield defer.ensureDeferred(
|
||||
self.datastore.db_pool.simple_select_one_onecol(
|
||||
table="tablename", keyvalues={"keycol": "TheKey"}, retcol="retcol"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals("Value", value)
|
||||
@@ -111,10 +113,12 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
self.mock_txn.rowcount = 1
|
||||
self.mock_txn.fetchone.return_value = (1, 2, 3)
|
||||
|
||||
ret = yield self.datastore.db_pool.simple_select_one(
|
||||
table="tablename",
|
||||
keyvalues={"keycol": "TheKey"},
|
||||
retcols=["colA", "colB", "colC"],
|
||||
ret = yield defer.ensureDeferred(
|
||||
self.datastore.db_pool.simple_select_one(
|
||||
table="tablename",
|
||||
keyvalues={"keycol": "TheKey"},
|
||||
retcols=["colA", "colB", "colC"],
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals({"colA": 1, "colB": 2, "colC": 3}, ret)
|
||||
@@ -127,11 +131,13 @@ class SQLBaseStoreTestCase(unittest.TestCase):
|
||||
self.mock_txn.rowcount = 0
|
||||
self.mock_txn.fetchone.return_value = None
|
||||
|
||||
ret = yield self.datastore.db_pool.simple_select_one(
|
||||
table="tablename",
|
||||
keyvalues={"keycol": "Not here"},
|
||||
retcols=["colA"],
|
||||
allow_none=True,
|
||||
ret = yield defer.ensureDeferred(
|
||||
self.datastore.db_pool.simple_select_one(
|
||||
table="tablename",
|
||||
keyvalues={"keycol": "Not here"},
|
||||
retcols=["colA"],
|
||||
allow_none=True,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertFalse(ret)
|
||||
|
||||
@@ -38,7 +38,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
self.store.store_device("user_id", "device_id", "display_name")
|
||||
)
|
||||
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"user_id": "user_id",
|
||||
@@ -111,12 +111,12 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
self.store.store_device("user_id", "device_id", "display_name 1")
|
||||
)
|
||||
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
|
||||
self.assertEqual("display_name 1", res["display_name"])
|
||||
|
||||
# do a no-op first
|
||||
yield defer.ensureDeferred(self.store.update_device("user_id", "device_id"))
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
|
||||
self.assertEqual("display_name 1", res["display_name"])
|
||||
|
||||
# do the update
|
||||
@@ -127,7 +127,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
|
||||
)
|
||||
|
||||
# check it worked
|
||||
res = yield self.store.get_device("user_id", "device_id")
|
||||
res = yield defer.ensureDeferred(self.store.get_device("user_id", "device_id"))
|
||||
self.assertEqual("display_name 2", res["display_name"])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -35,21 +35,34 @@ class ProfileStoreTestCase(unittest.TestCase):
|
||||
def test_displayname(self):
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
|
||||
|
||||
yield self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_displayname(self.u_frank.localpart, "Frank")
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
"Frank", (yield self.store.get_profile_displayname(self.u_frank.localpart))
|
||||
"Frank",
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_displayname(self.u_frank.localpart)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_avatar_url(self):
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.u_frank.localpart))
|
||||
|
||||
yield self.store.set_profile_avatar_url(
|
||||
self.u_frank.localpart, "http://my.site/here"
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.u_frank.localpart, "http://my.site/here"
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
"http://my.site/here",
|
||||
(yield self.store.get_profile_avatar_url(self.u_frank.localpart)),
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_avatar_url(self.u_frank.localpart)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@@ -53,7 +53,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
||||
"user_type": None,
|
||||
"deactivated": 0,
|
||||
},
|
||||
(yield self.store.get_user_by_id(self.user_id)),
|
||||
(yield defer.ensureDeferred(self.store.get_user_by_id(self.user_id))),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -54,12 +54,14 @@ class RoomStoreTestCase(unittest.TestCase):
|
||||
"creator": self.u_creator.to_string(),
|
||||
"is_public": True,
|
||||
},
|
||||
(yield self.store.get_room(self.room.to_string())),
|
||||
(yield defer.ensureDeferred(self.store.get_room(self.room.to_string()))),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_room_unknown_room(self):
|
||||
self.assertIsNone((yield self.store.get_room("!uknown:test")),)
|
||||
self.assertIsNone(
|
||||
(yield defer.ensureDeferred(self.store.get_room("!uknown:test")))
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_room_with_stats(self):
|
||||
@@ -69,12 +71,22 @@ class RoomStoreTestCase(unittest.TestCase):
|
||||
"creator": self.u_creator.to_string(),
|
||||
"public": True,
|
||||
},
|
||||
(yield self.store.get_room_with_stats(self.room.to_string())),
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_room_with_stats(self.room.to_string())
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_room_with_stats_unknown_room(self):
|
||||
self.assertIsNone((yield self.store.get_room_with_stats("!uknown:test")),)
|
||||
self.assertIsNone(
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_room_with_stats("!uknown:test")
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class RoomEventsStoreTestCase(unittest.TestCase):
|
||||
|
||||
37
tests/test_utils/http.py
Normal file
37
tests/test_utils/http.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.web.server import Request
|
||||
|
||||
|
||||
def convert_request_args_to_form_data(request: Request) -> bytes:
|
||||
"""Converts query arguments from a request to formatted HTML form data
|
||||
|
||||
Ref: https://developer.mozilla.org/en-US/docs/Learn/Forms/Sending_and_retrieving_form_data
|
||||
|
||||
Args:
|
||||
The request to pull arguments from
|
||||
|
||||
Returns:
|
||||
The HTML form body data representation of the request's arguments
|
||||
"""
|
||||
body = b""
|
||||
for key, value in request.args.items():
|
||||
arg = b"%s=%s&" % (key, value[0])
|
||||
body += arg
|
||||
|
||||
# Remove the last '&' sign
|
||||
return body[:-1]
|
||||
Reference in New Issue
Block a user