1
0
This commit is contained in:
Brendan Abolivier
2021-12-06 16:21:14 +00:00
parent 783a4c9dfe
commit 22630d35dd
9 changed files with 43 additions and 42 deletions

View File

@@ -87,7 +87,7 @@ class AccountValidityHandler:
# Mark users as inactive when they expired. Check once every hour
if self._account_validity_enabled:
def mark_expired_users_as_inactive():
def mark_expired_users_as_inactive() -> Awaitable:
# run as a background process to allow async functions to work
return run_as_background_process(
"_mark_expired_users_as_inactive",
@@ -438,12 +438,9 @@ class AccountValidityHandler:
return expiration_ts
async def _mark_expired_users_as_inactive(self):
async def _mark_expired_users_as_inactive(self) -> None:
"""Iterate over active, expired users. Mark them as inactive in order to hide them
from the user directory.
Returns:
Deferred
"""
# Get active, expired users
active_expired_users = await self.store.get_expired_users()

View File

@@ -462,7 +462,7 @@ class IdentityHandler:
return session_id
def rewrite_id_server_url(self, url: str, add_https=False) -> str:
def rewrite_id_server_url(self, url: str, add_https: bool = False) -> str:
"""Given an identity server URL, optionally add a protocol scheme
before rewriting it according to the rewrite_identity_server_urls
config option
@@ -1109,7 +1109,7 @@ class IdentityHandler:
id_server_url: str,
email: str,
user_id: str,
):
) -> None:
"""Bind an email to a fully qualified user ID using the internal API of an
instance of Sydent.

View File

@@ -630,7 +630,7 @@ class RegistrationHandler:
"""
await self._auto_join_rooms(user_id)
async def appservice_register(self, user_localpart: str, as_token: str):
async def appservice_register(self, user_localpart: str, as_token: str) -> str:
# FIXME: this should be factored out and merged with normal register()
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()

View File

@@ -17,7 +17,7 @@ import logging
import random
import re
from http import HTTPStatus
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional, Tuple
from urllib.parse import urlparse
from twisted.web.server import Request
@@ -294,7 +294,7 @@ class PasswordRestServlet(RestServlet):
return 200, {}
def on_OPTIONS(self, _):
def on_OPTIONS(self, _: Any) -> Tuple[int, JsonDict]:
return 200, {}
@@ -844,25 +844,27 @@ class ThreepidDeleteRestServlet(RestServlet):
class ThreepidLookupRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")]
def __init__(self, hs):
def __init__(self, hs: "HomeServer") -> None:
super(ThreepidLookupRestServlet, self).__init__()
self.auth = hs.get_auth()
self.identity_handler = hs.get_identity_handler()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Proxy a /_matrix/identity/api/v1/lookup request to an identity
server
"""
await self.auth.get_user_by_req(request)
# Verify query parameters
query_params = request.args
# Mypy will complain that request.args is of an incompatible type with JsonDict
# because Twisted is badly typed, so we just ignore it.
query_params: JsonDict = request.args # type: ignore[assignment]
assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"])
# Retrieve needed information from query parameters
medium = parse_string(request, "medium")
address = parse_string(request, "address")
id_server = parse_string(request, "id_server")
medium = parse_string(request, "medium", required=True)
address = parse_string(request, "address", required=True)
id_server = parse_string(request, "id_server", required=True)
# Proxy the request to the identity server. lookup_3pid handles checking
# if the lookup is allowed so we don't need to do it here.
@@ -874,12 +876,12 @@ class ThreepidLookupRestServlet(RestServlet):
class ThreepidBulkLookupRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")]
def __init__(self, hs):
def __init__(self, hs: "HomeServer") -> None:
super(ThreepidBulkLookupRestServlet, self).__init__()
self.auth = hs.get_auth()
self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
server
"""

View File

@@ -79,7 +79,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
return 200, {}
def on_OPTIONS(self, request, user_id):
def on_OPTIONS(self, request: SynapseRequest, user_id: str) -> Tuple[int, JsonDict]:
return 200, {}
@@ -134,7 +134,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
return 200, {}
def on_OPTIONS(self, request, user_id):
def on_OPTIONS(self, request: SynapseRequest, user_id: str) -> Tuple[int, JsonDict]:
return 200, {}

View File

@@ -903,13 +903,13 @@ class RegisterRestServlet(RestServlet):
return 200, result
def cap(name):
def cap(name: str) -> str:
"""Capitalise parts of a name containing different words, including those
separated by hyphens.
For example, 'John-Doe'
Args:
name (str): The name to parse
The name to parse
"""
if not name:
return name
@@ -923,13 +923,13 @@ def cap(name):
return capatilized_name
def _map_email_to_displayname(address):
def _map_email_to_displayname(address: str) -> str:
"""Custom mapping from an email address to a user displayname
Args:
address (str): The email address to process
address: The email address to process
Returns:
str: The new displayname
The new displayname
"""
# Split the part before and after the @ in the email.
# Replace all . with spaces in the first part

View File

@@ -104,7 +104,7 @@ class SingleUserInfoServlet(RestServlet):
PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer") -> None:
super(SingleUserInfoServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -115,7 +115,9 @@ class SingleUserInfoServlet(RestServlet):
if not registry.query_handlers.get("user_info"):
registry.register_query_handler("user_info", self._on_federation_query)
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
# Ensure the user is authenticated
await self.auth.get_user_by_req(request)
@@ -131,14 +133,14 @@ class SingleUserInfoServlet(RestServlet):
user_id_to_info = await self.store.get_info_for_users([user_id])
return 200, user_id_to_info[user_id]
async def _on_federation_query(self, args):
async def _on_federation_query(self, args: JsonDict) -> JsonDict:
"""Called when a request for user information appears over federation
Args:
args (dict): Dictionary of query arguments provided by the request
args: Dictionary of query arguments provided by the request
Returns:
Deferred[dict]: Deactivation and expiration information for a given user
Deactivation and expiration information for a given user
"""
user_id = args.get("user_id")
if not user_id:
@@ -162,14 +164,14 @@ class UserInfoServlet(RestServlet):
PATTERNS = client_patterns("/users/info$", unstable=True, releases=())
def __init__(self, hs):
def __init__(self, hs: "HomeServer") -> None:
super(UserInfoServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.transport_layer = hs.get_federation_transport_client()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
# Ensure the user is authenticated
await self.auth.get_user_by_req(request)

View File

@@ -16,7 +16,7 @@
import logging
import random
import re
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union
import attr
@@ -453,7 +453,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
async def get_info_for_users(
self,
user_ids: List[str],
user_ids: Iterable[str],
):
"""Return the user info for a given set of users

View File

@@ -32,7 +32,7 @@ from synapse.rest.client import (
)
from synapse.server import HomeServer
from synapse.storage.roommember import ProfileInfo
from synapse.types import create_requester
from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
@@ -910,7 +910,7 @@ class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
account.register_servlets,
]
def default_config(self):
def default_config(self) -> JsonDict:
config = super().default_config()
# Set accounts to expire after a week
@@ -920,12 +920,12 @@ class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
}
return config
def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
super(UserInfoTestCase, self).prepare(reactor, clock, hs)
self.store = hs.get_datastore()
self.handler = hs.get_user_directory_handler()
def test_user_info(self):
def test_user_info(self) -> None:
"""Test /users/info for local users from the Client-Server API"""
user_one, user_two, user_three, user_three_token = self.setup_test_users()
@@ -954,7 +954,7 @@ class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
self.assertFalse(user_three_info["deactivated"])
self.assertFalse(user_three_info["expired"])
def test_user_info_federation(self):
def test_user_info_federation(self) -> None:
"""Test that /users/info can be called from the Federation API, and
and that we can query remote users from the Client-Server API
"""
@@ -983,7 +983,7 @@ class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
self.assertFalse(user_three_info["deactivated"])
self.assertFalse(user_three_info["expired"])
def setup_test_users(self):
def setup_test_users(self) -> Tuple[str, str, str, str]:
"""Create an admin user and three test users, each with a different state"""
# Create an admin user to expire other users with
@@ -1007,7 +1007,7 @@ class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
return user_one, user_two, user_three, user_three_token
def expire(self, user_id_to_expire, admin_tok):
def expire(self, user_id_to_expire: str, admin_tok: str) -> None:
url = "/_synapse/admin/v1/account_validity/validity"
request_data = {
"user_id": user_id_to_expire,
@@ -1017,7 +1017,7 @@ class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
channel = self.make_request("POST", url, request_data, access_token=admin_tok)
self.assertEquals(channel.result["code"], b"200", channel.result)
def deactivate(self, user_id, tok):
def deactivate(self, user_id: str, tok: str) -> None:
request_data = {
"auth": {"type": "m.login.password", "user": user_id, "password": "pass"},
"erase": False,