Merge commit 'd34c6e127' into anoa/dinsic_release_1_31_0
This commit is contained in:
1
changelog.d/9093.misc
Normal file
1
changelog.d/9093.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type hints to media repository.
|
||||
1
changelog.d/9109.feature
Normal file
1
changelog.d/9109.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add support for multiple SSO Identity Providers.
|
||||
1
changelog.d/9112.misc
Normal file
1
changelog.d/9112.misc
Normal file
@@ -0,0 +1 @@
|
||||
Improve `UsernamePickerTestCase`.
|
||||
@@ -1,6 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 Quentin Gliech
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import string
|
||||
from typing import Optional, Type
|
||||
|
||||
import attr
|
||||
@@ -38,7 +39,7 @@ class OIDCConfig(Config):
|
||||
|
||||
oidc_config = config.get("oidc_config")
|
||||
if oidc_config and oidc_config.get("enabled", False):
|
||||
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, "oidc_config")
|
||||
validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
|
||||
self.oidc_provider = _parse_oidc_config_dict(oidc_config)
|
||||
|
||||
if not self.oidc_provider:
|
||||
@@ -205,6 +206,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
|
||||
"type": "object",
|
||||
"required": ["issuer", "client_id", "client_secret"],
|
||||
"properties": {
|
||||
"idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
|
||||
"idp_name": {"type": "string"},
|
||||
"discover": {"type": "boolean"},
|
||||
"issuer": {"type": "string"},
|
||||
"client_id": {"type": "string"},
|
||||
@@ -277,7 +280,17 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
|
||||
"methods: %s" % (", ".join(missing_methods),)
|
||||
)
|
||||
|
||||
# MSC2858 will appy certain limits in what can be used as an IdP id, so let's
|
||||
# enforce those limits now.
|
||||
idp_id = oidc_config.get("idp_id", "oidc")
|
||||
valid_idp_chars = set(string.ascii_letters + string.digits + "-._~")
|
||||
|
||||
if any(c not in valid_idp_chars for c in idp_id):
|
||||
raise ConfigError('idp_id may only contain A-Z, a-z, 0-9, "-", ".", "_", "~"')
|
||||
|
||||
return OidcProviderConfig(
|
||||
idp_id=idp_id,
|
||||
idp_name=oidc_config.get("idp_name", "OIDC"),
|
||||
discover=oidc_config.get("discover", True),
|
||||
issuer=oidc_config["issuer"],
|
||||
client_id=oidc_config["client_id"],
|
||||
@@ -296,8 +309,15 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
|
||||
)
|
||||
|
||||
|
||||
@attr.s
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class OidcProviderConfig:
|
||||
# a unique identifier for this identity provider. Used in the 'user_external_ids'
|
||||
# table, as well as the query/path parameter used in the login protocol.
|
||||
idp_id = attr.ib(type=str)
|
||||
|
||||
# user-facing name for this identity provider.
|
||||
idp_name = attr.ib(type=str)
|
||||
|
||||
# whether the OIDC discovery mechanism is used to discover endpoints
|
||||
discover = attr.ib(type=bool)
|
||||
|
||||
|
||||
@@ -175,7 +175,7 @@ class OidcHandler:
|
||||
session_data = self._token_generator.verify_oidc_session_token(
|
||||
session, state
|
||||
)
|
||||
except MacaroonDeserializationException as e:
|
||||
except (MacaroonDeserializationException, ValueError) as e:
|
||||
logger.exception("Invalid session")
|
||||
self._sso_handler.render_error(request, "invalid_session", str(e))
|
||||
return
|
||||
@@ -253,10 +253,10 @@ class OidcProvider:
|
||||
self._server_name = hs.config.server_name # type: str
|
||||
|
||||
# identifier for the external_ids table
|
||||
self.idp_id = "oidc"
|
||||
self.idp_id = provider.idp_id
|
||||
|
||||
# user-facing name of this auth provider
|
||||
self.idp_name = "OIDC"
|
||||
self.idp_name = provider.idp_name
|
||||
|
||||
self._sso_handler = hs.get_sso_handler()
|
||||
|
||||
@@ -656,6 +656,7 @@ class OidcProvider:
|
||||
cookie = self._token_generator.generate_oidc_session_token(
|
||||
state=state,
|
||||
session_data=OidcSessionData(
|
||||
idp_id=self.idp_id,
|
||||
nonce=nonce,
|
||||
client_redirect_url=client_redirect_url.decode(),
|
||||
ui_auth_session_id=ui_auth_session_id,
|
||||
@@ -924,6 +925,7 @@ class OidcSessionTokenGenerator:
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = session")
|
||||
macaroon.add_first_party_caveat("state = %s" % (state,))
|
||||
macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
|
||||
macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
|
||||
macaroon.add_first_party_caveat(
|
||||
"client_redirect_url = %s" % (session_data.client_redirect_url,)
|
||||
@@ -952,6 +954,9 @@ class OidcSessionTokenGenerator:
|
||||
|
||||
Returns:
|
||||
The data extracted from the session cookie
|
||||
|
||||
Raises:
|
||||
ValueError if an expected caveat is missing from the macaroon.
|
||||
"""
|
||||
macaroon = pymacaroons.Macaroon.deserialize(session)
|
||||
|
||||
@@ -960,6 +965,7 @@ class OidcSessionTokenGenerator:
|
||||
v.satisfy_exact("type = session")
|
||||
v.satisfy_exact("state = %s" % (state,))
|
||||
v.satisfy_general(lambda c: c.startswith("nonce = "))
|
||||
v.satisfy_general(lambda c: c.startswith("idp_id = "))
|
||||
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
|
||||
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
|
||||
# to always satisfy this.
|
||||
@@ -968,9 +974,9 @@ class OidcSessionTokenGenerator:
|
||||
|
||||
v.verify(macaroon, self._macaroon_secret_key)
|
||||
|
||||
# Extract the `nonce`, `client_redirect_url`, and maybe the
|
||||
# `ui_auth_session_id` from the token.
|
||||
# Extract the session data from the token.
|
||||
nonce = self._get_value_from_macaroon(macaroon, "nonce")
|
||||
idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
|
||||
client_redirect_url = self._get_value_from_macaroon(
|
||||
macaroon, "client_redirect_url"
|
||||
)
|
||||
@@ -983,6 +989,7 @@ class OidcSessionTokenGenerator:
|
||||
|
||||
return OidcSessionData(
|
||||
nonce=nonce,
|
||||
idp_id=idp_id,
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=ui_auth_session_id,
|
||||
)
|
||||
@@ -998,7 +1005,7 @@ class OidcSessionTokenGenerator:
|
||||
The extracted value
|
||||
|
||||
Raises:
|
||||
Exception: if the caveat was not in the macaroon
|
||||
ValueError: if the caveat was not in the macaroon
|
||||
"""
|
||||
prefix = key + " = "
|
||||
for caveat in macaroon.caveats:
|
||||
@@ -1019,6 +1026,9 @@ class OidcSessionTokenGenerator:
|
||||
class OidcSessionData:
|
||||
"""The attributes which are stored in a OIDC session cookie"""
|
||||
|
||||
# the Identity Provider being used
|
||||
idp_id = attr.ib(type=str)
|
||||
|
||||
# The `nonce` parameter passed to the OIDC provider.
|
||||
nonce = attr.ib(type=str)
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2019 New Vector Ltd
|
||||
# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -17,10 +17,11 @@
|
||||
import logging
|
||||
import os
|
||||
import urllib
|
||||
from typing import Awaitable
|
||||
from typing import Awaitable, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from twisted.internet.interfaces import IConsumer
|
||||
from twisted.protocols.basic import FileSender
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||
from synapse.http.server import finish_request, respond_with_json
|
||||
@@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
|
||||
]
|
||||
|
||||
|
||||
def parse_media_id(request):
|
||||
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
|
||||
try:
|
||||
# This allows users to append e.g. /test.png to the URL. Useful for
|
||||
# clients that parse the URL to see content type.
|
||||
@@ -69,7 +70,7 @@ def parse_media_id(request):
|
||||
)
|
||||
|
||||
|
||||
def respond_404(request):
|
||||
def respond_404(request: Request) -> None:
|
||||
respond_with_json(
|
||||
request,
|
||||
404,
|
||||
@@ -79,8 +80,12 @@ def respond_404(request):
|
||||
|
||||
|
||||
async def respond_with_file(
|
||||
request, media_type, file_path, file_size=None, upload_name=None
|
||||
):
|
||||
request: Request,
|
||||
media_type: str,
|
||||
file_path: str,
|
||||
file_size: Optional[int] = None,
|
||||
upload_name: Optional[str] = None,
|
||||
) -> None:
|
||||
logger.debug("Responding with %r", file_path)
|
||||
|
||||
if os.path.isfile(file_path):
|
||||
@@ -98,15 +103,20 @@ async def respond_with_file(
|
||||
respond_404(request)
|
||||
|
||||
|
||||
def add_file_headers(request, media_type, file_size, upload_name):
|
||||
def add_file_headers(
|
||||
request: Request,
|
||||
media_type: str,
|
||||
file_size: Optional[int],
|
||||
upload_name: Optional[str],
|
||||
) -> None:
|
||||
"""Adds the correct response headers in preparation for responding with the
|
||||
media.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request)
|
||||
media_type (str): The media/content type.
|
||||
file_size (int): Size in bytes of the media, if known.
|
||||
upload_name (str): The name of the requested file, if any.
|
||||
request
|
||||
media_type: The media/content type.
|
||||
file_size: Size in bytes of the media, if known.
|
||||
upload_name: The name of the requested file, if any.
|
||||
"""
|
||||
|
||||
def _quote(x):
|
||||
@@ -153,7 +163,8 @@ def add_file_headers(request, media_type, file_size, upload_name):
|
||||
# select private. don't bother setting Expires as all our
|
||||
# clients are smart enough to be happy with Cache-Control
|
||||
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
|
||||
request.setHeader(b"Content-Length", b"%d" % (file_size,))
|
||||
if file_size is not None:
|
||||
request.setHeader(b"Content-Length", b"%d" % (file_size,))
|
||||
|
||||
# Tell web crawlers to not index, archive, or follow links in media. This
|
||||
# should help to prevent things in the media repo from showing up in web
|
||||
@@ -184,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
|
||||
}
|
||||
|
||||
|
||||
def _can_encode_filename_as_token(x):
|
||||
def _can_encode_filename_as_token(x: str) -> bool:
|
||||
for c in x:
|
||||
# from RFC2616:
|
||||
#
|
||||
@@ -206,17 +217,21 @@ def _can_encode_filename_as_token(x):
|
||||
|
||||
|
||||
async def respond_with_responder(
|
||||
request, responder, media_type, file_size, upload_name=None
|
||||
):
|
||||
request: Request,
|
||||
responder: "Optional[Responder]",
|
||||
media_type: str,
|
||||
file_size: Optional[int],
|
||||
upload_name: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Responds to the request with given responder. If responder is None then
|
||||
returns 404.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request)
|
||||
responder (Responder|None)
|
||||
media_type (str): The media/content type.
|
||||
file_size (int|None): Size in bytes of the media. If not known it should be None
|
||||
upload_name (str|None): The name of the requested file, if any.
|
||||
request
|
||||
responder
|
||||
media_type: The media/content type.
|
||||
file_size: Size in bytes of the media. If not known it should be None
|
||||
upload_name: The name of the requested file, if any.
|
||||
"""
|
||||
if request._disconnected:
|
||||
logger.warning(
|
||||
@@ -308,22 +323,22 @@ class FileInfo:
|
||||
self.thumbnail_type = thumbnail_type
|
||||
|
||||
|
||||
def get_filename_from_headers(headers):
|
||||
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
|
||||
"""
|
||||
Get the filename of the downloaded file by inspecting the
|
||||
Content-Disposition HTTP header.
|
||||
|
||||
Args:
|
||||
headers (dict[bytes, list[bytes]]): The HTTP request headers.
|
||||
headers: The HTTP request headers.
|
||||
|
||||
Returns:
|
||||
A Unicode string of the filename, or None.
|
||||
The filename, or None.
|
||||
"""
|
||||
content_disposition = headers.get(b"Content-Disposition", [b""])
|
||||
|
||||
# No header, bail out.
|
||||
if not content_disposition[0]:
|
||||
return
|
||||
return None
|
||||
|
||||
_, params = _parse_header(content_disposition[0])
|
||||
|
||||
@@ -356,17 +371,16 @@ def get_filename_from_headers(headers):
|
||||
return upload_name
|
||||
|
||||
|
||||
def _parse_header(line):
|
||||
def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
|
||||
"""Parse a Content-type like header.
|
||||
|
||||
Cargo-culted from `cgi`, but works on bytes rather than strings.
|
||||
|
||||
Args:
|
||||
line (bytes): header to be parsed
|
||||
line: header to be parsed
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, dict[bytes, bytes]]:
|
||||
the main content-type, followed by the parameter dictionary
|
||||
The main content-type, followed by the parameter dictionary
|
||||
"""
|
||||
parts = _parseparam(b";" + line)
|
||||
key = next(parts)
|
||||
@@ -386,16 +400,16 @@ def _parse_header(line):
|
||||
return key, pdict
|
||||
|
||||
|
||||
def _parseparam(s):
|
||||
def _parseparam(s: bytes) -> Generator[bytes, None, None]:
|
||||
"""Generator which splits the input on ;, respecting double-quoted sequences
|
||||
|
||||
Cargo-culted from `cgi`, but works on bytes rather than strings.
|
||||
|
||||
Args:
|
||||
s (bytes): header to be parsed
|
||||
s: header to be parsed
|
||||
|
||||
Returns:
|
||||
Iterable[bytes]: the split input
|
||||
The split input
|
||||
"""
|
||||
while s[:1] == b";":
|
||||
s = s[1:]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 Will Hunt <will@half-shot.uk>
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,22 +15,29 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
|
||||
class MediaConfigResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
config = hs.get_config()
|
||||
self.clock = hs.get_clock()
|
||||
self.auth = hs.get_auth()
|
||||
self.limits_dict = {"m.upload.size": config.max_upload_size}
|
||||
|
||||
async def _async_render_GET(self, request):
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
await self.auth.get_user_by_req(request)
|
||||
respond_with_json(request, 200, self.limits_dict, send_cors=True)
|
||||
|
||||
async def _async_render_OPTIONS(self, request):
|
||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,24 +14,31 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
import synapse.http.servlet
|
||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||
from synapse.http.servlet import parse_boolean
|
||||
|
||||
from ._base import parse_media_id, respond_404
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DownloadResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo):
|
||||
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
||||
super().__init__()
|
||||
self.media_repo = media_repo
|
||||
self.server_name = hs.hostname
|
||||
|
||||
async def _async_render_GET(self, request):
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
set_cors_headers(request)
|
||||
request.setHeader(
|
||||
b"Content-Security-Policy",
|
||||
@@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
|
||||
if server_name == self.server_name:
|
||||
await self.media_repo.get_local_media(request, media_id, name)
|
||||
else:
|
||||
allow_remote = synapse.http.servlet.parse_boolean(
|
||||
request, "allow_remote", default=True
|
||||
)
|
||||
allow_remote = parse_boolean(request, "allow_remote", default=True)
|
||||
if not allow_remote:
|
||||
logger.info(
|
||||
"Rejecting request for remote media %s/%s due to allow_remote",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -16,11 +17,12 @@
|
||||
import functools
|
||||
import os
|
||||
import re
|
||||
from typing import Callable, List
|
||||
|
||||
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
|
||||
|
||||
|
||||
def _wrap_in_base_path(func):
|
||||
def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
|
||||
"""Takes a function that returns a relative path and turns it into an
|
||||
absolute path based on the location of the primary media store
|
||||
"""
|
||||
@@ -41,12 +43,18 @@ class MediaFilePaths:
|
||||
to write to the backup media store (when one is configured)
|
||||
"""
|
||||
|
||||
def __init__(self, primary_base_path):
|
||||
def __init__(self, primary_base_path: str):
|
||||
self.base_path = primary_base_path
|
||||
|
||||
def default_thumbnail_rel(
|
||||
self, default_top_level, default_sub_type, width, height, content_type, method
|
||||
):
|
||||
self,
|
||||
default_top_level: str,
|
||||
default_sub_type: str,
|
||||
width: int,
|
||||
height: int,
|
||||
content_type: str,
|
||||
method: str,
|
||||
) -> str:
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||
return os.path.join(
|
||||
@@ -55,12 +63,14 @@ class MediaFilePaths:
|
||||
|
||||
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
|
||||
|
||||
def local_media_filepath_rel(self, media_id):
|
||||
def local_media_filepath_rel(self, media_id: str) -> str:
|
||||
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
|
||||
|
||||
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
|
||||
|
||||
def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
|
||||
def local_media_thumbnail_rel(
|
||||
self, media_id: str, width: int, height: int, content_type: str, method: str
|
||||
) -> str:
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||
return os.path.join(
|
||||
@@ -86,7 +96,7 @@ class MediaFilePaths:
|
||||
media_id[4:],
|
||||
)
|
||||
|
||||
def remote_media_filepath_rel(self, server_name, file_id):
|
||||
def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
|
||||
return os.path.join(
|
||||
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
|
||||
)
|
||||
@@ -94,8 +104,14 @@ class MediaFilePaths:
|
||||
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
|
||||
|
||||
def remote_media_thumbnail_rel(
|
||||
self, server_name, file_id, width, height, content_type, method
|
||||
):
|
||||
self,
|
||||
server_name: str,
|
||||
file_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
content_type: str,
|
||||
method: str,
|
||||
) -> str:
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
|
||||
return os.path.join(
|
||||
@@ -113,7 +129,7 @@ class MediaFilePaths:
|
||||
# Should be removed after some time, when most of the thumbnails are stored
|
||||
# using the new path.
|
||||
def remote_media_thumbnail_rel_legacy(
|
||||
self, server_name, file_id, width, height, content_type
|
||||
self, server_name: str, file_id: str, width: int, height: int, content_type: str
|
||||
):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
|
||||
@@ -126,7 +142,7 @@ class MediaFilePaths:
|
||||
file_name,
|
||||
)
|
||||
|
||||
def remote_media_thumbnail_dir(self, server_name, file_id):
|
||||
def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
|
||||
return os.path.join(
|
||||
self.base_path,
|
||||
"remote_thumbnail",
|
||||
@@ -136,7 +152,7 @@ class MediaFilePaths:
|
||||
file_id[4:],
|
||||
)
|
||||
|
||||
def url_cache_filepath_rel(self, media_id):
|
||||
def url_cache_filepath_rel(self, media_id: str) -> str:
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
@@ -146,7 +162,7 @@ class MediaFilePaths:
|
||||
|
||||
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
|
||||
|
||||
def url_cache_filepath_dirs_to_delete(self, media_id):
|
||||
def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
|
||||
"The dirs to try and remove if we delete the media_id file"
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
|
||||
@@ -156,7 +172,9 @@ class MediaFilePaths:
|
||||
os.path.join(self.base_path, "url_cache", media_id[0:2]),
|
||||
]
|
||||
|
||||
def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
|
||||
def url_cache_thumbnail_rel(
|
||||
self, media_id: str, width: int, height: int, content_type: str, method: str
|
||||
) -> str:
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
||||
@@ -178,7 +196,7 @@ class MediaFilePaths:
|
||||
|
||||
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
|
||||
|
||||
def url_cache_thumbnail_directory(self, media_id):
|
||||
def url_cache_thumbnail_directory(self, media_id: str) -> str:
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
||||
@@ -195,7 +213,7 @@ class MediaFilePaths:
|
||||
media_id[4:],
|
||||
)
|
||||
|
||||
def url_cache_thumbnail_dirs_to_delete(self, media_id):
|
||||
def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
|
||||
"The dirs to try and remove if we delete the media_id thumbnails"
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,12 +13,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import errno
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import IO, Dict, List, Optional, Tuple
|
||||
from io import BytesIO
|
||||
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import twisted.internet.error
|
||||
import twisted.web.http
|
||||
@@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
|
||||
from .thumbnailer import Thumbnailer, ThumbnailError
|
||||
from .upload_resource import UploadResource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -63,7 +66,7 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
|
||||
|
||||
|
||||
class MediaRepository:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.client = hs.get_federation_http_client()
|
||||
@@ -73,16 +76,16 @@ class MediaRepository:
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.max_image_pixels = hs.config.max_image_pixels
|
||||
|
||||
self.primary_base_path = hs.config.media_store_path
|
||||
self.filepaths = MediaFilePaths(self.primary_base_path)
|
||||
self.primary_base_path = hs.config.media_store_path # type: str
|
||||
self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
|
||||
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||
|
||||
self.remote_media_linearizer = Linearizer(name="media_remote")
|
||||
|
||||
self.recently_accessed_remotes = set()
|
||||
self.recently_accessed_locals = set()
|
||||
self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
|
||||
self.recently_accessed_locals = set() # type: Set[str]
|
||||
|
||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||
|
||||
@@ -113,7 +116,7 @@ class MediaRepository:
|
||||
"update_recently_accessed_media", self._update_recently_accessed
|
||||
)
|
||||
|
||||
async def _update_recently_accessed(self):
|
||||
async def _update_recently_accessed(self) -> None:
|
||||
remote_media = self.recently_accessed_remotes
|
||||
self.recently_accessed_remotes = set()
|
||||
|
||||
@@ -124,12 +127,12 @@ class MediaRepository:
|
||||
local_media, remote_media, self.clock.time_msec()
|
||||
)
|
||||
|
||||
def mark_recently_accessed(self, server_name, media_id):
|
||||
def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
|
||||
"""Mark the given media as recently accessed.
|
||||
|
||||
Args:
|
||||
server_name (str|None): Origin server of media, or None if local
|
||||
media_id (str): The media ID of the content
|
||||
server_name: Origin server of media, or None if local
|
||||
media_id: The media ID of the content
|
||||
"""
|
||||
if server_name:
|
||||
self.recently_accessed_remotes.add((server_name, media_id))
|
||||
@@ -459,7 +462,14 @@ class MediaRepository:
|
||||
def _get_thumbnail_requirements(self, media_type):
|
||||
return self.thumbnail_requirements.get(media_type, ())
|
||||
|
||||
def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
|
||||
def _generate_thumbnail(
|
||||
self,
|
||||
thumbnailer: Thumbnailer,
|
||||
t_width: int,
|
||||
t_height: int,
|
||||
t_method: str,
|
||||
t_type: str,
|
||||
) -> Optional[BytesIO]:
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
|
||||
@@ -470,22 +480,20 @@ class MediaRepository:
|
||||
m_height,
|
||||
self.max_image_pixels,
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
if thumbnailer.transpose_method is not None:
|
||||
m_width, m_height = thumbnailer.transpose()
|
||||
|
||||
if t_method == "crop":
|
||||
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
|
||||
return thumbnailer.crop(t_width, t_height, t_type)
|
||||
elif t_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(t_width, t_height)
|
||||
t_width = min(m_width, t_width)
|
||||
t_height = min(m_height, t_height)
|
||||
t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
|
||||
else:
|
||||
t_byte_source = None
|
||||
return thumbnailer.scale(t_width, t_height, t_type)
|
||||
|
||||
return t_byte_source
|
||||
return None
|
||||
|
||||
async def generate_local_exact_thumbnail(
|
||||
self,
|
||||
@@ -776,7 +784,7 @@ class MediaRepository:
|
||||
|
||||
return {"width": m_width, "height": m_height}
|
||||
|
||||
async def delete_old_remote_media(self, before_ts):
|
||||
async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
|
||||
old_media = await self.store.get_remote_media_before(before_ts)
|
||||
|
||||
deleted = 0
|
||||
@@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
|
||||
within a given rectangle.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
# If we're not configured to use it, raise if we somehow got here.
|
||||
if not hs.config.can_load_media_repo:
|
||||
raise ConfigError("Synapse is not configured to use a media repo.")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vecotr Ltd
|
||||
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -18,6 +18,8 @@ import os
|
||||
import shutil
|
||||
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.interfaces import IConsumer
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
|
||||
@@ -270,7 +272,7 @@ class MediaStorage:
|
||||
return self.filepaths.local_media_filepath_rel(file_info.file_id)
|
||||
|
||||
|
||||
def _write_file_synchronously(source, dest):
|
||||
def _write_file_synchronously(source: IO, dest: IO) -> None:
|
||||
"""Write `source` to the file like `dest` synchronously. Should be called
|
||||
from a thread.
|
||||
|
||||
@@ -286,14 +288,14 @@ class FileResponder(Responder):
|
||||
"""Wraps an open file that can be sent to a request.
|
||||
|
||||
Args:
|
||||
open_file (file): A file like object to be streamed ot the client,
|
||||
open_file: A file like object to be streamed ot the client,
|
||||
is closed when finished streaming.
|
||||
"""
|
||||
|
||||
def __init__(self, open_file):
|
||||
def __init__(self, open_file: IO):
|
||||
self.open_file = open_file
|
||||
|
||||
def write_to_consumer(self, consumer):
|
||||
def write_to_consumer(self, consumer: IConsumer) -> Deferred:
|
||||
return make_deferred_yieldable(
|
||||
FileSender().beginFileTransfer(self.open_file, consumer)
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import errno
|
||||
import fnmatch
|
||||
@@ -23,12 +23,13 @@ import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Dict, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
|
||||
from urllib import parse as urlparse
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
@@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.rest.media.v1._base import get_filename_from_headers
|
||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
@@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
|
||||
|
||||
from ._base import FileInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lxml import etree
|
||||
|
||||
from synapse.app.homeserver import HomeServer
|
||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
|
||||
@@ -119,7 +127,12 @@ class OEmbedError(Exception):
|
||||
class PreviewUrlResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo, media_storage):
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
media_repo: "MediaRepository",
|
||||
media_storage: MediaStorage,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
@@ -165,11 +178,11 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
self._start_expire_url_cache_data, 10 * 1000
|
||||
)
|
||||
|
||||
async def _async_render_OPTIONS(self, request):
|
||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||
request.setHeader(b"Allow", b"OPTIONS, GET")
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
async def _async_render_GET(self, request):
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
|
||||
# XXX: if get_user_by_req fails, what should we do in an async render?
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
@@ -449,7 +462,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
|
||||
raise OEmbedError() from e
|
||||
|
||||
async def _download_url(self, url: str, user):
|
||||
async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
|
||||
# TODO: we should probably honour robots.txt... except in practice
|
||||
# we're most likely being explicitly triggered by a human rather than a
|
||||
# bot, so are we really a robot?
|
||||
@@ -579,7 +592,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
"expire_url_cache_data", self._expire_url_cache_data
|
||||
)
|
||||
|
||||
async def _expire_url_cache_data(self):
|
||||
async def _expire_url_cache_data(self) -> None:
|
||||
"""Clean up expired url cache content, media and thumbnails.
|
||||
"""
|
||||
# TODO: Delete from backup media store
|
||||
@@ -675,7 +688,9 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
logger.debug("No media removed from url cache")
|
||||
|
||||
|
||||
def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
|
||||
def decode_and_calc_og(
|
||||
body: bytes, media_uri: str, request_encoding: Optional[str] = None
|
||||
) -> Dict[str, Optional[str]]:
|
||||
# If there's no body, nothing useful is going to be found.
|
||||
if not body:
|
||||
return {}
|
||||
@@ -696,7 +711,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]
|
||||
return og
|
||||
|
||||
|
||||
def _calc_og(tree, media_uri):
|
||||
def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
|
||||
# suck our tree into lxml and define our OG response.
|
||||
|
||||
# if we see any image URLs in the OG response, then spider them
|
||||
@@ -800,7 +815,9 @@ def _calc_og(tree, media_uri):
|
||||
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
|
||||
)
|
||||
og["og:description"] = summarize_paragraphs(text_nodes)
|
||||
else:
|
||||
elif og["og:description"]:
|
||||
# This must be a non-empty string at this point.
|
||||
assert isinstance(og["og:description"], str)
|
||||
og["og:description"] = summarize_paragraphs([og["og:description"]])
|
||||
|
||||
# TODO: delete the url downloads to stop diskfilling,
|
||||
@@ -808,7 +825,9 @@ def _calc_og(tree, media_uri):
|
||||
return og
|
||||
|
||||
|
||||
def _iterate_over_text(tree, *tags_to_ignore):
|
||||
def _iterate_over_text(
|
||||
tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
|
||||
) -> Generator[str, None, None]:
|
||||
"""Iterate over the tree returning text nodes in a depth first fashion,
|
||||
skipping text nodes inside certain tags.
|
||||
"""
|
||||
@@ -842,32 +861,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
|
||||
)
|
||||
|
||||
|
||||
def _rebase_url(url, base):
|
||||
base = list(urlparse.urlparse(base))
|
||||
url = list(urlparse.urlparse(url))
|
||||
if not url[0]: # fix up schema
|
||||
url[0] = base[0] or "http"
|
||||
if not url[1]: # fix up hostname
|
||||
url[1] = base[1]
|
||||
if not url[2].startswith("/"):
|
||||
url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
|
||||
return urlparse.urlunparse(url)
|
||||
def _rebase_url(url: str, base: str) -> str:
|
||||
base_parts = list(urlparse.urlparse(base))
|
||||
url_parts = list(urlparse.urlparse(url))
|
||||
if not url_parts[0]: # fix up schema
|
||||
url_parts[0] = base_parts[0] or "http"
|
||||
if not url_parts[1]: # fix up hostname
|
||||
url_parts[1] = base_parts[1]
|
||||
if not url_parts[2].startswith("/"):
|
||||
url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
|
||||
return urlparse.urlunparse(url_parts)
|
||||
|
||||
|
||||
def _is_media(content_type):
|
||||
if content_type.lower().startswith("image/"):
|
||||
return True
|
||||
def _is_media(content_type: str) -> bool:
|
||||
return content_type.lower().startswith("image/")
|
||||
|
||||
|
||||
def _is_html(content_type):
|
||||
def _is_html(content_type: str) -> bool:
|
||||
content_type = content_type.lower()
|
||||
if content_type.startswith("text/html") or content_type.startswith(
|
||||
return content_type.startswith("text/html") or content_type.startswith(
|
||||
"application/xhtml"
|
||||
):
|
||||
return True
|
||||
)
|
||||
|
||||
|
||||
def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
|
||||
def summarize_paragraphs(
|
||||
text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
|
||||
) -> Optional[str]:
|
||||
# Try to get a summary of between 200 and 500 words, respecting
|
||||
# first paragraph and then word boundaries.
|
||||
# TODO: Respect sentences?
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,10 +13,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.config._base import Config
|
||||
from synapse.logging.context import defer_to_thread, run_in_background
|
||||
@@ -27,13 +28,17 @@ from .media_storage import FileResponder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
class StorageProvider:
|
||||
|
||||
class StorageProvider(metaclass=abc.ABCMeta):
|
||||
"""A storage provider is a service that can store uploaded media and
|
||||
retrieve them.
|
||||
"""
|
||||
|
||||
async def store_file(self, path: str, file_info: FileInfo):
|
||||
@abc.abstractmethod
|
||||
async def store_file(self, path: str, file_info: FileInfo) -> None:
|
||||
"""Store the file described by file_info. The actual contents can be
|
||||
retrieved by reading the file in file_info.upload_path.
|
||||
|
||||
@@ -42,6 +47,7 @@ class StorageProvider:
|
||||
file_info: The metadata of the file.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
|
||||
"""Attempt to fetch the file described by file_info and stream it
|
||||
into writer.
|
||||
@@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
|
||||
self.store_synchronous = store_synchronous
|
||||
self.store_remote = store_remote
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return "StorageProviderWrapper[%s]" % (self.backend,)
|
||||
|
||||
async def store_file(self, path, file_info):
|
||||
async def store_file(self, path: str, file_info: FileInfo) -> None:
|
||||
if not file_info.server_name and not self.store_local:
|
||||
return None
|
||||
|
||||
@@ -91,7 +97,7 @@ class StorageProviderWrapper(StorageProvider):
|
||||
if self.store_synchronous:
|
||||
# store_file is supposed to return an Awaitable, but guard
|
||||
# against improper implementations.
|
||||
return await maybe_awaitable(self.backend.store_file(path, file_info))
|
||||
await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
|
||||
else:
|
||||
# TODO: Handle errors.
|
||||
async def store():
|
||||
@@ -103,9 +109,8 @@ class StorageProviderWrapper(StorageProvider):
|
||||
logger.exception("Error storing file")
|
||||
|
||||
run_in_background(store)
|
||||
return None
|
||||
|
||||
async def fetch(self, path, file_info):
|
||||
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
|
||||
# store_file is supposed to return an Awaitable, but guard
|
||||
# against improper implementations.
|
||||
return await maybe_awaitable(self.backend.fetch(path, file_info))
|
||||
@@ -115,11 +120,11 @@ class FileStorageProviderBackend(StorageProvider):
|
||||
"""A storage provider that stores files in a directory on a filesystem.
|
||||
|
||||
Args:
|
||||
hs (HomeServer)
|
||||
hs
|
||||
config: The config returned by `parse_config`.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, config):
|
||||
def __init__(self, hs: "HomeServer", config: str):
|
||||
self.hs = hs
|
||||
self.cache_directory = hs.config.media_store_path
|
||||
self.base_directory = config
|
||||
@@ -127,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
|
||||
def __str__(self):
|
||||
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
|
||||
|
||||
async def store_file(self, path, file_info):
|
||||
async def store_file(self, path: str, file_info: FileInfo) -> None:
|
||||
"""See StorageProvider.store_file"""
|
||||
|
||||
primary_fname = os.path.join(self.cache_directory, path)
|
||||
@@ -137,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
return await defer_to_thread(
|
||||
await defer_to_thread(
|
||||
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
|
||||
)
|
||||
|
||||
async def fetch(self, path, file_info):
|
||||
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
|
||||
"""See StorageProvider.fetch"""
|
||||
|
||||
backup_fname = os.path.join(self.base_directory, path)
|
||||
if os.path.isfile(backup_fname):
|
||||
return FileResponder(open(backup_fname, "rb"))
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config):
|
||||
def parse_config(config: dict) -> str:
|
||||
"""Called on startup to parse config supplied. This should parse
|
||||
the config and raise if there is a problem.
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,10 +16,14 @@
|
||||
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||
from synapse.http.servlet import parse_integer, parse_string
|
||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||
|
||||
from ._base import (
|
||||
FileInfo,
|
||||
@@ -28,13 +33,22 @@ from ._base import (
|
||||
respond_with_responder,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThumbnailResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo, media_storage):
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
media_repo: "MediaRepository",
|
||||
media_storage: MediaStorage,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
@@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.server_name = hs.hostname
|
||||
|
||||
async def _async_render_GET(self, request):
|
||||
async def _async_render_GET(self, request: Request) -> None:
|
||||
set_cors_headers(request)
|
||||
server_name, media_id, _ = parse_media_id(request)
|
||||
width = parse_integer(request, "width", required=True)
|
||||
@@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||
|
||||
async def _respond_local_thumbnail(
|
||||
self, request, media_id, width, height, method, m_type
|
||||
):
|
||||
self,
|
||||
request: Request,
|
||||
media_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
method: str,
|
||||
m_type: str,
|
||||
) -> None:
|
||||
media_info = await self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info:
|
||||
@@ -114,13 +134,13 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||
|
||||
async def _select_or_generate_local_thumbnail(
|
||||
self,
|
||||
request,
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
):
|
||||
request: Request,
|
||||
media_id: str,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
) -> None:
|
||||
media_info = await self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info:
|
||||
@@ -178,14 +198,14 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||
|
||||
async def _select_or_generate_remote_thumbnail(
|
||||
self,
|
||||
request,
|
||||
server_name,
|
||||
media_id,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
):
|
||||
request: Request,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
) -> None:
|
||||
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
|
||||
|
||||
thumbnail_infos = await self.store.get_remote_media_thumbnails(
|
||||
@@ -239,8 +259,15 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||
raise SynapseError(400, "Failed to generate thumbnail.")
|
||||
|
||||
async def _respond_remote_thumbnail(
|
||||
self, request, server_name, media_id, width, height, method, m_type
|
||||
):
|
||||
self,
|
||||
request: Request,
|
||||
server_name: str,
|
||||
media_id: str,
|
||||
width: int,
|
||||
height: int,
|
||||
method: str,
|
||||
m_type: str,
|
||||
) -> None:
|
||||
# TODO: Don't download the whole remote file
|
||||
# We should proxy the thumbnail from the remote server instead of
|
||||
# downloading the remote file and generating our own thumbnails.
|
||||
@@ -275,12 +302,12 @@ class ThumbnailResource(DirectServeJsonResource):
|
||||
|
||||
def _select_thumbnail(
|
||||
self,
|
||||
desired_width,
|
||||
desired_height,
|
||||
desired_method,
|
||||
desired_type,
|
||||
desired_width: int,
|
||||
desired_height: int,
|
||||
desired_method: str,
|
||||
desired_type: str,
|
||||
thumbnail_infos,
|
||||
):
|
||||
) -> dict:
|
||||
d_w = desired_width
|
||||
d_h = desired_height
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -39,7 +41,7 @@ class Thumbnailer:
|
||||
|
||||
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
|
||||
|
||||
def __init__(self, input_path):
|
||||
def __init__(self, input_path: str):
|
||||
try:
|
||||
self.image = Image.open(input_path)
|
||||
except OSError as e:
|
||||
@@ -59,11 +61,11 @@ class Thumbnailer:
|
||||
# A lot of parsing errors can happen when parsing EXIF
|
||||
logger.info("Error parsing image EXIF information: %s", e)
|
||||
|
||||
def transpose(self):
|
||||
def transpose(self) -> Tuple[int, int]:
|
||||
"""Transpose the image using its EXIF Orientation tag
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: (width, height) containing the new image size in pixels.
|
||||
A tuple containing the new image size in pixels as (width, height).
|
||||
"""
|
||||
if self.transpose_method is not None:
|
||||
self.image = self.image.transpose(self.transpose_method)
|
||||
@@ -73,7 +75,7 @@ class Thumbnailer:
|
||||
self.image.info["exif"] = None
|
||||
return self.image.size
|
||||
|
||||
def aspect(self, max_width, max_height):
|
||||
def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
|
||||
"""Calculate the largest size that preserves aspect ratio which
|
||||
fits within the given rectangle::
|
||||
|
||||
@@ -91,7 +93,7 @@ class Thumbnailer:
|
||||
else:
|
||||
return (max_height * self.width) // self.height, max_height
|
||||
|
||||
def _resize(self, width, height):
|
||||
def _resize(self, width: int, height: int) -> Image:
|
||||
# 1-bit or 8-bit color palette images need converting to RGB
|
||||
# otherwise they will be scaled using nearest neighbour which
|
||||
# looks awful
|
||||
@@ -99,7 +101,7 @@ class Thumbnailer:
|
||||
self.image = self.image.convert("RGB")
|
||||
return self.image.resize((width, height), Image.ANTIALIAS)
|
||||
|
||||
def scale(self, width, height, output_type):
|
||||
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
|
||||
"""Rescales the image to the given dimensions.
|
||||
|
||||
Returns:
|
||||
@@ -108,7 +110,7 @@ class Thumbnailer:
|
||||
scaled = self._resize(width, height)
|
||||
return self._encode_image(scaled, output_type)
|
||||
|
||||
def crop(self, width, height, output_type):
|
||||
def crop(self, width: int, height: int, output_type: str) -> BytesIO:
|
||||
"""Rescales and crops the image to the given dimensions preserving
|
||||
aspect::
|
||||
(w_in / h_in) = (w_scaled / h_scaled)
|
||||
@@ -136,7 +138,7 @@ class Thumbnailer:
|
||||
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
|
||||
return self._encode_image(cropped, output_type)
|
||||
|
||||
def _encode_image(self, output_image, output_type):
|
||||
def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
|
||||
output_bytes_io = BytesIO()
|
||||
fmt = self.FORMATS[output_type]
|
||||
if fmt == "JPEG":
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,18 +15,25 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_string
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UploadResource(DirectServeJsonResource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo):
|
||||
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
|
||||
super().__init__()
|
||||
|
||||
self.media_repo = media_repo
|
||||
@@ -37,10 +45,10 @@ class UploadResource(DirectServeJsonResource):
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
async def _async_render_OPTIONS(self, request):
|
||||
async def _async_render_OPTIONS(self, request: Request) -> None:
|
||||
respond_with_json(request, 200, {}, send_cors=True)
|
||||
|
||||
async def _async_render_POST(self, request):
|
||||
async def _async_render_POST(self, request: Request) -> None:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -169,7 +170,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
|
||||
async def get_local_media_before(
|
||||
self, before_ts: int, size_gt: int, keep_profiles: bool,
|
||||
) -> Optional[List[str]]:
|
||||
) -> List[str]:
|
||||
|
||||
# to find files that have never been accessed (last_access_ts IS NULL)
|
||||
# compare with `created_ts`
|
||||
|
||||
@@ -13,20 +13,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import re
|
||||
from typing import Dict, Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from mock import ANY, Mock, patch
|
||||
|
||||
import pymacaroons
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.api.errors import RedirectException
|
||||
from synapse.handlers.sso import MappingException
|
||||
from synapse.rest.client.v1 import login
|
||||
from synapse.rest.synapse.client.pick_username import pick_username_resource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID
|
||||
|
||||
@@ -848,6 +842,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
return self.handler._token_generator.generate_oidc_session_token(
|
||||
state=state,
|
||||
session_data=OidcSessionData(
|
||||
idp_id="oidc",
|
||||
nonce=nonce,
|
||||
client_redirect_url=client_redirect_url,
|
||||
ui_auth_session_id=ui_auth_session_id,
|
||||
@@ -855,116 +850,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
|
||||
)
|
||||
|
||||
|
||||
class UsernamePickerTestCase(HomeserverTestCase):
|
||||
if not HAS_OIDC:
|
||||
skip = "requires OIDC"
|
||||
|
||||
servlets = [login.register_servlets]
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["public_baseurl"] = BASE_URL
|
||||
oidc_config = {
|
||||
"enabled": True,
|
||||
"client_id": CLIENT_ID,
|
||||
"client_secret": CLIENT_SECRET,
|
||||
"issuer": ISSUER,
|
||||
"scopes": SCOPES,
|
||||
"user_mapping_provider": {
|
||||
"config": {"display_name_template": "{{ user.displayname }}"}
|
||||
},
|
||||
}
|
||||
|
||||
# Update this config with what's in the default config so that
|
||||
# override_config works as expected.
|
||||
oidc_config.update(config.get("oidc_config", {}))
|
||||
config["oidc_config"] = oidc_config
|
||||
|
||||
# whitelist this client URI so we redirect straight to it rather than
|
||||
# serving a confirmation page
|
||||
config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
|
||||
return config
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
d = super().create_resource_dict()
|
||||
d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
|
||||
return d
|
||||
|
||||
def test_username_picker(self):
|
||||
"""Test the happy path of a username picker flow."""
|
||||
client_redirect_url = "https://whitelisted.client"
|
||||
|
||||
# first of all, mock up an OIDC callback to the OidcHandler, which should
|
||||
# raise a RedirectException
|
||||
userinfo = {"sub": "tester", "displayname": "Jonny"}
|
||||
f = self.get_failure(
|
||||
_make_callback_with_userinfo(
|
||||
self.hs, userinfo, client_redirect_url=client_redirect_url
|
||||
),
|
||||
RedirectException,
|
||||
)
|
||||
|
||||
# check the Location and cookies returned by the RedirectException
|
||||
self.assertEqual(f.value.location, b"/_synapse/client/pick_username")
|
||||
cookieheader = f.value.cookies[0]
|
||||
regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);")
|
||||
m = regex.search(cookieheader)
|
||||
if not m:
|
||||
self.fail("cookie header %s does not match %s" % (cookieheader, regex))
|
||||
|
||||
# introspect the sso handler a bit to check that the username mapping session
|
||||
# looks ok.
|
||||
session_id = m.group(1).decode("ascii")
|
||||
username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
|
||||
self.assertIn(
|
||||
session_id, username_mapping_sessions, "session id not found in map"
|
||||
)
|
||||
session = username_mapping_sessions[session_id]
|
||||
self.assertEqual(session.remote_user_id, "tester")
|
||||
self.assertEqual(session.display_name, "Jonny")
|
||||
self.assertEqual(session.client_redirect_url, client_redirect_url)
|
||||
|
||||
# the expiry time should be about 15 minutes away
|
||||
expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
|
||||
self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
|
||||
|
||||
# Now, submit a username to the username picker, which should serve a redirect
|
||||
# back to the client
|
||||
submit_path = f.value.location + b"/submit"
|
||||
content = urlencode({b"username": b"bobby"}).encode("utf8")
|
||||
chan = self.make_request(
|
||||
"POST",
|
||||
path=submit_path,
|
||||
content=content,
|
||||
content_is_form=True,
|
||||
custom_headers=[
|
||||
("Cookie", cookieheader),
|
||||
# old versions of twisted don't do form-parsing without a valid
|
||||
# content-length header.
|
||||
("Content-Length", str(len(content))),
|
||||
],
|
||||
)
|
||||
self.assertEqual(chan.code, 302, chan.result)
|
||||
location_headers = chan.headers.getRawHeaders("Location")
|
||||
# ensure that the returned location starts with the requested redirect URL
|
||||
self.assertEqual(
|
||||
location_headers[0][: len(client_redirect_url)], client_redirect_url
|
||||
)
|
||||
|
||||
# fish the login token out of the returned redirect uri
|
||||
parts = urlparse(location_headers[0])
|
||||
query = parse_qs(parts.query)
|
||||
login_token = query["loginToken"][0]
|
||||
|
||||
# finally, submit the matrix login token to the login API, which gives us our
|
||||
# matrix access token, mxid, and device id.
|
||||
chan = self.make_request(
|
||||
"POST", "/login", content={"type": "m.login.token", "token": login_token},
|
||||
)
|
||||
self.assertEqual(chan.code, 200, chan.result)
|
||||
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
|
||||
|
||||
|
||||
async def _make_callback_with_userinfo(
|
||||
hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
|
||||
) -> None:
|
||||
@@ -990,7 +875,7 @@ async def _make_callback_with_userinfo(
|
||||
session = handler._token_generator.generate_oidc_session_token(
|
||||
state=state,
|
||||
session_data=OidcSessionData(
|
||||
nonce="nonce", client_redirect_url=client_redirect_url,
|
||||
idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
|
||||
),
|
||||
)
|
||||
request = _build_callback_request("code", state, session)
|
||||
|
||||
@@ -17,6 +17,7 @@ import time
|
||||
import urllib.parse
|
||||
from html.parser import HTMLParser
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from urllib.parse import parse_qs, urlencode, urlparse
|
||||
|
||||
from mock import Mock
|
||||
|
||||
@@ -30,13 +31,14 @@ from synapse.rest.client.v1 import login, logout
|
||||
from synapse.rest.client.v2_alpha import devices, register
|
||||
from synapse.rest.client.v2_alpha.account import WhoamiRestServlet
|
||||
from synapse.rest.synapse.client.pick_idp import PickIdpResource
|
||||
from synapse.rest.synapse.client.pick_username import pick_username_resource
|
||||
from synapse.types import create_requester
|
||||
|
||||
from tests import unittest
|
||||
from tests.handlers.test_oidc import HAS_OIDC
|
||||
from tests.handlers.test_saml import has_saml2
|
||||
from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
|
||||
from tests.unittest import override_config, skip_unless
|
||||
from tests.unittest import HomeserverTestCase, override_config, skip_unless
|
||||
|
||||
try:
|
||||
import jwt
|
||||
@@ -1060,3 +1062,104 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
|
||||
self.assertEquals(channel.result["code"], b"401", channel.result)
|
||||
|
||||
|
||||
@skip_unless(HAS_OIDC, "requires OIDC")
|
||||
class UsernamePickerTestCase(HomeserverTestCase):
|
||||
"""Tests for the username picker flow of SSO login"""
|
||||
|
||||
servlets = [login.register_servlets]
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["public_baseurl"] = BASE_URL
|
||||
|
||||
config["oidc_config"] = {}
|
||||
config["oidc_config"].update(TEST_OIDC_CONFIG)
|
||||
config["oidc_config"]["user_mapping_provider"] = {
|
||||
"config": {"display_name_template": "{{ user.displayname }}"}
|
||||
}
|
||||
|
||||
# whitelist this client URI so we redirect straight to it rather than
|
||||
# serving a confirmation page
|
||||
config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
|
||||
return config
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
from synapse.rest.oidc import OIDCResource
|
||||
|
||||
d = super().create_resource_dict()
|
||||
d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
|
||||
d["/_synapse/oidc"] = OIDCResource(self.hs)
|
||||
return d
|
||||
|
||||
def test_username_picker(self):
|
||||
"""Test the happy path of a username picker flow."""
|
||||
client_redirect_url = "https://whitelisted.client"
|
||||
|
||||
# do the start of the login flow
|
||||
channel = self.helper.auth_via_oidc(
|
||||
{"sub": "tester", "displayname": "Jonny"}, client_redirect_url
|
||||
)
|
||||
|
||||
# that should redirect to the username picker
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
picker_url = channel.headers.getRawHeaders("Location")[0]
|
||||
self.assertEqual(picker_url, "/_synapse/client/pick_username")
|
||||
|
||||
# ... with a username_mapping_session cookie
|
||||
cookies = {} # type: Dict[str,str]
|
||||
channel.extract_cookies(cookies)
|
||||
self.assertIn("username_mapping_session", cookies)
|
||||
session_id = cookies["username_mapping_session"]
|
||||
|
||||
# introspect the sso handler a bit to check that the username mapping session
|
||||
# looks ok.
|
||||
username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
|
||||
self.assertIn(
|
||||
session_id, username_mapping_sessions, "session id not found in map",
|
||||
)
|
||||
session = username_mapping_sessions[session_id]
|
||||
self.assertEqual(session.remote_user_id, "tester")
|
||||
self.assertEqual(session.display_name, "Jonny")
|
||||
self.assertEqual(session.client_redirect_url, client_redirect_url)
|
||||
|
||||
# the expiry time should be about 15 minutes away
|
||||
expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
|
||||
self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
|
||||
|
||||
# Now, submit a username to the username picker, which should serve a redirect
|
||||
# back to the client
|
||||
submit_path = picker_url + "/submit"
|
||||
content = urlencode({b"username": b"bobby"}).encode("utf8")
|
||||
chan = self.make_request(
|
||||
"POST",
|
||||
path=submit_path,
|
||||
content=content,
|
||||
content_is_form=True,
|
||||
custom_headers=[
|
||||
("Cookie", "username_mapping_session=" + session_id),
|
||||
# old versions of twisted don't do form-parsing without a valid
|
||||
# content-length header.
|
||||
("Content-Length", str(len(content))),
|
||||
],
|
||||
)
|
||||
self.assertEqual(chan.code, 302, chan.result)
|
||||
location_headers = chan.headers.getRawHeaders("Location")
|
||||
# ensure that the returned location starts with the requested redirect URL
|
||||
self.assertEqual(
|
||||
location_headers[0][: len(client_redirect_url)], client_redirect_url
|
||||
)
|
||||
|
||||
# fish the login token out of the returned redirect uri
|
||||
parts = urlparse(location_headers[0])
|
||||
query = parse_qs(parts.query)
|
||||
login_token = query["loginToken"][0]
|
||||
|
||||
# finally, submit the matrix login token to the login API, which gives us our
|
||||
# matrix access token, mxid, and device id.
|
||||
chan = self.make_request(
|
||||
"POST", "/login", content={"type": "m.login.token", "token": login_token},
|
||||
)
|
||||
self.assertEqual(chan.code, 200, chan.result)
|
||||
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
|
||||
|
||||
@@ -363,10 +363,10 @@ class RestHelper:
|
||||
the normal places.
|
||||
"""
|
||||
client_redirect_url = "https://x"
|
||||
channel = self.auth_via_oidc(remote_user_id, client_redirect_url)
|
||||
channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
|
||||
|
||||
# expect a confirmation page
|
||||
assert channel.code == 200
|
||||
assert channel.code == 200, channel.result
|
||||
|
||||
# fish the matrix login token out of the body of the confirmation page
|
||||
m = re.search(
|
||||
@@ -390,7 +390,7 @@ class RestHelper:
|
||||
|
||||
def auth_via_oidc(
|
||||
self,
|
||||
remote_user_id: str,
|
||||
user_info_dict: JsonDict,
|
||||
client_redirect_url: Optional[str] = None,
|
||||
ui_auth_session_id: Optional[str] = None,
|
||||
) -> FakeChannel:
|
||||
@@ -411,7 +411,8 @@ class RestHelper:
|
||||
the normal places.
|
||||
|
||||
Args:
|
||||
remote_user_id: the remote id that the OIDC provider should present
|
||||
user_info_dict: the remote userinfo that the OIDC provider should present.
|
||||
Typically this should be '{"sub": "<remote user id>"}'.
|
||||
client_redirect_url: for a login flow, the client redirect URL to pass to
|
||||
the login redirect endpoint
|
||||
ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
|
||||
@@ -457,7 +458,7 @@ class RestHelper:
|
||||
# a dummy OIDC access token
|
||||
("https://issuer.test/token", {"access_token": "TEST"}),
|
||||
# and then one to the user_info endpoint, which returns our remote user id.
|
||||
("https://issuer.test/userinfo", {"sub": remote_user_id}),
|
||||
("https://issuer.test/userinfo", user_info_dict),
|
||||
]
|
||||
|
||||
async def mock_req(method: str, uri: str, data=None, headers=None):
|
||||
|
||||
@@ -411,7 +411,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
|
||||
# run the UIA-via-SSO flow
|
||||
session_id = channel.json_body["session"]
|
||||
channel = self.helper.auth_via_oidc(
|
||||
remote_user_id=remote_user_id, ui_auth_session_id=session_id
|
||||
{"sub": remote_user_id}, ui_auth_session_id=session_id
|
||||
)
|
||||
|
||||
# that should serve a confirmation page
|
||||
|
||||
Reference in New Issue
Block a user